From 92c2423d36b4a82d216fbc47f55f7e3435450a21 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Thu, 5 Sep 2024 14:52:20 -0700 Subject: [PATCH 001/191] outline for malicious shuffle --- .../src/protocol/ipa_prf/shuffle/malicious.rs | 81 +++++++++++++++++++ ipa-core/src/protocol/ipa_prf/shuffle/mod.rs | 2 + ipa-core/src/protocol/ipa_prf/shuffle/step.rs | 3 + 3 files changed, 86 insertions(+) create mode 100644 ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs new file mode 100644 index 000000000..a0dc6ef0c --- /dev/null +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -0,0 +1,81 @@ +use std::ops::Add; + +use proptest::num::usize; +use rand::distributions::{Distribution, Standard}; + +use crate::{ + error::Error, + ff::{boolean_array::BooleanArray, Gf32Bit}, + protocol::{ + context::Context, + ipa_prf::shuffle::{base::shuffle, step::OPRFShuffleStep}, + prss::SharedRandomness, + RecordId, + }, + secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, +}; + +/// This function executes the maliciously secure shuffle protocol on the input: `shares`. +/// +/// ## Errors +/// Propagates network, multiplication and conversion errors from sub functions. +pub async fn malicious_shuffle(ctx: C, shares: I) -> Result>, Error> +where + C: Context, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + S: BooleanArray, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, + Standard: Distribution, +{ + // compute amount of MAC keys + let amount_of_keys: usize = usize::try_from(S::BITS).unwrap() + 31 / 32; + // generate MAC keys + let keys = (0..amount_of_keys) + .map(|i| ctx.prss().generate_fields(RecordId::from(i))) + .map(|(left, right)| AdditiveShare::new(left, right)) + .collect::>>(); + + // call + // async fn compute_tags( + // ctx: C, + // keys: &[AdditiveShare], + // rows: &[AdditiveShare], + // ) -> Result>, Error> + // + // i.e. let shares_and_tags = compute_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags, keys, shares).await? + // placeholder + let shares_and_tags = + vec![vec![AdditiveShare::::ZERO; amount_of_keys + 1]; shares.into_iter().len()]; + + // call + // pub async fn shuffle( + // ctx: C, + // shares: I, + // ) -> Result<(Vec>, IntermediateShuffleMessages), Error> + // + // i.e. let (output_shares, messages) = shuffle(ctx.narrow(&OPRFShuffleStep::ShuffleProtocol, shares_and_tags).await? + // placeholder + let output_shares = shuffle( + ctx.narrow(&OPRFShuffleStep::ShuffleProtocol), + shares_and_tags, + ) + .await?; + + // call + // async fn verify_shuffle( + // ctx: C, + // key_shares: &[AdditiveShare], + // shuffled_shares: &[AdditiveShare], + // messages: IntermediateShuffleMessages, + // ) -> Result<(), Error> + // + // i.e. verify_shuffle(ctx.narrow(&OPRFShuffleStep::VerifyShuffle), keys, output_shares, messages).await? + + // truncate tags from output_shares + // create function to do this + + // placeholder + Ok(vec![AdditiveShare::ZERO; 1]) +} diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index e0868fb96..555a06e72 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -22,6 +22,8 @@ use crate::{ }; pub mod base; +#[allow(dead_code)] +pub mod malicious; #[cfg(descriptive_gate)] mod sharded; pub(crate) mod step; diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs index e5014d775..4fe86990a 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs @@ -9,4 +9,7 @@ pub(crate) enum OPRFShuffleStep { TransferCHat, TransferX2, TransferY1, + GenerateTags, + ShuffleProtocol, + VerifyShuffle, } From d73552cb2e320e7b6ba7770dd93dc9929137af5f Mon Sep 17 00:00:00 2001 From: danielmasny Date: Fri, 6 Sep 2024 19:17:42 -0700 Subject: [PATCH 002/191] malicious shuffle + stalling test, probably issue with seq_join(parallel_join) --- .../src/protocol/ipa_prf/shuffle/malicious.rs | 225 +++++++++++++----- 1 file changed, 170 insertions(+), 55 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index dd2dfb4b3..e3517858c 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -34,67 +34,84 @@ use crate::{ /// /// ## Errors /// Propagates network, multiplication and conversion errors from sub functions. -pub async fn malicious_shuffle(ctx: C, shares: I) -> Result>, Error> +/// +/// ## Panics +/// Panics when `S::Bits + 32 != B::Bits` or type conversions fail. +pub async fn malicious_shuffle( + ctx: C, + shares: I, +) -> Result>, Error> where C: Context, + S: BooleanArray, + B: BooleanArray, I: IntoIterator>, I::IntoIter: ExactSizeIterator, - S: BooleanArray, - for<'a> &'a S: Add, - for<'a> &'a S: Add<&'a S, Output = S>, - Standard: Distribution, + ::IntoIter: Send, + for<'a> &'a B: Add, + for<'a> &'a B: Add<&'a B, Output = B>, + Standard: Distribution, { // compute amount of MAC keys let amount_of_keys: usize = usize::try_from(S::BITS).unwrap() + 31 / 32; - // generate MAC keys - let keys = (0..amount_of_keys) - .map(|i| ctx.prss().generate_fields(RecordId::from(i))) - .map(|(left, right)| AdditiveShare::new(left, right)) - .collect::>>(); - - // call - // async fn compute_tags( - // ctx: C, - // keys: &[AdditiveShare], - // rows: &[AdditiveShare], - // ) -> Result>, Error> + // // generate MAC keys + let keys = vec![AdditiveShare::ZERO; amount_of_keys]; + // = (0..amount_of_keys) + // .map(|i| ctx.prss().generate_fields(RecordId::from(i))) + // .map(|(left, right)| AdditiveShare::new(left, right)) + // .collect::>>(); + + // compute and append tags to rows + let shares_and_tags: Vec> = + compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?; + + // // shuffle + // let (shuffled_shares, messages) = shuffle( + // ctx.narrow(&OPRFShuffleStep::ShuffleProtocol), + // shares_and_tags, + // ) + // .await?; // - // i.e. let shares_and_tags = compute_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags, keys, shares).await? - // placeholder - let shares_and_tags = - vec![vec![AdditiveShare::::ZERO; amount_of_keys + 1]; shares.into_iter().len()]; - - // call - // pub async fn shuffle( - // ctx: C, - // shares: I, - // ) -> Result<(Vec>, IntermediateShuffleMessages), Error> - // - // i.e. let (output_shares, messages) = shuffle(ctx.narrow(&OPRFShuffleStep::ShuffleProtocol, shares_and_tags).await? - // placeholder - let output_shares = shuffle( - ctx.narrow(&OPRFShuffleStep::ShuffleProtocol), - shares_and_tags, - ) - .await?; - - // call - // async fn verify_shuffle( - // ctx: C, - // key_shares: &[AdditiveShare], - // shuffled_shares: &[AdditiveShare], - // messages: IntermediateShuffleMessages, - // ) -> Result<(), Error> + // // verify the shuffle + // verify_shuffle( + // ctx.narrow(&OPRFShuffleStep::VerifyShuffle), + // &keys, + // &shuffled_shares, + // messages, + // ) + // .await?; // - // i.e. verify_shuffle(ctx.narrow(&OPRFShuffleStep::VerifyShuffle), keys, output_shares, messages).await? + // // truncate tags from output_shares + // Ok(truncate_tags(&shuffled_shares)) - // truncate tags from output_shares - // create function to do this - - // placeholder Ok(vec![AdditiveShare::ZERO; 1]) } +/// This function truncates the tags from the output shares of the shuffle protocol +/// +/// ## Panics +/// Panics when `S::Bits > B::Bits`. +fn truncate_tags(shares_and_tags: &[AdditiveShare]) -> Vec> +where + S: BooleanArray, + B: BooleanArray, +{ + let tag_offset = usize::try_from((S::BITS + 7) / 8).unwrap(); + shares_and_tags + .into_iter() + .map(|row_with_tag| { + let mut buf_left = GenericArray::default(); + let mut buf_right = GenericArray::default(); + row_with_tag.left().serialize(&mut buf_left); + row_with_tag.right().serialize(&mut buf_right); + AdditiveShare::new( + S::deserialize(GenericArray::from_slice(&buf_left[0..tag_offset])).unwrap(), + S::deserialize(GenericArray::from_slice(&buf_right[0..tag_offset])).unwrap(), + ) + }) + .collect() +} + /// This function verifies the `shuffled_shares` and the `IntermediateShuffleMessages`. /// /// ## Errors @@ -350,19 +367,28 @@ async fn reveal_keys( /// ## Panics /// When conversion fails, when `S::Bits + 32 != B::Bits` /// or when `rows` is empty or elements in `rows` have length `0`. -async fn compute_and_add_tags( +async fn compute_and_add_tags( ctx: C, keys: &[AdditiveShare], - rows: &[AdditiveShare], -) -> Result>, Error> { - let length = rows.len(); + rows: I, +) -> Result>, Error> +where + C: Context, + S: BooleanArray, + B: BooleanArray, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + ::IntoIter: Send, +{ + let row_iterator = rows.into_iter(); + let length = row_iterator.len(); let row_length = keys.len(); // make sure total records is not 0 debug_assert!(length * row_length != 0); let tag_ctx = ctx.set_total_records(TotalRecords::specified(length * row_length)?); let p_ctx = &tag_ctx; - let futures = rows.iter().enumerate().map(|(i, row)| async move { + let futures = row_iterator.enumerate().map(|(i, row)| async move { let row_entries_iterator = row.to_gf32bit()?; // compute tags via inner product between row and keys let row_tag = p_ctx @@ -381,7 +407,7 @@ async fn compute_and_add_tags( .iter() .fold(AdditiveShare::::ZERO, |acc, x| acc + x); // combine row and row_tag - Ok::, Error>(concatenate_row_and_tag::(row, &row_tag)) + Ok::, Error>(concatenate_row_and_tag::(&row, &row_tag)) }); seq_join(ctx.active_work(), iter(futures)) @@ -428,6 +454,95 @@ mod tests { test_fixture::{Reconstruct, Runner, TestWorld}, }; + pub async fn wrapper(ctx: C, shares: I) + where + C: Context, + S: BooleanArray, + B: BooleanArray, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + ::IntoIter: Send, + for<'a> &'a B: Add, + for<'a> &'a B: Add<&'a B, Output = B>, + Standard: Distribution, + { + // compute amount of MAC keys + let amount_of_keys: usize = usize::try_from(S::BITS).unwrap() + 31 / 32; + // // generate MAC keys + let keys = vec![AdditiveShare::ZERO; amount_of_keys]; + + // compute and append tags to rows + let shares_and_tags: Vec> = + compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares) + .await + .unwrap(); + } + + #[test] + fn minimal_stall() { + const RECORD_AMOUNT: usize = 1; + run(|| async { + let world = TestWorld::default(); + let mut rng = thread_rng(); + let records = (0..RECORD_AMOUNT) + .map(|_| rng.gen::()) + .collect::>(); + + world + .semi_honest(records.into_iter(), |ctx, (row_shares)| async move { + wrapper::<_, BA32, BA64, _>(ctx, row_shares).await; + }) + .await; + }); + } + + /// This test checks the correctness of the malicious shuffle. + /// It does not check the security against malicious behavior. + #[test] + fn check_shuffle_correctness() { + const RECORD_AMOUNT: usize = 10; + run(|| async { + let world = TestWorld::default(); + let mut rng = thread_rng(); + // using Gf32Bit here since it implements cmp such that vec can later be sorted + let mut records = (0..RECORD_AMOUNT) + .map(|_| rng.gen()) + .collect::>(); + + let records_boolean_array = records + .iter() + .map(|row| { + let mut buf = GenericArray::default(); + row.serialize(&mut buf); + BA32::deserialize(&buf).unwrap() + }) + .collect::>(); + + let result = world + .semi_honest( + records_boolean_array.into_iter(), + |ctx, records| async move { + malicious_shuffle::<_, BA32, BA64, _>(ctx, records) + .await + .unwrap() + }, + ) + .await + .reconstruct(); + + let mut result_galois = result + .iter() + .map(|row| { + let mut buf = GenericArray::default(); + row.serialize(&mut buf); + Gf32Bit::deserialize(&buf).unwrap() + }) + .collect::>(); + + assert_eq!(records.sort(), result_galois.sort()); + }); + } + /// This test checks the correctness of the malicious shuffle /// when all parties behave honestly /// and all the MAC keys are `Gf32Bit::ONE`. @@ -557,7 +672,7 @@ mod tests { // convert key let mac_key: Vec> = key_shares.to_gf32bit().unwrap().collect::>(); - compute_and_add_tags(ctx, &mac_key, &row_shares) + compute_and_add_tags(ctx, &mac_key, row_shares) .await .unwrap() }, From 64145b80eeb21abce3fadca2406694fcc23b6143 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Mon, 9 Sep 2024 11:21:40 -0700 Subject: [PATCH 003/191] fix issue by setting row length incorrectly within malicious shuffle (thanks to Andy) --- .../src/protocol/ipa_prf/shuffle/malicious.rs | 109 ++++++------------ 1 file changed, 35 insertions(+), 74 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index e3517858c..cfd044fe7 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -6,7 +6,6 @@ use futures_util::{ stream::iter, }; use generic_array::GenericArray; -use proptest::num::usize; use rand::distributions::{Distribution, Standard}; use crate::{ @@ -53,38 +52,35 @@ where Standard: Distribution, { // compute amount of MAC keys - let amount_of_keys: usize = usize::try_from(S::BITS).unwrap() + 31 / 32; + let amount_of_keys: usize = (usize::try_from(S::BITS).unwrap() + 31) / 32; // // generate MAC keys - let keys = vec![AdditiveShare::ZERO; amount_of_keys]; - // = (0..amount_of_keys) - // .map(|i| ctx.prss().generate_fields(RecordId::from(i))) - // .map(|(left, right)| AdditiveShare::new(left, right)) - // .collect::>>(); + let keys = (0..amount_of_keys) + .map(|i| ctx.prss().generate_fields(RecordId::from(i))) + .map(|(left, right)| AdditiveShare::new(left, right)) + .collect::>>(); // compute and append tags to rows let shares_and_tags: Vec> = compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?; - // // shuffle - // let (shuffled_shares, messages) = shuffle( - // ctx.narrow(&OPRFShuffleStep::ShuffleProtocol), - // shares_and_tags, - // ) - // .await?; - // - // // verify the shuffle - // verify_shuffle( - // ctx.narrow(&OPRFShuffleStep::VerifyShuffle), - // &keys, - // &shuffled_shares, - // messages, - // ) - // .await?; - // - // // truncate tags from output_shares - // Ok(truncate_tags(&shuffled_shares)) - - Ok(vec![AdditiveShare::ZERO; 1]) + // shuffle + let (shuffled_shares, messages) = shuffle( + ctx.narrow(&OPRFShuffleStep::ShuffleProtocol), + shares_and_tags, + ) + .await?; + + // verify the shuffle + verify_shuffle( + ctx.narrow(&OPRFShuffleStep::VerifyShuffle), + &keys, + &shuffled_shares, + messages, + ) + .await?; + + // truncate tags from output_shares + Ok(truncate_tags(&shuffled_shares)) } /// This function truncates the tags from the output shares of the shuffle protocol @@ -98,7 +94,7 @@ where { let tag_offset = usize::try_from((S::BITS + 7) / 8).unwrap(); shares_and_tags - .into_iter() + .iter() .map(|row_with_tag| { let mut buf_left = GenericArray::default(); let mut buf_right = GenericArray::default(); @@ -454,48 +450,6 @@ mod tests { test_fixture::{Reconstruct, Runner, TestWorld}, }; - pub async fn wrapper(ctx: C, shares: I) - where - C: Context, - S: BooleanArray, - B: BooleanArray, - I: IntoIterator>, - I::IntoIter: ExactSizeIterator, - ::IntoIter: Send, - for<'a> &'a B: Add, - for<'a> &'a B: Add<&'a B, Output = B>, - Standard: Distribution, - { - // compute amount of MAC keys - let amount_of_keys: usize = usize::try_from(S::BITS).unwrap() + 31 / 32; - // // generate MAC keys - let keys = vec![AdditiveShare::ZERO; amount_of_keys]; - - // compute and append tags to rows - let shares_and_tags: Vec> = - compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares) - .await - .unwrap(); - } - - #[test] - fn minimal_stall() { - const RECORD_AMOUNT: usize = 1; - run(|| async { - let world = TestWorld::default(); - let mut rng = thread_rng(); - let records = (0..RECORD_AMOUNT) - .map(|_| rng.gen::()) - .collect::>(); - - world - .semi_honest(records.into_iter(), |ctx, (row_shares)| async move { - wrapper::<_, BA32, BA64, _>(ctx, row_shares).await; - }) - .await; - }); - } - /// This test checks the correctness of the malicious shuffle. /// It does not check the security against malicious behavior. #[test] @@ -539,7 +493,10 @@ mod tests { }) .collect::>(); - assert_eq!(records.sort(), result_galois.sort()); + records.sort(); + result_galois.sort(); + + assert_eq!(records, result_galois); }); } @@ -672,9 +629,13 @@ mod tests { // convert key let mac_key: Vec> = key_shares.to_gf32bit().unwrap().collect::>(); - compute_and_add_tags(ctx, &mac_key, row_shares) - .await - .unwrap() + compute_and_add_tags( + ctx.narrow(&OPRFShuffleStep::GenerateTags), + &mac_key, + row_shares, + ) + .await + .unwrap() }, ) .await From 2ee5d35aef8d4ec6fb344543f861c8cc60dc7c62 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Mon, 9 Sep 2024 16:29:59 -0700 Subject: [PATCH 004/191] Remove some console spam --- ipa-core/src/cli/playbook/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/ipa-core/src/cli/playbook/mod.rs b/ipa-core/src/cli/playbook/mod.rs index df43911d8..74bf9484d 100644 --- a/ipa-core/src/cli/playbook/mod.rs +++ b/ipa-core/src/cli/playbook/mod.rs @@ -146,7 +146,6 @@ pub fn validate_dp( } else { next_actual_f64 }; - println!("next_actual_f64 = {next_actual_f64}, next_actual_f64_shifted = {next_actual_f64_shifted}"); let (_, std) = truncated_discrete_laplace.mean_and_std(); let tolerance_factor = 20.0; // set so this fails randomly with small probability From 2b2f51715890e045ce423c447db4aafdb946ffe3 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Mon, 9 Sep 2024 21:41:10 -0700 Subject: [PATCH 005/191] test when party is malicious succeed, i.e. detect malicious behavior + fix bug in compute row hash for BAs that are not multiples of 32 bits --- .../src/protocol/ipa_prf/shuffle/malicious.rs | 301 +++++++++++++++--- 1 file changed, 260 insertions(+), 41 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index cfd044fe7..dae440ea8 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -71,7 +71,7 @@ where .await?; // verify the shuffle - verify_shuffle( + verify_shuffle::<_, S, B>( ctx.narrow(&OPRFShuffleStep::VerifyShuffle), &keys, &shuffled_shares, @@ -96,28 +96,43 @@ where shares_and_tags .iter() .map(|row_with_tag| { - let mut buf_left = GenericArray::default(); - let mut buf_right = GenericArray::default(); - row_with_tag.left().serialize(&mut buf_left); - row_with_tag.right().serialize(&mut buf_right); AdditiveShare::new( - S::deserialize(GenericArray::from_slice(&buf_left[0..tag_offset])).unwrap(), - S::deserialize(GenericArray::from_slice(&buf_right[0..tag_offset])).unwrap(), + split_row_and_tag(row_with_tag.left(), tag_offset).0, + split_row_and_tag(row_with_tag.right(), tag_offset).0, ) }) .collect() } +/// This function splits a row with tag into +/// a row without tag and a tag +/// +/// ## Panics +/// Panics when the lengths are incorrect: +/// `S` in bytes needs to be equal to `tag_offset`. +/// `B` in bytes needs to be equal to `tag_offset + 4`. +fn split_row_and_tag( + row_with_tag: B, + tag_offset: usize, +) -> (S, Gf32Bit) { + let mut buf = GenericArray::default(); + row_with_tag.serialize(&mut buf); + ( + S::deserialize(GenericArray::from_slice(&buf.as_slice()[0..tag_offset])).unwrap(), + Gf32Bit::deserialize(GenericArray::from_slice(&buf.as_slice()[tag_offset..])).unwrap(), + ) +} + /// This function verifies the `shuffled_shares` and the `IntermediateShuffleMessages`. /// /// ## Errors /// Propagates network errors. /// Further, returns an error when messages are inconsistent with the MAC tags. -async fn verify_shuffle( +async fn verify_shuffle( ctx: C, key_shares: &[AdditiveShare], - shuffled_shares: &[AdditiveShare], - messages: IntermediateShuffleMessages, + shuffled_shares: &[AdditiveShare], + messages: IntermediateShuffleMessages, ) -> Result<(), Error> { // reveal keys let k_ctx = ctx @@ -127,11 +142,15 @@ async fn verify_shuffle( // verify messages and shares match ctx.role() { - Role::H1 => h1_verify(ctx, &keys, shuffled_shares, messages.get_x1_or_y1()).await, - Role::H2 => h2_verify(ctx, &keys, shuffled_shares, messages.get_x2_or_y2()).await, + Role::H1 => { + h1_verify::<_, S, B>(ctx, &keys, shuffled_shares, messages.get_x1_or_y1()).await + } + Role::H2 => { + h2_verify::<_, S, B>(ctx, &keys, shuffled_shares, messages.get_x2_or_y2()).await + } Role::H3 => { let (y1, y2) = messages.get_both_x_or_ys(); - h3_verify(ctx, &keys, shuffled_shares, y1, y2).await + h3_verify::<_, S, B>(ctx, &keys, shuffled_shares, y1, y2).await } } } @@ -145,17 +164,17 @@ async fn verify_shuffle( /// Propagates network errors. Further it returns an error when /// `hash_x1 != hash_y1` or `hash_c_h2 != hash_a_xor_b` /// or `hash_c_h3 != hash_a_xor_b`. -async fn h1_verify( +async fn h1_verify( ctx: C, keys: &[StdArray], - share_a_and_b: &[AdditiveShare], - x1: Vec, + share_a_and_b: &[AdditiveShare], + x1: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for x1 - let hash_x1 = compute_row_hash(keys, x1); + let hash_x1 = compute_row_hash::(keys, x1); // compute hash for A xor B - let hash_a_xor_b = compute_row_hash( + let hash_a_xor_b = compute_row_hash::( keys, share_a_and_b .iter() @@ -212,17 +231,17 @@ async fn h1_verify( /// ## Errors /// Propagates network errors. Further it returns an error when /// `hash_x2 != hash_y2`. -async fn h2_verify( +async fn h2_verify( ctx: C, keys: &[StdArray], - share_b_and_c: &[AdditiveShare], - x2: Vec, + share_b_and_c: &[AdditiveShare], + x2: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for x2 - let hash_x2 = compute_row_hash(keys, x2); + let hash_x2 = compute_row_hash::(keys, x2); // compute hash for C - let hash_c = compute_row_hash( + let hash_c = compute_row_hash::( keys, share_b_and_c.iter().map(ReplicatedSecretSharing::right), ); @@ -260,20 +279,20 @@ async fn h2_verify( /// /// ## Errors /// Propagates network errors. -async fn h3_verify( +async fn h3_verify( ctx: C, keys: &[StdArray], - share_c_and_a: &[AdditiveShare], - y1: Vec, - y2: Vec, + share_c_and_a: &[AdditiveShare], + y1: Vec, + y2: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for y1 - let hash_y1 = compute_row_hash(keys, y1); + let hash_y1 = compute_row_hash::(keys, y1); // compute hash for y2 - let hash_y2 = compute_row_hash(keys, y2); + let hash_y2 = compute_row_hash::(keys, y2); // compute hash for C - let hash_c = compute_row_hash( + let hash_c = compute_row_hash::( keys, share_c_and_a.iter().map(ReplicatedSecretSharing::left), ); @@ -304,16 +323,23 @@ async fn h3_verify( /// /// ## Panics /// Panics when conversion from `BooleanArray` to `Vec(keys: &[StdArray], row_iterator: I) -> Hash +fn compute_row_hash(keys: &[StdArray], row_iterator: I) -> Hash where S: BooleanArray, - I: IntoIterator, + B: BooleanArray, + I: IntoIterator, { - let iterator = row_iterator - .into_iter() - .map(|row| >>::try_into(row).unwrap()); - compute_hash(iterator.map(|row| { - row.into_iter() + let tag_offset = usize::try_from((B::BITS + 7) / 8).unwrap() - 4; + + let iterator = row_iterator.into_iter().map(|row_with_tag| { + let (row, tag) = split_row_and_tag(row_with_tag, tag_offset); + >>::try_into(row) + .unwrap() + .into_iter() + .chain(iter::once(tag)) + }); + compute_hash(iterator.map(|row_iterator| { + row_iterator .zip(keys) .fold(Gf32Bit::ZERO, |acc, (row_entry, key)| { acc + row_entry * *key.first() @@ -444,12 +470,64 @@ mod tests { boolean_array::{BA112, BA144, BA20, BA32, BA64}, Serializable, }, + helpers::in_memory_config::{MaliciousHelper, MaliciousHelperContext}, protocol::ipa_prf::shuffle::base::shuffle, secret_sharing::SharedValue, test_executor::run, - test_fixture::{Reconstruct, Runner, TestWorld}, + test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig}, }; + /// Test the hashing of `BA112` and tag equality. + #[test] + fn hash() { + run(|| async { + let world = TestWorld::default(); + + let mut rng = thread_rng(); + let record = rng.gen::(); + + let (keys, result) = world + .semi_honest(record, |ctx, record| async move { + // compute amount of MAC keys + let amount_of_keys: usize = (usize::try_from(BA112::BITS).unwrap() + 31) / 32; + // // generate MAC keys + let keys = (0..amount_of_keys) + .map(|i| ctx.prss().generate_fields(RecordId::from(i))) + .map(|(left, right)| AdditiveShare::new(left, right)) + .collect::>>(); + + // compute and append tags to rows + let shares_and_tags: Vec> = compute_and_add_tags( + ctx.narrow(&OPRFShuffleStep::GenerateTags), + &keys, + iter::once(record), + ) + .await + .unwrap(); + + (keys, shares_and_tags) + }) + .await + .reconstruct(); + + let result_ba = BA112::deserialize_from_slice(&result[0].as_raw_slice()[0..14]); + + assert_eq!(record, result_ba); + + let tag = >>::try_into(record) + .unwrap() + .iter() + .zip(keys) + .fold(Gf32Bit::ZERO, |acc, (entry, key)| acc + *entry * key); + + let tag_mpc = >>::try_into(BA32::deserialize_from_slice( + &result[0].as_raw_slice()[14..18], + )) + .unwrap(); + assert_eq!(tag, tag_mpc[0]); + }); + } + /// This test checks the correctness of the malicious shuffle. /// It does not check the security against malicious behavior. #[test] @@ -500,6 +578,30 @@ mod tests { }); } + /// This tests checks that the shuffling of `BA112` + /// does not return an error + /// nor panic. + #[test] + fn shuffle_ba112_succeeds() { + const RECORD_AMOUNT: usize = 10; + run(|| async { + let world = TestWorld::default(); + let mut rng = thread_rng(); + + let records = (0..RECORD_AMOUNT) + .map(|_| rng.gen()) + .collect::>(); + + world + .semi_honest(records.into_iter(), |ctx, records| async move { + malicious_shuffle::<_, BA112, BA144, _>(ctx, records) + .await + .unwrap() + }) + .await; + }); + } + /// This test checks the correctness of the malicious shuffle /// when all parties behave honestly /// and all the MAC keys are `Gf32Bit::ONE`. @@ -527,9 +629,14 @@ mod tests { // run shuffle let (shares, messages) = shuffle(ctx.narrow("shuffle"), rows).await.unwrap(); // verify it - verify_shuffle(ctx.narrow("verify"), &key_shares, &shares, messages) - .await - .unwrap(); + verify_shuffle::<_, BA32, BA64>( + ctx.narrow("verify"), + &key_shares, + &shares, + messages, + ) + .await + .unwrap(); }) .await; }); @@ -675,4 +782,116 @@ mod tests { fn bad_initialization_too_small() { check_tags::(); } + + #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait + fn interceptor_h1_to_h2(ctx: &MaliciousHelperContext, data: &mut Vec) { + // H3 runs an additive attack against H1 (on the right) by + // adding a 1 to the left part of share it is holding + if ctx.gate.as_ref().contains("transfer_x2") && ctx.dest == Role::H2 { + data[0] ^= 1u8; + } + } + + #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait + fn interceptor_h2_to_h3(ctx: &MaliciousHelperContext, data: &mut Vec) { + // H3 runs an additive attack against H1 (on the right) by + // adding a 1 to the left part of share it is holding + if ctx.gate.as_ref().contains("transfer_y1") && ctx.dest == Role::H3 { + data[0] ^= 1u8; + } + } + + #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait + fn interceptor_h3_to_h2(ctx: &MaliciousHelperContext, data: &mut Vec) { + // H3 runs an additive attack against H1 (on the right) by + // adding a 1 to the left part of share it is holding + if ctx.gate.as_ref().contains("transfer_c_hat") && ctx.dest == Role::H2 { + data[0] ^= 1u8; + } + } + + /// This test checks that the malicious sort fails + /// under a simple bit flip attack by H1. + /// + /// `x2` will be inconsistent which is checked by `H2`. + #[test] + #[should_panic(expected = "X2 is inconsistent")] + fn fail_under_bit_flip_attack_on_x2() { + const RECORD_AMOUNT: usize = 10; + + run(move || async move { + let mut rng = thread_rng(); + let mut config = TestWorldConfig::default(); + config.stream_interceptor = + MaliciousHelper::new(Role::H1, config.role_assignment(), interceptor_h1_to_h2); + + let world = TestWorld::new_with(config); + let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); + let [_, h2, _] = world + .semi_honest(records.into_iter(), |ctx, shares| async move { + malicious_shuffle::<_, BA32, BA64, _>(ctx, shares).await + }) + .await; + + let _ = h2.unwrap(); + }); + } + + /// This test checks that the malicious sort fails + /// under a simple bit flip attack by H2. + /// + /// `y1` will be inconsistent which is checked by `H1`. + #[test] + #[should_panic(expected = "Y1 is inconsistent")] + fn fail_under_bit_flip_attack_on_y1() { + const RECORD_AMOUNT: usize = 10; + + run(move || async move { + let mut rng = thread_rng(); + let mut config = TestWorldConfig::default(); + config.stream_interceptor = + MaliciousHelper::new(Role::H2, config.role_assignment(), interceptor_h2_to_h3); + + let world = TestWorld::new_with(config); + let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); + let [h1, _, _] = world + .malicious(records.into_iter(), |ctx, shares| async move { + malicious_shuffle::<_, BA32, BA64, _>(ctx, shares).await + }) + .await; + let _ = h1.unwrap(); + }); + } + + /// This test checks that the malicious sort fails + /// under a simple bit flip attack by H3. + /// + /// `c` from `H2` will be inconsistent + /// which is checked by `H1`. + #[test] + #[should_panic(expected = "C from H2 is inconsistent")] + fn fail_under_bit_flip_attack_on_c() { + const RECORD_AMOUNT: usize = 10; + + run(move || async move { + let mut rng = thread_rng(); + let mut config = TestWorldConfig::default(); + config.stream_interceptor = + MaliciousHelper::new(Role::H3, config.role_assignment(), interceptor_h3_to_h2); + + let world = TestWorld::new_with(config); + let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); + let [h1, h2, _] = world + .semi_honest(records.into_iter(), |ctx, shares| async move { + malicious_shuffle::<_, BA32, BA64, _>(ctx, shares).await + }) + .await; + + // x2 should be consistent with y2 + let _ = h2.unwrap(); + + // but this should fail + let _ = h1.unwrap(); + }); + } } From 0a27ab0e4c8a9d6f51cec62ba90ec327178c7b1c Mon Sep 17 00:00:00 2001 From: danielmasny Date: Mon, 9 Sep 2024 21:48:03 -0700 Subject: [PATCH 006/191] change comments in interceptors to test malicious shuffle --- ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index dae440ea8..4c515c97a 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -785,8 +785,8 @@ mod tests { #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait fn interceptor_h1_to_h2(ctx: &MaliciousHelperContext, data: &mut Vec) { - // H3 runs an additive attack against H1 (on the right) by - // adding a 1 to the left part of share it is holding + // H1 runs an additive attack against H2 by + // changing x2 if ctx.gate.as_ref().contains("transfer_x2") && ctx.dest == Role::H2 { data[0] ^= 1u8; } @@ -794,8 +794,8 @@ mod tests { #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait fn interceptor_h2_to_h3(ctx: &MaliciousHelperContext, data: &mut Vec) { - // H3 runs an additive attack against H1 (on the right) by - // adding a 1 to the left part of share it is holding + // H2 runs an additive attack against H3 by + // changing y1 if ctx.gate.as_ref().contains("transfer_y1") && ctx.dest == Role::H3 { data[0] ^= 1u8; } @@ -803,8 +803,8 @@ mod tests { #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait fn interceptor_h3_to_h2(ctx: &MaliciousHelperContext, data: &mut Vec) { - // H3 runs an additive attack against H1 (on the right) by - // adding a 1 to the left part of share it is holding + // H3 runs an additive attack against H2 by + // changing c_hat_2 if ctx.gate.as_ref().contains("transfer_c_hat") && ctx.dest == Role::H2 { data[0] ^= 1u8; } From 22a0a8b7f7aa916796fe4cf9a636ff830a4e4380 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Tue, 10 Sep 2024 10:22:52 -0700 Subject: [PATCH 007/191] Malicious security for top-level OPRF IPA protocol (#1252) --- ipa-core/src/bin/report_collector.rs | 27 +++++-- ipa-core/src/cli/crypto.rs | 2 +- ipa-core/src/helpers/transport/query/mod.rs | 9 ++- ipa-core/src/net/http_serde.rs | 10 ++- .../src/net/server/handlers/query/create.rs | 30 +++++-- ipa-core/src/protocol/ipa_prf/mod.rs | 79 ++++++++++++++----- .../src/protocol/ipa_prf/prf_sharding/mod.rs | 18 ++--- ipa-core/src/query/executor.rs | 35 ++++---- ipa-core/src/query/processor.rs | 2 +- ipa-core/src/query/runner/oprf_ipa.rs | 57 +++++++++---- ipa-core/src/secret_sharing/vector/traits.rs | 2 +- ipa-core/src/test_fixture/ipa.rs | 12 +-- ipa-core/tests/common/mod.rs | 4 +- ipa-core/tests/compact_gate.rs | 6 ++ ipa-core/tests/encrypted_input.rs | 2 +- ipa-core/tests/helper_networks.rs | 6 ++ 16 files changed, 212 insertions(+), 89 deletions(-) diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 6debb6ead..58e2752d1 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -97,8 +97,11 @@ enum ReportCollectorCommand { }, /// Apply differential privacy noise to IPA inputs ApplyDpNoise(ApplyDpArgs), - /// Execute OPRF IPA in a semi-honest majority setting - OprfIpa(IpaQueryConfig), + /// Execute OPRF IPA in a semi-honest setting + #[command(visible_alias = "oprf-ipa")] + SemiHonestOprfIpa(IpaQueryConfig), + /// Execute OPRF IPA in an honest majority (one malicious helper) setting + MaliciousOprfIpa(IpaQueryConfig), } #[derive(Debug, clap::Args)] @@ -129,7 +132,7 @@ async fn main() -> Result<(), Box> { gen_args, } => gen_inputs(count, seed, args.output_file, gen_args)?, ReportCollectorCommand::ApplyDpNoise(ref dp_args) => apply_dp_noise(&args, dp_args)?, - ReportCollectorCommand::OprfIpa(config) => { + ReportCollectorCommand::SemiHonestOprfIpa(config) => { ipa( &args, &network, @@ -140,6 +143,17 @@ async fn main() -> Result<(), Box> { ) .await? } + ReportCollectorCommand::MaliciousOprfIpa(config) => { + ipa( + &args, + &network, + IpaSecurityModel::Malicious, + config, + &clients, + IpaQueryStyle::Oprf, + ) + .await? + } }; Ok(()) @@ -180,13 +194,12 @@ async fn ipa( query_style: IpaQueryStyle, ) -> Result<(), Box> { let input = InputSource::from(&args.input); - let query_type: QueryType; - match (security_model, &query_style) { + let query_type = match (security_model, &query_style) { (IpaSecurityModel::SemiHonest, IpaQueryStyle::Oprf) => { - query_type = QueryType::OprfIpa(ipa_query_config); + QueryType::SemiHonestOprfIpa(ipa_query_config) } (IpaSecurityModel::Malicious, IpaQueryStyle::Oprf) => { - panic!("OPRF for malicious is not implemented as yet") + QueryType::MaliciousOprfIpa(ipa_query_config) } }; diff --git a/ipa-core/src/cli/crypto.rs b/ipa-core/src/cli/crypto.rs index 69acd1e11..f7ccca4f8 100644 --- a/ipa-core/src/cli/crypto.rs +++ b/ipa-core/src/cli/crypto.rs @@ -649,7 +649,7 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" .expect("manually constructed for test"), )])); - OprfIpaQuery::>::new( + OprfIpaQuery::<_, BA16, KeyRegistry>::new( query_config, private_registry, ) diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index 27850ab8d..f509185a5 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -198,14 +198,16 @@ pub enum QueryType { TestMultiply, #[cfg(any(test, feature = "test-fixture", feature = "cli"))] TestAddInPrimeField, - OprfIpa(IpaQueryConfig), + SemiHonestOprfIpa(IpaQueryConfig), + MaliciousOprfIpa(IpaQueryConfig), } impl QueryType { /// TODO: strum pub const TEST_MULTIPLY_STR: &'static str = "test-multiply"; pub const TEST_ADD_STR: &'static str = "test-add"; - pub const OPRF_IPA_STR: &'static str = "oprf_ipa"; + pub const SEMI_HONEST_OPRF_IPA_STR: &'static str = "semi-honest-oprf-ipa"; + pub const MALICIOUS_OPRF_IPA_STR: &'static str = "malicious-oprf-ipa"; } /// TODO: should this `AsRef` impl (used for `Substep`) take into account config of IPA? @@ -216,7 +218,8 @@ impl AsRef for QueryType { QueryType::TestMultiply => Self::TEST_MULTIPLY_STR, #[cfg(any(test, feature = "cli", feature = "test-fixture"))] QueryType::TestAddInPrimeField => Self::TEST_ADD_STR, - QueryType::OprfIpa(_) => Self::OPRF_IPA_STR, + QueryType::SemiHonestOprfIpa(_) => Self::SEMI_HONEST_OPRF_IPA_STR, + QueryType::MaliciousOprfIpa(_) => Self::MALICIOUS_OPRF_IPA_STR, } } } diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 92f9ebf48..687fd2f19 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -122,9 +122,13 @@ pub mod query { QueryType::TEST_MULTIPLY_STR => Ok(QueryType::TestMultiply), #[cfg(any(test, feature = "cli", feature = "test-fixture"))] QueryType::TEST_ADD_STR => Ok(QueryType::TestAddInPrimeField), - QueryType::OPRF_IPA_STR => { + QueryType::SEMI_HONEST_OPRF_IPA_STR => { let Query(q) = req.extract().await?; - Ok(QueryType::OprfIpa(q)) + Ok(QueryType::SemiHonestOprfIpa(q)) + } + QueryType::MALICIOUS_OPRF_IPA_STR => { + let Query(q) = req.extract().await?; + Ok(QueryType::MaliciousOprfIpa(q)) } other => Err(Error::bad_query_value("query_type", other)), }?; @@ -148,7 +152,7 @@ pub mod query { match self.query_type { #[cfg(any(test, feature = "test-fixture", feature = "cli"))] QueryType::TestMultiply | QueryType::TestAddInPrimeField => Ok(()), - QueryType::OprfIpa(config) => { + QueryType::SemiHonestOprfIpa(config) | QueryType::MaliciousOprfIpa(config) => { write!( f, "&per_user_credit_cap={}&max_breakdown_key={}&with_dp={}&epsilon={}", diff --git a/ipa-core/src/net/server/handlers/query/create.rs b/ipa-core/src/net/server/handlers/query/create.rs index aa4577ec4..f56c0b8d2 100644 --- a/ipa-core/src/net/server/handlers/query/create.rs +++ b/ipa-core/src/net/server/handlers/query/create.rs @@ -90,7 +90,7 @@ mod tests { async fn create_test_ipa_no_attr_window() { create_test( QueryConfig::new( - QueryType::OprfIpa(IpaQueryConfig { + QueryType::SemiHonestOprfIpa(IpaQueryConfig { per_user_credit_cap: 1, max_breakdown_key: 1, attribution_window_seconds: None, @@ -107,10 +107,30 @@ mod tests { } #[tokio::test] - async fn create_test_ipa_no_attr_window_with_dp() { + async fn create_test_semi_honest_ipa_no_attr_window_with_dp_default_padding() { create_test( QueryConfig::new( - QueryType::OprfIpa(IpaQueryConfig { + QueryType::SemiHonestOprfIpa(IpaQueryConfig { + per_user_credit_cap: 8, + max_breakdown_key: 20, + attribution_window_seconds: None, + with_dp: 1, + epsilon: 5.0, + plaintext_match_keys: true, + }), + FieldType::Fp32BitPrime, + 1, + ) + .unwrap(), + ) + .await; + } + + #[tokio::test] + async fn create_test_malicious_ipa_no_attr_window_with_dp_default_padding() { + create_test( + QueryConfig::new( + QueryType::MaliciousOprfIpa(IpaQueryConfig { per_user_credit_cap: 8, max_breakdown_key: 20, attribution_window_seconds: None, @@ -131,7 +151,7 @@ mod tests { create_test(QueryConfig { size: 1.try_into().unwrap(), field_type: FieldType::Fp32BitPrime, - query_type: QueryType::OprfIpa(IpaQueryConfig { + query_type: QueryType::SemiHonestOprfIpa(IpaQueryConfig { per_user_credit_cap: 1, max_breakdown_key: 1, attribution_window_seconds: NonZeroU32::new(86_400), @@ -238,7 +258,7 @@ mod tests { fn default() -> Self { Self { field_type: format!("{:?}", FieldType::Fp32BitPrime), - query_type: QueryType::OPRF_IPA_STR.to_string(), + query_type: QueryType::SEMI_HONEST_OPRF_IPA_STR.to_string(), per_user_credit_cap: "1".into(), max_breakdown_key: "1".into(), attribution_window_seconds: None, diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 6dc108524..7fdfa7ef0 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -20,10 +20,7 @@ use crate::{ }, protocol::{ basics::{BooleanArrayMul, BooleanProtocols, Reveal}, - context::{ - dzkp_validator::DZKPValidator, Context, DZKPUpgraded, DZKPUpgradedSemiHonestContext, - MacUpgraded, SemiHonestContext, UpgradableContext, UpgradedSemiHonestContext, - }, + context::{dzkp_validator::DZKPValidator, DZKPUpgraded, MacUpgraded, UpgradableContext}, ipa_prf::{ boolean_ops::convert_to_fp25519, oprf_padding::apply_dp_padding, @@ -40,7 +37,6 @@ use crate::{ SharedValue, TransposeFrom, Vectorizable, }, seq_join::seq_join, - sharding::NotSharded, }; pub(crate) mod aggregation; @@ -220,25 +216,32 @@ where /// Propagates errors from config issues or while running the protocol /// # Panics /// Propagates errors from config issues or while running the protocol -pub async fn oprf_ipa<'ctx, BK, TV, HV, TS, const SS_BITS: usize, const B: usize>( - ctx: SemiHonestContext<'ctx>, +pub async fn oprf_ipa<'ctx, C, BK, TV, HV, TS, const SS_BITS: usize, const B: usize>( + ctx: C, input_rows: Vec>, attribution_window_seconds: Option, dp_params: DpMechanism, dp_padding_params: PaddingParameters, ) -> Result>, Error> where + C: UpgradableContext + 'ctx, BK: BreakdownKey, TV: BooleanArray + U128Conversions, HV: BooleanArray + U128Conversions, TS: BooleanArray + U128Conversions, Boolean: FieldSimd, - Replicated: - BooleanProtocols, B>, - Replicated: BooleanProtocols, B>, - for<'a> Replicated: BooleanArrayMul>, - for<'a> Replicated: BooleanArrayMul>, - for<'a> Replicated: BooleanArrayMul>, + Replicated: BooleanProtocols>, + Replicated: BooleanProtocols, B>, + Replicated: BooleanProtocols, AGG_CHUNK>, + Replicated: BooleanProtocols, CONV_CHUNK>, + Replicated: BooleanProtocols, SORT_CHUNK>, + Replicated: + PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, + Replicated: + Reveal, Output = >::Array>, + Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul>, BitDecomposed>: for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, BitDecomposed>: @@ -444,7 +447,47 @@ pub mod tests { let mut result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { - oprf_ipa::( + oprf_ipa::<_, BA5, BA3, BA16, BA20, 5, 32>( + ctx, + input_rows, + None, + dp_params, + padding_params, + ) + .await + .unwrap() + }) + .await + .reconstruct(); + result.truncate(EXPECTED.len()); + assert_eq!( + result.iter().map(|&v| v.as_u128()).collect::>(), + EXPECTED, + ); + }); + } + + #[test] + fn malicious() { + const EXPECTED: &[u128] = &[0, 2, 5, 0, 0, 0, 0, 0]; + + run(|| async { + let world = TestWorld::default(); + + let records: Vec = vec![ + test_input(0, 12345, false, 1, 0), + test_input(5, 12345, false, 2, 0), + test_input(10, 12345, true, 0, 5), + test_input(0, 68362, false, 1, 0), + test_input(20, 68362, true, 0, 2), + ]; // trigger value of 2 attributes to earlier source row with breakdown 1 and trigger + // value of 5 attributes to source row with breakdown 2. + let dp_params = DpMechanism::NoDp; + let padding_params = PaddingParameters::relaxed(); + + let mut result: Vec<_> = world + .malicious(records.into_iter(), |ctx, input_rows| async move { + oprf_ipa::<_, BA5, BA3, BA16, BA20, 5, 32>( ctx, input_rows, None, @@ -501,7 +544,7 @@ pub mod tests { ]; let mut result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { - oprf_ipa::( + oprf_ipa::<_, BA5, BA3, BA16, BA20, SS_BITS, B>( ctx, input_rows, None, @@ -562,7 +605,7 @@ pub mod tests { let mut result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { - oprf_ipa::( + oprf_ipa::<_, BA5, BA3, BA8, BA20, 5, 32>( ctx, input_rows, None, @@ -598,7 +641,7 @@ pub mod tests { let mut result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { - oprf_ipa::( + oprf_ipa::<_, BA5, BA3, BA8, BA20, 5, 32>( ctx, input_rows, None, @@ -652,7 +695,7 @@ pub mod tests { let padding_params = PaddingParameters::no_padding(); let mut result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { - oprf_ipa::( + oprf_ipa::<_, BA8, BA3, BA16, BA20, 5, 256>( ctx, input_rows, None, diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 482123d73..e91880b92 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -487,9 +487,9 @@ where Replicated: BooleanProtocols>, Replicated: BooleanProtocols, B>, Replicated: BooleanProtocols, AGG_CHUNK>, - for<'a> Replicated: BooleanArrayMul>, - for<'a> Replicated: BooleanArrayMul>, - for<'a> Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul>, BitDecomposed>: for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, BitDecomposed>: @@ -568,9 +568,9 @@ where BK: BreakdownKey, TV: BooleanArray + U128Conversions, TS: BooleanArray + U128Conversions, - for<'a> Replicated: BooleanArrayMul, - for<'a> Replicated: BooleanArrayMul, - for<'a> Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, { let chunked_user_results = input @@ -608,9 +608,9 @@ where BK: BooleanArray + U128Conversions, TV: BooleanArray + U128Conversions, TS: BooleanArray + U128Conversions, - for<'a> Replicated: BooleanArrayMul, - for<'a> Replicated: BooleanArrayMul, - for<'a> Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, { assert!(!rows_for_user.is_empty()); if rows_for_user.len() == 1 { diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index 923f3fafe..c24ab8b5f 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -19,19 +19,30 @@ use rand_core::SeedableRng; use shuttle::future as tokio; use typenum::Unsigned; +#[cfg(any( + test, + feature = "cli", + feature = "test-fixture", + feature = "weak-field" +))] +use crate::ff::FieldType; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] use crate::{ ff::Fp32BitPrime, query::runner::execute_test_multiply, query::runner::test_add_in_prime_field, }; use crate::{ - ff::{boolean_array::BA32, FieldType, Serializable}, + ff::{boolean_array::BA32, Serializable}, helpers::{ negotiate_prss, query::{QueryConfig, QueryType}, BodyStream, Gateway, }, hpke::PrivateKeyRegistry, - protocol::{context::SemiHonestContext, prss::Endpoint as PrssEndpoint, Gate}, + protocol::{ + context::{MaliciousContext, SemiHonestContext}, + prss::Endpoint as PrssEndpoint, + Gate, + }, query::{ runner::{OprfIpaQuery, QueryResult}, state::RunningQuery, @@ -101,35 +112,29 @@ pub fn execute( } // TODO(953): This is really using BA32, not Fp32bitPrime. The `FieldType` mechanism needs // to be reworked. - (QueryType::OprfIpa(ipa_config), FieldType::Fp32BitPrime) => do_query( + (QueryType::SemiHonestOprfIpa(ipa_config), _) => do_query( config, gateway, input, move |prss, gateway, config, input| { let ctx = SemiHonestContext::new(prss, gateway); Box::pin( - OprfIpaQuery::::new(ipa_config, key_registry) + OprfIpaQuery::<_, BA32, R>::new(ipa_config, key_registry) .execute(ctx, config.size, input) .then(|res| ready(res.map(|out| Box::new(out) as Box))), ) }, ), - // TODO(953): This is not doing anything differently than the Fp32BitPrime case, except - // using 16 bits for histogram values - #[cfg(any(test, feature = "weak-field"))] - (QueryType::OprfIpa(ipa_config), FieldType::Fp31) => do_query( + (QueryType::MaliciousOprfIpa(ipa_config), _) => do_query( config, gateway, input, move |prss, gateway, config, input| { - let ctx = SemiHonestContext::new(prss, gateway); + let ctx = MaliciousContext::new(prss, gateway); Box::pin( - OprfIpaQuery::::new( - ipa_config, - key_registry, - ) - .execute(ctx, config.size, input) - .then(|res| ready(res.map(|out| Box::new(out) as Box))), + OprfIpaQuery::<_, BA32, R>::new(ipa_config, key_registry) + .execute(ctx, config.size, input) + .then(|res| ready(res.map(|out| Box::new(out) as Box))), ) }, ), diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index a779b5fa6..c399bb019 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -677,7 +677,7 @@ mod tests { QueryConfig { size: record_count.try_into().unwrap(), field_type: FieldType::Fp31, - query_type: QueryType::OprfIpa(IpaQueryConfig { + query_type: QueryType::SemiHonestOprfIpa(IpaQueryConfig { per_user_credit_cap: 8, max_breakdown_key: 3, attribution_window_seconds: None, diff --git a/ipa-core/src/query/runner/oprf_ipa.rs b/ipa-core/src/query/runner/oprf_ipa.rs index 320d2246b..99c59b727 100644 --- a/ipa-core/src/query/runner/oprf_ipa.rs +++ b/ipa-core/src/query/runner/oprf_ipa.rs @@ -8,6 +8,8 @@ use crate::{ ff::{ boolean::Boolean, boolean_array::{BooleanArray, BA20, BA3, BA8}, + curve_points::RP25519, + ec_prime_field::Fp25519, Field, Serializable, U128Conversions, }, helpers::{ @@ -16,26 +18,31 @@ use crate::{ }, hpke::PrivateKeyRegistry, protocol::{ - basics::ShareKnownValue, - context::{Context, SemiHonestContext}, - ipa_prf::{oprf_ipa, oprf_padding::PaddingParameters, OPRFIPAInputRow}, + basics::{BooleanArrayMul, Reveal, ShareKnownValue}, + context::{DZKPUpgraded, MacUpgraded, UpgradableContext}, + ipa_prf::{ + oprf_ipa, oprf_padding::PaddingParameters, prf_eval::PrfSharing, OPRFIPAInputRow, + AGG_CHUNK, CONV_CHUNK, PRF_CHUNK, SORT_CHUNK, + }, + prss::FromPrss, step::ProtocolStep::IpaPrf, + BooleanProtocols, }, report::{EncryptedOprfReport, EventType}, secret_sharing::{ replicated::semi_honest::{AdditiveShare as Replicated, AdditiveShare}, - BitDecomposed, SharedValue, TransposeFrom, + BitDecomposed, SharedValue, TransposeFrom, Vectorizable, }, sync::Arc, }; -pub struct OprfIpaQuery<'a, HV, R: PrivateKeyRegistry> { +pub struct OprfIpaQuery { config: IpaQueryConfig, key_registry: Arc, - phantom_data: PhantomData<&'a HV>, + phantom_data: PhantomData<(C, HV)>, } -impl<'a, HV, R: PrivateKeyRegistry> OprfIpaQuery<'a, HV, R> { +impl OprfIpaQuery { pub fn new(config: IpaQueryConfig, key_registry: Arc) -> Self { Self { config, @@ -46,11 +53,24 @@ impl<'a, HV, R: PrivateKeyRegistry> OprfIpaQuery<'a, HV, R> { } #[allow(clippy::too_many_lines)] -impl<'ctx, HV, R> OprfIpaQuery<'ctx, HV, R> +impl OprfIpaQuery where + C: UpgradableContext, HV: BooleanArray + U128Conversions, R: PrivateKeyRegistry, - Replicated: Serializable + ShareKnownValue, Boolean>, + Replicated: Serializable + ShareKnownValue, + Replicated: BooleanProtocols>, + Replicated: BooleanProtocols, 256>, + Replicated: BooleanProtocols, AGG_CHUNK>, + Replicated: BooleanProtocols, CONV_CHUNK>, + Replicated: BooleanProtocols, SORT_CHUNK>, + Replicated: + PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, + Replicated: + Reveal, Output = >::Array>, + Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul>, Vec>: for<'a> TransposeFrom<&'a BitDecomposed>, Error = LengthError>, BitDecomposed>: @@ -59,7 +79,7 @@ where #[tracing::instrument("oprf_ipa_query", skip_all, fields(sz=%query_size))] pub async fn execute( self, - ctx: SemiHonestContext<'ctx>, + ctx: C, query_size: QuerySize, input_stream: BodyStream, ) -> Result>, Error> { @@ -127,11 +147,11 @@ where #[cfg(not(any(test, feature = "cli", feature = "test-fixture")))] let padding_params = PaddingParameters::default(); match config.per_user_credit_cap { - 8 => oprf_ipa::(ctx, input, aws, dp_params, padding_params).await, - 16 => oprf_ipa::(ctx, input, aws, dp_params, padding_params).await, - 32 => oprf_ipa::(ctx, input, aws, dp_params, padding_params).await, - 64 => oprf_ipa::(ctx, input, aws, dp_params, padding_params).await, - 128 => oprf_ipa::(ctx, input, aws, dp_params, padding_params).await, + 8 => oprf_ipa::<_, BA8, BA3, HV, BA20, 3, 256>(ctx, input, aws, dp_params, padding_params).await, + 16 => oprf_ipa::<_, BA8, BA3, HV, BA20, 4, 256>(ctx, input, aws, dp_params, padding_params).await, + 32 => oprf_ipa::<_, BA8, BA3, HV, BA20, 5, 256>(ctx, input, aws, dp_params, padding_params).await, + 64 => oprf_ipa::<_, BA8, BA3, HV, BA20, 6, 256>(ctx, input, aws, dp_params, padding_params).await, + 128 => oprf_ipa::<_, BA8, BA3, HV, BA20, 7, 256>(ctx, input, aws, dp_params, padding_params).await, _ => panic!( "Invalid value specified for per-user cap: {:?}. Must be one of 8, 16, 32, 64, or 128.", config.per_user_credit_cap @@ -243,8 +263,11 @@ mod tests { }; let input = BodyStream::from(buffer); - OprfIpaQuery::>::new(query_config, Arc::clone(&key_registry)) - .execute(ctx, query_size, input) + OprfIpaQuery::<_, BA16, KeyRegistry>::new( + query_config, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) })) .await; diff --git a/ipa-core/src/secret_sharing/vector/traits.rs b/ipa-core/src/secret_sharing/vector/traits.rs index b44316b70..8194fb103 100644 --- a/ipa-core/src/secret_sharing/vector/traits.rs +++ b/ipa-core/src/secret_sharing/vector/traits.rs @@ -108,7 +108,7 @@ pub trait SharedValueArray: pub trait FieldArray: SharedValueArray + FromRandom - + for<'a> Mul + + Mul + for<'a> Mul<&'a F, Output = Self> + for<'a> Mul<&'a Self, Output = Self> { diff --git a/ipa-core/src/test_fixture/ipa.rs b/ipa-core/src/test_fixture/ipa.rs index b14324c8e..fd91081db 100644 --- a/ipa-core/src/test_fixture/ipa.rs +++ b/ipa-core/src/test_fixture/ipa.rs @@ -219,7 +219,7 @@ pub async fn test_oprf_ipa( world.semi_honest( records.into_iter(), |ctx, input_rows: Vec>| async move { - oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + oprf_ipa::<_, BA5, BA8, BA32, BA20, 8, 32>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap() }, @@ -231,19 +231,19 @@ pub async fn test_oprf_ipa( |ctx, input_rows: Vec>| async move { match config.per_user_credit_cap { - 8 => oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + 8 => oprf_ipa::<_, BA8, BA3, BA32, BA20, 3, 256>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap(), - 16 => oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + 16 => oprf_ipa::<_, BA8, BA3, BA32, BA20, 4, 256>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap(), - 32 => oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + 32 => oprf_ipa::<_, BA8, BA3, BA32, BA20, 5, 256>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap(), - 64 => oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + 64 => oprf_ipa::<_, BA8, BA3, BA32, BA20, 6, 256>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap(), - 128 => oprf_ipa::(ctx, input_rows, aws, dp_params, padding_params) + 128 => oprf_ipa::<_, BA8, BA3, BA32, BA20, 7, 256>(ctx, input_rows, aws, dp_params, padding_params) .await .unwrap(), _ => diff --git a/ipa-core/tests/common/mod.rs b/ipa-core/tests/common/mod.rs index 56cdc37ce..d04d99efd 100644 --- a/ipa-core/tests/common/mod.rs +++ b/ipa-core/tests/common/mod.rs @@ -264,8 +264,8 @@ pub fn test_ipa_with_config(mode: IpaSecurityModel, https: bool, config: IpaQuer } let protocol = match mode { - IpaSecurityModel::SemiHonest => "oprf-ipa", - IpaSecurityModel::Malicious => "malicious-ipa", + IpaSecurityModel::SemiHonest => "semi-honest-oprf-ipa", + IpaSecurityModel::Malicious => "malicious-oprf-ipa", }; command .arg(protocol) diff --git a/ipa-core/tests/compact_gate.rs b/ipa-core/tests/compact_gate.rs index f847275f3..c67be42bc 100644 --- a/ipa-core/tests/compact_gate.rs +++ b/ipa-core/tests/compact_gate.rs @@ -28,6 +28,12 @@ fn compact_gate_cap_8_no_window_semi_honest() { test_compact_gate(IpaSecurityModel::SemiHonest, 8, 0); } +#[test] +#[ignore] // TODO +fn compact_gate_cap_8_no_window_malicious() { + test_compact_gate(IpaSecurityModel::Malicious, 8, 0); +} + #[test] fn compact_gate_cap_8_with_window_semi_honest() { test_compact_gate(IpaSecurityModel::SemiHonest, 8, 86400); diff --git a/ipa-core/tests/encrypted_input.rs b/ipa-core/tests/encrypted_input.rs index 8a7853344..23f5e1aa3 100644 --- a/ipa-core/tests/encrypted_input.rs +++ b/ipa-core/tests/encrypted_input.rs @@ -180,7 +180,7 @@ public_key = "b900be35da06106a83ed73c33f733e03e4ea5888b7ea4c912ab270b0b0f8381e" .expect("manually constructed for test"), )])); - OprfIpaQuery::>::new( + OprfIpaQuery::<_, BA16, KeyRegistry>::new( query_config, private_registry, ) diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index d8e8fedd0..8bbcc0622 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -54,6 +54,12 @@ fn https_semi_honest_ipa() { test_ipa(IpaSecurityModel::SemiHonest, true); } +#[test] +#[cfg(all(test, web_test, not(feature = "compact-gate")))] // TODO: enable for compact gate +fn https_malicious_ipa() { + test_ipa(IpaSecurityModel::Malicious, true); +} + /// Similar to [`network`] tests, but it uses keygen + confgen CLIs to generate helper client config /// and then just runs test multiply to make sure helpers are up and running /// From a15116a6d8fc1b3bf58b0b3fc337f5896bb13bb1 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Tue, 10 Sep 2024 18:29:46 -0700 Subject: [PATCH 008/191] Update compact gate for malicious security --- .../src/protocol/basics/mul/dzkp_malicious.rs | 4 +- ipa-core/src/protocol/basics/mul/mod.rs | 2 +- ipa-core/src/protocol/context/batcher.rs | 19 ++-- .../src/protocol/context/dzkp_malicious.rs | 38 ++++---- .../src/protocol/context/dzkp_validator.rs | 95 ++++++++++++------- ipa-core/src/protocol/context/malicious.rs | 39 ++++++-- ipa-core/src/protocol/context/mod.rs | 19 +++- ipa-core/src/protocol/context/semi_honest.rs | 21 ++-- ipa-core/src/protocol/context/step.rs | 40 +++++--- ipa-core/src/protocol/context/validator.rs | 4 +- ipa-core/src/protocol/dp/mod.rs | 16 +++- .../src/protocol/ipa_prf/aggregation/mod.rs | 30 ++++-- .../src/protocol/ipa_prf/aggregation/step.rs | 4 + .../boolean_ops/share_conversion_aby.rs | 8 +- ipa-core/src/protocol/ipa_prf/mod.rs | 34 ++++--- ipa-core/src/protocol/ipa_prf/prf_eval.rs | 1 - .../src/protocol/ipa_prf/prf_sharding/mod.rs | 51 +++++----- .../src/protocol/ipa_prf/prf_sharding/step.rs | 2 + ipa-core/src/protocol/ipa_prf/quicksort.rs | 18 ++-- ipa-core/src/protocol/ipa_prf/step.rs | 14 ++- .../ipa_prf/validation_protocol/validation.rs | 6 +- ipa-core/src/protocol/step.rs | 2 - ipa-core/src/secret_sharing/vector/impls.rs | 9 +- ipa-core/src/test_fixture/world.rs | 4 +- ipa-core/tests/compact_gate.rs | 1 - ipa-core/tests/helper_networks.rs | 2 +- 26 files changed, 309 insertions(+), 174 deletions(-) diff --git a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs index 74df755c3..23a96c982 100644 --- a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs +++ b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs @@ -84,7 +84,7 @@ mod test { ff::boolean::Boolean, protocol::{ basics::SecureMul, - context::{dzkp_validator::DZKPValidator, Context, UpgradableContext}, + context::{dzkp_validator::DZKPValidator, Context, UpgradableContext, TEST_DZKP_STEPS}, RecordId, }, rand::{thread_rng, Rng}, @@ -101,7 +101,7 @@ mod test { let res = world .malicious((a, b), |ctx, (a, b)| async move { - let validator = ctx.dzkp_validator(10); + let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 10); let mctx = validator.context(); let result = a .multiply(&b, mctx.set_total_records(1), RecordId::from(0)) diff --git a/ipa-core/src/protocol/basics/mul/mod.rs b/ipa-core/src/protocol/basics/mul/mod.rs index 3194be9fb..260d7a4b8 100644 --- a/ipa-core/src/protocol/basics/mul/mod.rs +++ b/ipa-core/src/protocol/basics/mul/mod.rs @@ -27,7 +27,7 @@ use crate::{ mod dzkp_malicious; pub(crate) mod malicious; mod semi_honest; -pub(in crate::protocol) mod step; +pub(crate) mod step; pub use semi_honest::sh_multiply as semi_honest_multiply; diff --git a/ipa-core/src/protocol/context/batcher.rs b/ipa-core/src/protocol/context/batcher.rs index 96337b629..f7021d988 100644 --- a/ipa-core/src/protocol/context/batcher.rs +++ b/ipa-core/src/protocol/context/batcher.rs @@ -3,12 +3,7 @@ use std::{cmp::min, collections::VecDeque, future::Future}; use bitvec::{bitvec, prelude::BitVec}; use tokio::sync::watch; -use crate::{ - error::Error, - helpers::TotalRecords, - protocol::RecordId, - sync::{Arc, Mutex}, -}; +use crate::{error::Error, helpers::TotalRecords, protocol::RecordId, sync::Mutex}; /// Manages validation of batches of records for malicious protocols. /// @@ -88,14 +83,14 @@ impl<'a, B> Batcher<'a, B> { records_per_batch: usize, total_records: T, batch_constructor: Box B + Send + 'a>, - ) -> Arc> { - Arc::new(Mutex::new(Self { + ) -> Mutex { + Mutex::new(Self { batches: VecDeque::new(), first_batch: 0, records_per_batch, total_records: total_records.into(), batch_constructor, - })) + }) } pub fn set_total_records>(&mut self, total_records: T) { @@ -550,7 +545,7 @@ mod tests { .push(i); } - let batcher = Arc::into_inner(batcher).unwrap().into_inner().unwrap(); + let batcher = batcher.into_inner().unwrap(); assert_eq!(batcher.into_single_batch(), vec![0, 1]); } @@ -568,7 +563,7 @@ mod tests { .push(i); } - let batcher = Arc::into_inner(batcher).unwrap().into_inner().unwrap(); + let batcher = batcher.into_inner().unwrap(); batcher.into_single_batch(); } @@ -602,7 +597,7 @@ mod tests { }); assert_eq!(try_join(fut1, fut2).await.unwrap(), ((), ())); - let batcher = Arc::into_inner(batcher).unwrap().into_inner().unwrap(); + let batcher = batcher.into_inner().unwrap(); batcher.into_single_batch(); } } diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index b553e41b5..73cfda40c 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -11,33 +11,33 @@ use crate::{ helpers::{MpcMessage, MpcReceivingEnd, Role, SendingEnd, TotalRecords}, protocol::{ context::{ - batcher::Batcher, - dzkp_validator::{Batch, Segment}, + dzkp_validator::{Batch, MaliciousDZKPValidatorInner, Segment}, prss::InstrumentedIndexedSharedRandomness, - step::ZeroKnowledgeProofValidateStep, + step::DzkpBatchStep, Context as ContextTrait, DZKPContext, InstrumentedSequentialSharedRandomness, MaliciousContext, }, Gate, RecordId, }, seq_join::SeqJoin, - sync::{Arc, Mutex, Weak}, + sync::{Arc, Weak}, }; -pub(super) type DzkpBatcher<'a> = Mutex>; - /// Represents protocol context in malicious setting when using zero-knowledge proofs, /// i.e. secure against one active adversary in 3 party MPC ring. #[derive(Clone)] pub struct DZKPUpgraded<'a> { - batcher: Weak>, + validator_inner: Weak>, base_ctx: MaliciousContext<'a>, } impl<'a> DZKPUpgraded<'a> { - pub(super) fn new(batch: &Arc>, base_ctx: MaliciousContext<'a>) -> Self { + pub(super) fn new( + validator_inner: &Arc>, + base_ctx: MaliciousContext<'a>, + ) -> Self { Self { - batcher: Arc::downgrade(batch), + validator_inner: Arc::downgrade(validator_inner), base_ctx, } } @@ -49,10 +49,10 @@ impl<'a> DZKPUpgraded<'a> { } fn with_batch T, T>(&self, record_id: RecordId, action: C) -> T { - let batcher = self.batcher.upgrade().expect("Validator is active"); + let validator_inner = self.validator_inner.upgrade().expect("Validator is active"); - let mut batch = batcher.lock().unwrap(); - let state = batch.get_batch(record_id); + let mut batcher = validator_inner.batcher.lock().unwrap(); + let state = batcher.get_batch(record_id); (action)(&mut state.batch) } } @@ -60,18 +60,16 @@ impl<'a> DZKPUpgraded<'a> { #[async_trait] impl<'a> DZKPContext for DZKPUpgraded<'a> { async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> { - let validation_future = self + let validator_inner = self.validator_inner.upgrade().expect("validator is active"); + + let ctx = validator_inner.validate_ctx.clone(); + + let validation_future = validator_inner .batcher - .upgrade() - .expect("Validation batch is active") .lock() .unwrap() .validate_record(record_id, |batch_idx, batch| { - batch.validate( - self.base_ctx - .narrow(&ZeroKnowledgeProofValidateStep::DZKPValidate(batch_idx)) - .validator_context(), - ) + batch.validate(ctx.narrow(&DzkpBatchStep(batch_idx))) }); validation_future.await diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 73a32e62b..517d4db46 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -3,6 +3,7 @@ use std::{cmp, collections::BTreeMap, fmt::Debug, future::ready}; use async_trait::async_trait; use bitvec::prelude::{BitArray, BitSlice, Lsb0}; use futures::{stream, Future, FutureExt, Stream, StreamExt}; +use ipa_step::StepNarrow; use crate::{ error::{BoxError, Error}, @@ -12,17 +13,17 @@ use crate::{ context::{ batcher::Batcher, dzkp_field::{DZKPBaseField, UVTupleBlock}, - dzkp_malicious::{DZKPUpgraded as MaliciousDZKPUpgraded, DzkpBatcher}, + dzkp_malicious::DZKPUpgraded as MaliciousDZKPUpgraded, dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded, - step::ZeroKnowledgeProofValidateStep as Step, - Base, Context, DZKPContext, MaliciousContext, + step::{DzkpSingleBatchStep, DzkpValidationProtocolStep as Step}, + Base, Context, DZKPContext, MaliciousContext, MaliciousProtocolSteps, }, ipa_prf::validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, Gate, RecordId, }, seq_join::{seq_join, SeqJoin}, sharding::ShardBinding, - sync::Arc, + sync::{Arc, Mutex}, }; pub type Array256Bit = BitArray<[u8; 32], Lsb0>; @@ -696,15 +697,23 @@ impl<'a, B: ShardBinding> DZKPValidator for SemiHonestDZKPValidator<'a, B> { } } +type DzkpBatcher<'a> = Batcher<'a, Batch>; + +/// The DZKP validator, and all associated contexts, each hold a reference to a single +/// instance of `MaliciousDZKPValidatorInner`. +pub(super) struct MaliciousDZKPValidatorInner<'a> { + pub(super) batcher: Mutex>, + pub(super) validate_ctx: Base<'a>, +} + /// `MaliciousDZKPValidator` corresponds to pub struct `Malicious` and implements the trait `DZKPValidator` /// The implementation of `validate` of the `DZKPValidator` trait depends on generic `DF` pub struct MaliciousDZKPValidator<'a> { // This is an `Option` because we want to consume it in `DZKPValidator::validate`, // but we also want to implement `Drop`. Note that the `is_verified` check in `Drop` // does nothing when `batcher_ref` is already `None`. - batcher_ref: Option>>, + inner_ref: Option>>, protocol_ctx: MaliciousDZKPUpgraded<'a>, - validate_ctx: MaliciousContext<'a>, } #[async_trait] @@ -716,30 +725,31 @@ impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> { } fn set_total_records>(&mut self, total_records: T) { - self.batcher_ref + self.inner_ref .as_ref() - .unwrap() + .expect("validator should be active") + .batcher .lock() .unwrap() .set_total_records(total_records); } async fn validate(mut self) -> Result<(), Error> { - let batcher_arc = self - .batcher_ref + let arc = self + .inner_ref .take() .expect("nothing else should be consuming the batcher"); - let batcher_mutex = Arc::into_inner(batcher_arc) + let MaliciousDZKPValidatorInner { + batcher: batcher_mutex, + validate_ctx, + } = Arc::into_inner(arc) .expect("validator should hold the only strong reference to batcher"); + let batcher = batcher_mutex.into_inner().unwrap(); batcher .into_single_batch() - .validate( - self.validate_ctx - .narrow(&Step::DZKPValidate(0)) - .validator_context(), - ) + .validate(validate_ctx.narrow(&DzkpSingleBatchStep)) .await } @@ -749,7 +759,13 @@ impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> { /// ## Errors /// Errors when there are `MultiplicationInputs` that have not been verified. fn is_verified(&self) -> Result<(), Error> { - let batcher = self.batcher_ref.as_ref().unwrap().lock().unwrap(); + let batcher = self + .inner_ref + .as_ref() + .expect("validator should be active") + .batcher + .lock() + .unwrap(); if batcher.is_empty() { Ok(()) } else { @@ -760,7 +776,16 @@ impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> { impl<'a> MaliciousDZKPValidator<'a> { #[must_use] - pub fn new(ctx: MaliciousContext<'a>, max_multiplications_per_gate: usize) -> Self { + #[allow(clippy::needless_pass_by_value)] + pub fn new( + ctx: MaliciousContext<'a>, + steps: MaliciousProtocolSteps, + max_multiplications_per_gate: usize, + ) -> Self + where + Gate: StepNarrow, + S: ipa_step::Step + ?Sized, + { let batcher = Batcher::new( max_multiplications_per_gate, ctx.total_records(), @@ -771,19 +796,21 @@ impl<'a> MaliciousDZKPValidator<'a> { ) }), ); - let protocol_ctx = - MaliciousDZKPUpgraded::new(&batcher, ctx.narrow(&Step::DZKPMaliciousProtocol)); + let inner = Arc::new(MaliciousDZKPValidatorInner { + batcher, + validate_ctx: ctx.narrow(steps.validate).validator_context(), + }); + let protocol_ctx = MaliciousDZKPUpgraded::new(&inner, ctx.narrow(steps.protocol)); Self { - batcher_ref: Some(batcher), + inner_ref: Some(inner), protocol_ctx, - validate_ctx: ctx, } } } impl<'a> Drop for MaliciousDZKPValidator<'a> { fn drop(&mut self) { - if self.batcher_ref.is_some() { + if self.inner_ref.is_some() { self.is_verified().unwrap(); } } @@ -811,10 +838,9 @@ mod tests { context::{ dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, dzkp_validator::{ - Batch, DZKPValidator, Segment, SegmentEntry, Step, BIT_ARRAY_LEN, - TARGET_PROOF_SIZE, + Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, }, - Context, UpgradableContext, + Context, UpgradableContext, TEST_DZKP_STEPS, }, Gate, RecordId, }, @@ -839,11 +865,8 @@ mod tests { .malicious( original_inputs.clone().into_iter(), |ctx, input_shares| async move { - let v = ctx.dzkp_validator(COUNT); - let m_ctx = v - .context() - .narrow(&Step::DZKPMaliciousProtocol) - .set_total_records(COUNT - 1); + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, COUNT); + let m_ctx = v.context().set_total_records(COUNT - 1); let m_results = seq_join( NonZeroUsize::new(COUNT).unwrap(), @@ -922,7 +945,7 @@ mod tests { .map(|(ctx, input_shares)| async move { let v = ctx .set_total_records(count - 1) - .dzkp_validator(ctx.active_work().get()); + .dzkp_validator(TEST_DZKP_STEPS, ctx.active_work().get()); let m_ctx = v.context(); let m_results = v @@ -951,7 +974,7 @@ mod tests { .into_iter() .zip([h1_shares, h2_shares, h3_shares]) .map(|(ctx, input_shares)| async move { - let v = ctx.dzkp_validator(max_multiplications_per_gate); + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); let m_ctx = v.context(); let m_results = v @@ -1311,14 +1334,16 @@ mod tests { let [h1_batch, h2_batch, h3_batch] = world .malicious((a, b), |ctx, (a, b)| async move { - let mut validator = ctx.dzkp_validator(10); + let mut validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 10); let mctx = validator.context(); let _ = a .multiply(&b, mctx.set_total_records(1), RecordId::from(0)) .await .unwrap(); - let batcher_mutex = Arc::into_inner(validator.batcher_ref.take().unwrap()).unwrap(); + let batcher_mutex = Arc::into_inner(validator.inner_ref.take().unwrap()) + .unwrap() + .batcher; batcher_mutex.into_inner().unwrap().into_single_batch() }) .await; diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index d001b8ab6..a93a8edfb 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -13,11 +13,14 @@ use crate::{ protocol::{ basics::mul::{semi_honest_multiply, step::MaliciousMultiplyStep::RandomnessForValidation}, context::{ - batcher::Batcher, dzkp_validator::MaliciousDZKPValidator, - prss::InstrumentedIndexedSharedRandomness, step::UpgradeStep, upgrade::Upgradable, - validator, validator::BatchValidator, Base, Context as ContextTrait, - InstrumentedSequentialSharedRandomness, SpecialAccessToUpgradedContext, - UpgradableContext, UpgradedContext, + batcher::Batcher, + dzkp_validator::MaliciousDZKPValidator, + prss::InstrumentedIndexedSharedRandomness, + step::UpgradeStep, + upgrade::Upgradable, + validator::{self, BatchValidator}, + Base, Context as ContextTrait, InstrumentedSequentialSharedRandomness, + SpecialAccessToUpgradedContext, UpgradableContext, UpgradedContext, }, prss::{Endpoint as PrssEndpoint, FromPrss}, Gate, RecordId, @@ -31,6 +34,20 @@ use crate::{ sync::Arc, }; +pub struct MaliciousProtocolSteps<'a, S: Step + ?Sized> { + pub protocol: &'a S, + pub validate: &'a S, +} + +#[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))] +pub(crate) const TEST_DZKP_STEPS: MaliciousProtocolSteps< + 'static, + super::step::MaliciousProtocolStep, +> = MaliciousProtocolSteps { + protocol: &super::step::MaliciousProtocolStep::MaliciousProtocol, + validate: &super::step::MaliciousProtocolStep::Validate, +}; + #[derive(Clone)] pub struct Context<'a> { inner: Base<'a>, @@ -122,8 +139,16 @@ impl<'a> UpgradableContext for Context<'a> { type DZKPValidator = MaliciousDZKPValidator<'a>; - fn dzkp_validator(self, max_multiplications_per_gate: usize) -> Self::DZKPValidator { - MaliciousDZKPValidator::new(self, max_multiplications_per_gate) + fn dzkp_validator( + self, + steps: MaliciousProtocolSteps, + max_multiplications_per_gate: usize, + ) -> Self::DZKPValidator + where + Gate: StepNarrow, + S: Step + ?Sized, + { + MaliciousDZKPValidator::new(self, steps, max_multiplications_per_gate) } } diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index af2e0142b..4a090bae3 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -9,9 +9,6 @@ pub mod step; pub mod upgrade; mod batcher; -/// Validators are not used in IPA v3 yet. Once we make use of MAC-based validation, -/// this flag can be removed -#[allow(dead_code)] pub mod validator; use std::{collections::HashMap, iter, num::NonZeroUsize, pin::pin}; @@ -21,13 +18,18 @@ pub use dzkp_malicious::DZKPUpgraded as DZKPUpgradedMaliciousContext; pub use dzkp_semi_honest::DZKPUpgraded as DZKPUpgradedSemiHonestContext; use futures::{stream, Stream, StreamExt}; use ipa_step::{Step, StepNarrow}; -pub use malicious::{Context as MaliciousContext, Upgraded as UpgradedMaliciousContext}; +pub use malicious::{ + Context as MaliciousContext, MaliciousProtocolSteps, Upgraded as UpgradedMaliciousContext, +}; use prss::{InstrumentedIndexedSharedRandomness, InstrumentedSequentialSharedRandomness}; pub use semi_honest::Upgraded as UpgradedSemiHonestContext; pub use validator::Validator; pub type SemiHonestContext<'a, B = NotSharded> = semi_honest::Context<'a, B>; pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>; +#[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))] +pub(crate) use malicious::TEST_DZKP_STEPS; + use crate::{ error::Error, helpers::{ @@ -109,7 +111,14 @@ pub trait UpgradableContext: Context { type DZKPValidator: DZKPValidator; - fn dzkp_validator(self, max_multiplications_per_gate: usize) -> Self::DZKPValidator; + fn dzkp_validator( + self, + steps: MaliciousProtocolSteps, + max_multiplications_per_gate: usize, + ) -> Self::DZKPValidator + where + Gate: StepNarrow, + S: Step + ?Sized; } pub type MacUpgraded = <::Validator as Validator>::Context; diff --git a/ipa-core/src/protocol/context/semi_honest.rs b/ipa-core/src/protocol/context/semi_honest.rs index 1be359879..65f4e644e 100644 --- a/ipa-core/src/protocol/context/semi_honest.rs +++ b/ipa-core/src/protocol/context/semi_honest.rs @@ -16,9 +16,10 @@ use crate::{ }, protocol::{ context::{ - dzkp_validator::SemiHonestDZKPValidator, upgrade::Upgradable, - validator::SemiHonest as Validator, Base, InstrumentedIndexedSharedRandomness, - InstrumentedSequentialSharedRandomness, ShardedContext, SpecialAccessToUpgradedContext, + dzkp_validator::SemiHonestDZKPValidator, step::MaliciousProtocolStep, + upgrade::Upgradable, validator::SemiHonest as Validator, Base, Context as _, + InstrumentedIndexedSharedRandomness, InstrumentedSequentialSharedRandomness, + MaliciousProtocolSteps, ShardedContext, SpecialAccessToUpgradedContext, UpgradableContext, UpgradedContext, }, prss::Endpoint as PrssEndpoint, @@ -151,13 +152,21 @@ impl<'a, B: ShardBinding> UpgradableContext for Context<'a, B> { type Validator = Validator<'a, B, F>; fn validator(self) -> Self::Validator { - Self::Validator::new(self.inner) + Self::Validator::new(self.inner.narrow(&MaliciousProtocolStep::MaliciousProtocol)) } type DZKPValidator = SemiHonestDZKPValidator<'a, B>; - fn dzkp_validator(self, _max_multiplications_per_gate: usize) -> Self::DZKPValidator { - Self::DZKPValidator::new(self.inner) + fn dzkp_validator( + self, + steps: MaliciousProtocolSteps, + _max_multiplications_per_gate: usize, + ) -> Self::DZKPValidator + where + S: ipa_step::Step + ?Sized, + Gate: StepNarrow, + { + Self::DZKPValidator::new(self.inner.narrow(steps.protocol)) } } diff --git a/ipa-core/src/protocol/context/step.rs b/ipa-core/src/protocol/context/step.rs index a3fcdab02..d650340f8 100644 --- a/ipa-core/src/protocol/context/step.rs +++ b/ipa-core/src/protocol/context/step.rs @@ -2,7 +2,7 @@ use ipa_step_derive::CompactStep; /// Upgrades all use this step to distinguish protocol steps from the step that is used to upgrade inputs. #[derive(CompactStep)] -#[step(name = "upgrade")] +#[step(name = "upgrade", child = crate::protocol::basics::mul::step::MaliciousMultiplyStep)] pub(crate) struct UpgradeStep; /// Steps used by the validation component of malicious protocol execution. @@ -10,8 +10,10 @@ pub(crate) struct UpgradeStep; #[derive(CompactStep)] pub(crate) enum MaliciousProtocolStep { /// For the execution of the malicious protocol. + #[step(child = crate::protocol::ipa_prf::step::PrfStep)] MaliciousProtocol, /// The final validation steps. + #[step(child = ValidateStep)] Validate, } @@ -22,23 +24,37 @@ pub(crate) enum ValidateStep { /// Reveal the value of `r`, necessary for validation. RevealR, /// Check that there is no disagreement between accumulated values. + #[step(child = crate::protocol::basics::step::CheckZeroStep)] CheckZero, } -/// Steps used by the validation component of the DZKP +// This really is only for DZKPs and not for MACs. The MAC protocol uses record IDs to +// count batches. DZKP probably should do the same to avoid the fixed upper limit. #[derive(CompactStep)] -pub(crate) enum ZeroKnowledgeProofValidateStep { - /// For the execution of the malicious protocol. - DZKPMaliciousProtocol, - /// Step for computing `p * q` between proof verifiers - PTimesQ, - /// Step for producing challenge between proof verifiers - Challenge, - /// Steps for validating the DZK proofs for each batch. - #[step(count = 256)] - DZKPValidate(usize), +#[step(count = 192, child = DzkpValidationProtocolStep)] +pub(crate) struct DzkpBatchStep(pub usize); + +// This is used when we don't do batched verification, to avoid paying for x256 as many +// steps in compact gate. +#[derive(CompactStep)] +#[step(child = DzkpValidationProtocolStep)] +pub(crate) struct DzkpSingleBatchStep; + +#[derive(CompactStep)] +pub(crate) enum DzkpValidationProtocolStep { /// Step for proof generation GenerateProof, + /// Step for producing challenge between proof verifiers + Challenge, /// Step for proof verification + #[step(child = DzkpProofVerifyStep)] VerifyProof, } + +#[derive(CompactStep)] +pub(crate) enum DzkpProofVerifyStep { + /// Step for computing `p * q` between proof verifiers + PTimesQ, + /// Step for computing `G_diff` between proof verifiers + Diff, +} diff --git a/ipa-core/src/protocol/context/validator.rs b/ipa-core/src/protocol/context/validator.rs index 63a212a2c..33303fb9b 100644 --- a/ipa-core/src/protocol/context/validator.rs +++ b/ipa-core/src/protocol/context/validator.rs @@ -221,11 +221,11 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> { Self { protocol_ctx: ctx.narrow(&Step::MaliciousProtocol), - batches_ref: Batcher::new( + batches_ref: Arc::new(Batcher::new( records_per_batch, total_records, Box::new(move |batch_index| Malicious::new(ctx.clone(), batch_index)), - ), + )), } } } diff --git a/ipa-core/src/protocol/dp/mod.rs b/ipa-core/src/protocol/dp/mod.rs index a44b73e6d..bb847eabd 100644 --- a/ipa-core/src/protocol/dp/mod.rs +++ b/ipa-core/src/protocol/dp/mod.rs @@ -15,12 +15,16 @@ use crate::{ helpers::{query::DpMechanism, Direction, Role, TotalRecords}, protocol::{ boolean::step::ThirtyTwoBitStep, - context::{dzkp_validator::DZKPValidator, Context, DZKPUpgraded, UpgradableContext}, + context::{ + dzkp_validator::DZKPValidator, Context, DZKPUpgraded, MaliciousProtocolSteps, + UpgradableContext, + }, dp::step::{ApplyDpNoise, DPStep}, ipa_prf::{ aggregation::{aggregate_values, aggregate_values_proof_chunk}, boolean_ops::addition_sequential::integer_add, oprf_padding::insecure::OPRFPaddingDp, + step::IpaPrfStep, }, prss::{FromPrss, SharedRandomness}, BooleanProtocols, RecordId, @@ -224,6 +228,7 @@ where /// # Panics /// may panic from asserts down in `gen_binomial_noise` /// +#[allow(clippy::too_many_lines)] pub async fn dp_for_histogram( ctx: C, histogram_bin_values: BitDecomposed>, @@ -240,6 +245,10 @@ where BitDecomposed>: for<'a> TransposeFrom<&'a [AdditiveShare; B], Error = Infallible>, { + let steps = MaliciousProtocolSteps { + protocol: &IpaPrfStep::DifferentialPrivacy, + validate: &IpaPrfStep::DifferentialPrivacyValidate, + }; match dp_params { DpMechanism::NoDp => Ok(Vec::transposed_from(&histogram_bin_values)?), DpMechanism::Binomial { epsilon } => { @@ -286,7 +295,8 @@ where "num_bernoulli of {num_bernoulli} may result in excessively large DZKP" ); } - let dp_validator = ctx.dzkp_validator(num_bernoulli); + + let dp_validator = ctx.dzkp_validator(steps, num_bernoulli); let noisy_histogram = apply_dp_noise::<_, B, OV>( dp_validator.context(), @@ -328,7 +338,7 @@ where OV::BITS, ); - let dp_validator = ctx.dzkp_validator(1); + let dp_validator = ctx.dzkp_validator(steps, 1); let noised_output = apply_laplace_noise_pass::<_, OV, B>( &dp_validator.context().narrow(&DPStep::LaplacePass1), diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index 9248e5a65..69663f5d5 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -18,7 +18,7 @@ use crate::{ boolean::{step::ThirtyTwoBitStep, NBitStep}, context::{ dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, - Context, DZKPContext, UpgradableContext, + Context, DZKPContext, MaliciousProtocolSteps, UpgradableContext, }, ipa_prf::{ aggregation::step::{AggregateChunkStep, AggregateValuesStep, AggregationStep as Step}, @@ -153,9 +153,14 @@ where let move_to_bucket_records = TotalRecords::specified(div_round_up(contributions_stream_len, Const::))?; let validator = ctx - .narrow(&Step::MoveToBucket) .set_total_records(move_to_bucket_records) - .dzkp_validator(move_to_bucket_chunk_size); + .dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::MoveToBucket, + validate: &Step::MoveToBucketValidate, + }, + move_to_bucket_chunk_size, + ); let bucket_ctx = validator.context(); // move each value to the correct bucket let row_contribution_chunk_stream = process_stream_by_chunks( @@ -223,11 +228,17 @@ where }); let mut intermediate_results = Vec::new(); let mut chunk_counter = 0; + for chunk in chunks { - let ctx = ctx.narrow(&Step::AggregateChunk(chunk_counter)); chunk_counter += 1; let stream = aggregation_input.by_ref().take(chunk); - let validator = ctx.dzkp_validator(agg_proof_chunk); + let validator = ctx.clone().dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::AggregateChunk(chunk_counter), + validate: &Step::AggregateChunkValidate(chunk_counter), + }, + agg_proof_chunk, + ); let result = aggregate_values::<_, HV, B>(validator.context(), stream.boxed(), chunk).await?; validator.validate().await?; @@ -235,9 +246,14 @@ where } if intermediate_results.len() > 1 { - let ctx = ctx.narrow(&Step::AggregateChunk(chunk_counter)); - let validator = ctx.dzkp_validator(agg_proof_chunk); let stream_len = intermediate_results.len(); + let validator = ctx.dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::AggregateChunk(chunk_counter), + validate: &Step::AggregateChunkValidate(chunk_counter), + }, + agg_proof_chunk, + ); let aggregated_result = aggregate_values::<_, HV, B>( validator.context(), stream::iter(intermediate_results).boxed(), diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs index f75654088..2b1995fd2 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs @@ -12,8 +12,12 @@ pub(crate) enum AggregationStep { RevealStep, #[step(child = BucketStep)] MoveToBucket, + #[step(child = crate::protocol::context::step::DzkpBatchStep)] + MoveToBucketValidate, #[step(count = 32, child = AggregateChunkStep)] AggregateChunk(usize), + #[step(count = 32, child = crate::protocol::context::step::DzkpSingleBatchStep)] + AggregateChunkValidate(usize), } /// the number of steps must be kept in sync with `MAX_BREAKDOWNS` defined diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs index 22a6c4ae5..05e3d3b62 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs @@ -380,7 +380,7 @@ mod tests { ff::{boolean_array::BA64, Serializable}, helpers::{repeat_n, stream::process_slice_by_chunks}, protocol::{ - context::{dzkp_validator::DZKPValidator, UpgradableContext}, + context::{dzkp_validator::DZKPValidator, UpgradableContext, TEST_DZKP_STEPS}, ipa_prf::{CONV_CHUNK, CONV_PROOF_CHUNK, PRF_CHUNK}, }, rand::thread_rng, @@ -415,7 +415,7 @@ mod tests { let [res0, res1, res2] = world .semi_honest(records.into_iter(), |ctx, records| async move { let c_ctx = ctx.set_total_records((COUNT + CONV_CHUNK - 1) / CONV_CHUNK); - let validator = &c_ctx.dzkp_validator(CONV_PROOF_CHUNK); + let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, CONV_PROOF_CHUNK); let m_ctx = validator.context(); seq_join( m_ctx.active_work(), @@ -477,7 +477,7 @@ mod tests { let [res0, res1, res2] = world .malicious(records.into_iter(), |ctx, records| async move { let c_ctx = ctx.set_total_records(TOTAL_RECORDS); - let validator = &c_ctx.dzkp_validator(PROOF_CHUNK); + let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, PROOF_CHUNK); let m_ctx = validator.context(); seq_join( m_ctx.active_work(), @@ -518,7 +518,7 @@ mod tests { TestWorld::default() .semi_honest(iter::empty::(), |ctx, _records| async move { let c_ctx = ctx.set_total_records(1); - let validator = &c_ctx.dzkp_validator(1); + let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, 1); let m_ctx = validator.context(); let match_keys = BitDecomposed::new(repeat_n( AdditiveShare::::ZERO, diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 7fdfa7ef0..994713911 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -20,7 +20,10 @@ use crate::{ }, protocol::{ basics::{BooleanArrayMul, BooleanProtocols, Reveal}, - context::{dzkp_validator::DZKPValidator, DZKPUpgraded, MacUpgraded, UpgradableContext}, + context::{ + dzkp_validator::DZKPValidator, DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, + UpgradableContext, + }, ipa_prf::{ boolean_ops::convert_to_fp25519, oprf_padding::apply_dp_padding, @@ -28,6 +31,7 @@ use crate::{ prf_sharding::{ attribute_cap_aggregate, histograms_ranges_sortkeys, PrfShardedIpaInputRow, }, + step::IpaPrfStep, }, prss::FromPrss, RecordId, @@ -296,12 +300,8 @@ where ) .await?; - let noisy_output_histogram = dp_for_histogram::<_, B, HV, SS_BITS>( - ctx.narrow(&Step::DifferentialPrivacy), - output_histogram, - dp_params, - ) - .await?; + let noisy_output_histogram = + dp_for_histogram::<_, B, HV, SS_BITS>(ctx, output_histogram, dp_params).await?; Ok(noisy_output_histogram) } @@ -330,11 +330,15 @@ where let conv_records = TotalRecords::specified(div_round_up(input_rows.len(), Const::))?; let eval_records = TotalRecords::specified(div_round_up(input_rows.len(), Const::))?; - let convert_ctx = ctx - .narrow(&Step::ConvertFp25519) - .set_total_records(conv_records); + let convert_ctx = ctx.set_total_records(conv_records); - let validator = convert_ctx.dzkp_validator(CONV_PROOF_CHUNK); + let validator = convert_ctx.dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::ConvertFp25519, + validate: &Step::ConvertFp25519Validate, + }, + CONV_PROOF_CHUNK, + ); let m_ctx = validator.context(); let curve_pts = seq_join( @@ -354,9 +358,11 @@ where .try_collect::>() .await?; - let eval_ctx = ctx.narrow(&Step::EvalPrf).set_total_records(eval_records); - let prf_key = gen_prf_key(&eval_ctx); - let validator = eval_ctx.validator::(); + let prf_key = gen_prf_key(&ctx.narrow(&IpaPrfStep::PrfKeyGen)); + let validator = ctx + .narrow(&Step::EvalPrf) + .set_total_records(eval_records) + .validator::(); let eval_ctx = validator.context(); let prf_of_match_keys = seq_join( diff --git a/ipa-core/src/protocol/ipa_prf/prf_eval.rs b/ipa-core/src/protocol/ipa_prf/prf_eval.rs index 1870c9ef7..60cc41bdc 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_eval.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_eval.rs @@ -78,7 +78,6 @@ where C: UpgradableContext, Fp25519: Vectorizable, { - let ctx = ctx.narrow(&Step::PRFKeyGen); let v: AdditiveShare = ctx.prss().generate(RecordId::FIRST); v.expand() diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index e91880b92..69b645055 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -30,9 +30,10 @@ use crate::{ }, context::{ dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, - Context, DZKPContext, DZKPUpgraded, UpgradableContext, + Context, DZKPContext, DZKPUpgraded, MaliciousProtocolSteps, UpgradableContext, }, ipa_prf::{ + aggregation::{aggregate_values_proof_chunk, step::AggregationStep}, boolean_ops::{ addition_sequential::integer_add, comparison_and_subtraction_sequential::{compare_gt, integer_sub}, @@ -385,18 +386,10 @@ where (histogram, ranges) } -fn set_up_contexts( - root_ctx: C, - chunk_size: usize, - histogram: &[usize], -) -> Result<(C::DZKPValidator, Vec>), Error> +fn set_up_contexts(ctx: &C, histogram: &[usize]) -> Result, Error> where - C: UpgradableContext, + C: Context, { - let mut dzkp_validator = root_ctx.dzkp_validator(chunk_size); - let ctx = dzkp_validator.context(); - dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); - let mut context_per_row_depth = Vec::with_capacity(histogram.len()); for (row_number, num_users_having_that_row_number) in histogram.iter().enumerate() { if row_number == 0 { @@ -409,7 +402,7 @@ where context_per_row_depth.push(ctx_for_row_number); } } - Ok((dzkp_validator, context_per_row_depth)) + Ok(context_per_row_depth) } /// @@ -513,8 +506,15 @@ where // Tricky hacks to work around the limitations of our current infrastructure let num_outputs = input_rows.len() - histogram[0]; - let (dzkp_validator, ctx_for_row_number) = - set_up_contexts(sh_ctx.narrow(&Step::Attribute), chunk_size, histogram)?; + let mut dzkp_validator = sh_ctx.clone().dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::Attribute, + validate: &Step::AttributeValidate, + }, + chunk_size, + ); + dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); + let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?; // Chunk the incoming stream of records into stream of vectors of records with the same PRF let mut input_stream = stream::iter(input_rows); @@ -535,23 +535,28 @@ where attribution_window_seconds, ); - let aggregation_validator = sh_ctx.narrow(&Step::Aggregate).dzkp_validator(0); - let ctx = aggregation_validator.context(); + let ctx = sh_ctx.narrow(&Step::Aggregate); // New aggregation is still experimental, we need proofs that it is private, // hence it is only enabled behind a feature flag. if cfg!(feature = "reveal-aggregation") { // If there was any error in attribution we stop the execution with an error tracing::warn!("Using the experimental aggregation based on revealing breakdown keys"); + let validator = ctx.dzkp_validator( + MaliciousProtocolSteps { + protocol: &AggregationStep::AggregateChunk(0), + validate: &AggregationStep::AggregateChunkValidate(0), + }, + aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()), + ); let user_contributions = flattened_user_results.try_collect::>().await?; - breakdown_reveal_aggregation::<_, _, _, HV, B>(ctx, user_contributions).await + let result = + breakdown_reveal_aggregation::<_, _, _, HV, B>(validator.context(), user_contributions) + .await; + validator.validate().await?; + result } else { - aggregate_contributions::<_, _, _, _, HV, B>( - sh_ctx.narrow(&Step::Aggregate), - flattened_user_results, - num_outputs, - ) - .await + aggregate_contributions::<_, _, _, _, HV, B>(ctx, flattened_user_results, num_outputs).await } } diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs index d3f7c9edb..e32de5a40 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs @@ -16,6 +16,8 @@ impl From for UserNthRowStep { pub(crate) enum AttributionStep { #[step(child = UserNthRowStep)] Attribute, + #[step(child = crate::protocol::context::step::DzkpBatchStep)] + AttributeValidate, #[step(child = crate::protocol::ipa_prf::aggregation::step::AggregationStep)] Aggregate, } diff --git a/ipa-core/src/protocol/ipa_prf/quicksort.rs b/ipa-core/src/protocol/ipa_prf/quicksort.rs index c3794c222..c198dbe90 100644 --- a/ipa-core/src/protocol/ipa_prf/quicksort.rs +++ b/ipa-core/src/protocol/ipa_prf/quicksort.rs @@ -14,7 +14,10 @@ use crate::{ protocol::{ basics::reveal, boolean::{step::ThirtyTwoBitStep, NBitStep}, - context::{dzkp_validator::DZKPValidator, Context, DZKPUpgraded, UpgradableContext}, + context::{ + dzkp_validator::DZKPValidator, Context, DZKPUpgraded, MaliciousProtocolSteps, + UpgradableContext, + }, ipa_prf::{ boolean_ops::comparison_and_subtraction_sequential::compare_gt, step::{QuicksortPassStep, QuicksortStep as Step}, @@ -166,12 +169,15 @@ where let total_records_usize = div_round_up(num_comparisons_needed, Const::); let total_records = TotalRecords::specified(total_records_usize) .expect("num_comparisons_needed should not be zero"); - let v = ctx - .narrow(&Step::QuicksortPass(quicksort_pass)) - .set_total_records(total_records) + let v = ctx.set_total_records(total_records).dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::QuicksortPass(quicksort_pass), + validate: &Step::QuicksortPassValidate(quicksort_pass), + }, // TODO: use something like this when validating in chunks - //.dzkp_validator(TARGET_PROOF_SIZE / usize::try_from(K::BITS).unwrap() / SORT_CHUNK); - .dzkp_validator(total_records_usize); + // `TARGET_PROOF_SIZE / usize::try_from(K::BITS).unwrap() / SORT_CHUNK`` + total_records_usize, + ); let c = v.context(); let cmp_ctx = c.narrow(&QuicksortPassStep::Compare); let rvl_ctx = c.narrow(&QuicksortPassStep::Reveal); diff --git a/ipa-core/src/protocol/ipa_prf/step.rs b/ipa-core/src/protocol/ipa_prf/step.rs index 2f0ab5d92..633f157cc 100644 --- a/ipa-core/src/protocol/ipa_prf/step.rs +++ b/ipa-core/src/protocol/ipa_prf/step.rs @@ -6,10 +6,12 @@ pub(crate) enum IpaPrfStep { PaddingDp, #[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)] Shuffle, - // ConvertInputRowsToPrf, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::Fp25519ConversionStep)] ConvertFp25519, - #[step(child = PrfStep)] + #[step(child = crate::protocol::context::step::DzkpBatchStep)] + ConvertFp25519Validate, + PrfKeyGen, + #[step(child = crate::protocol::context::step::MaliciousProtocolStep)] EvalPrf, #[step(child = QuicksortStep)] SortByTimestamp, @@ -17,6 +19,8 @@ pub(crate) enum IpaPrfStep { Attribution, #[step(child = crate::protocol::dp::step::DPStep, name = "dp")] DifferentialPrivacy, + #[step(child = crate::protocol::context::step::DzkpSingleBatchStep)] + DifferentialPrivacyValidate, } #[derive(CompactStep)] @@ -24,6 +28,8 @@ pub(crate) enum QuicksortStep { /// Sort up to 1B rows. We can't exceed that limit for other reasons as well `record_id`. #[step(count = 30, child = crate::protocol::ipa_prf::step::QuicksortPassStep)] QuicksortPass(usize), + #[step(count = 30, child = crate::protocol::context::step::DzkpSingleBatchStep)] + QuicksortPassValidate(usize), } #[derive(CompactStep)] @@ -35,10 +41,12 @@ pub(crate) enum QuicksortPassStep { #[derive(CompactStep)] pub(crate) enum PrfStep { - PRFKeyGen, GenRandomMask, + #[step(child = crate::protocol::context::step::UpgradeStep)] UpgradeY, + #[step(child = crate::protocol::context::step::UpgradeStep)] UpgradeMask, + #[step(child = crate::protocol::basics::mul::step::MaliciousMultiplyStep)] MultMaskWithPRFInput, RevealR, Revealz, diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs index 456c7cdf0..bbb994f70 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs @@ -10,7 +10,7 @@ use crate::{ Direction, TotalRecords, }, protocol::{ - context::{step::ZeroKnowledgeProofValidateStep as Step, Context}, + context::{step::DzkpProofVerifyStep as Step, Context}, ipa_prf::{ malicious_security::{ prover::{LargeProofGenerator, SmallProofGenerator}, @@ -245,7 +245,9 @@ impl BatchToVerify { // send dif_left to the right let length = diff_left.len(); - let communication_ctx = ctx.set_total_records(TotalRecords::specified(length)?); + let communication_ctx = ctx + .narrow(&Step::Diff) + .set_total_records(TotalRecords::specified(length)?); let send_channel = communication_ctx.send_channel::(ctx.role().peer(Direction::Right)); diff --git a/ipa-core/src/protocol/step.rs b/ipa-core/src/protocol/step.rs index a38d50593..604052222 100644 --- a/ipa-core/src/protocol/step.rs +++ b/ipa-core/src/protocol/step.rs @@ -43,8 +43,6 @@ pub enum DeadCodeStep { SaturatedSubtraction, #[step(child = crate::protocol::ipa_prf::prf_sharding::step::FeatureLabelDotProductStep)] FeatureLabelDotProduct, - #[step(child = crate::protocol::context::step::ZeroKnowledgeProofValidateStep)] - ZeroKnowledgeProofValidate, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::MultiplicationStep)] Multiplication, } diff --git a/ipa-core/src/secret_sharing/vector/impls.rs b/ipa-core/src/secret_sharing/vector/impls.rs index 536840fa3..974bfd259 100644 --- a/ipa-core/src/secret_sharing/vector/impls.rs +++ b/ipa-core/src/secret_sharing/vector/impls.rs @@ -65,7 +65,10 @@ macro_rules! boolean_vector { error::Error, protocol::{ basics::select, - context::{dzkp_validator::DZKPValidator, Context, UpgradableContext}, + context::{ + dzkp_validator::DZKPValidator, Context, UpgradableContext, + TEST_DZKP_STEPS, + }, RecordId, }, rand::{thread_rng, Rng}, @@ -89,7 +92,7 @@ macro_rules! boolean_vector { let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))) .map(|(ctx, (bit_share, (a_share, b_share)))| async move { - let v = ctx.clone().dzkp_validator(1); + let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); let m_ctx = v.context(); let result = select( @@ -129,7 +132,7 @@ macro_rules! boolean_vector { let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))) .map(|(ctx, (bit_share, (a_share, b_share)))| async move { - let v = ctx.clone().dzkp_validator(1); + let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); let sh_ctx = v.context(); let result = select( diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 6fba504cf..f92326c9b 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -24,7 +24,7 @@ use crate::{ dzkp_validator::DZKPValidator, upgrade::Upgradable, Context, DZKPUpgradedMaliciousContext, MaliciousContext, SemiHonestContext, ShardedSemiHonestContext, UpgradableContext, UpgradedContext, UpgradedMaliciousContext, - UpgradedSemiHonestContext, Validator, + UpgradedSemiHonestContext, Validator, TEST_DZKP_STEPS, }, prss::Endpoint as PrssEndpoint, Gate, QueryId, RecordId, @@ -676,7 +676,7 @@ impl Runner for TestWorld { R: Future + Send, { self.malicious(input, |ctx, share| async { - let v = ctx.dzkp_validator(10); + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 10); let m_ctx = v.context(); let m_result = helper_fn(m_ctx, share).await; v.validate().await.unwrap(); diff --git a/ipa-core/tests/compact_gate.rs b/ipa-core/tests/compact_gate.rs index c67be42bc..1ded887c9 100644 --- a/ipa-core/tests/compact_gate.rs +++ b/ipa-core/tests/compact_gate.rs @@ -29,7 +29,6 @@ fn compact_gate_cap_8_no_window_semi_honest() { } #[test] -#[ignore] // TODO fn compact_gate_cap_8_no_window_malicious() { test_compact_gate(IpaSecurityModel::Malicious, 8, 0); } diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index 8bbcc0622..9fd69bb30 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -55,7 +55,7 @@ fn https_semi_honest_ipa() { } #[test] -#[cfg(all(test, web_test, not(feature = "compact-gate")))] // TODO: enable for compact gate +#[cfg(all(test, web_test))] fn https_malicious_ipa() { test_ipa(IpaSecurityModel::Malicious, true); } From 64b17aa6f792d360b3afafe26e5b08bda566ea5e Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 11 Sep 2024 10:15:02 -0700 Subject: [PATCH 009/191] Updating debian and using Rust image --- docker/helper.Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/helper.Dockerfile b/docker/helper.Dockerfile index 52c0806ab..fc113222d 100644 --- a/docker/helper.Dockerfile +++ b/docker/helper.Dockerfile @@ -1,6 +1,6 @@ # syntax=docker/dockerfile:1 ARG SOURCES_DIR=/usr/src/ipa -FROM rust:bullseye AS builder +FROM rust:bookworm AS builder ARG SOURCES_DIR # Prepare helper binaries @@ -10,7 +10,7 @@ RUN set -eux; \ cargo build --bin helper --release --no-default-features --features "web-app real-world-infra compact-gate" # Copy them to the final image -FROM debian:bullseye-slim +FROM rust:slim-bookworm ENV HELPER_BIN_PATH=/usr/local/bin/ipa-helper ENV CONF_DIR=/etc/ipa ARG SOURCES_DIR From 2c0ca5ced044855dba601ee2a8b4677f55633e14 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 11 Sep 2024 13:09:34 -0700 Subject: [PATCH 010/191] remove unneeded input-data-100.txt (#1271) --- input-data-100.txt | 100 --------------------------------------------- 1 file changed, 100 deletions(-) delete mode 100644 input-data-100.txt diff --git a/input-data-100.txt b/input-data-100.txt deleted file mode 100644 index 4ea0ee7c4..000000000 --- a/input-data-100.txt +++ /dev/null @@ -1,100 +0,0 @@ -600339,534942975307,0,5,0 -96422,191017627906,0,3,0 -507032,117803731851,0,10,0 -17448,304519167044,1,0,4 -224051,12251697120,0,17,0 -572331,534942975307,1,0,1 -204850,534942975307,0,12,0 -572399,865368699047,0,2,0 -595278,865368699047,1,0,4 -457115,191017627906,1,0,4 -279628,534942975307,0,7,0 -100525,925363717604,1,0,5 -565595,925363717604,0,11,0 -567404,865368699047,0,3,0 -140412,304519167044,1,0,5 -329551,925363717604,1,0,1 -524654,314908499604,0,8,0 -240982,850807271120,1,0,5 -603020,117803731851,0,1,0 -272156,865368699047,0,17,0 -227353,12251697120,0,5,0 -265919,925363717604,1,0,1 -547,12251697120,0,2,0 -342491,925363717604,1,0,1 -250600,304519167044,0,6,0 -252290,117803731851,0,18,0 -141260,850807271120,0,6,0 -248451,304519167044,0,16,0 -515699,191017627906,1,0,4 -312537,12251697120,1,0,2 -492188,283283408809,0,13,0 -451766,917537570026,0,7,0 -287218,822386586545,0,11,0 -67235,925363717604,1,0,5 -603886,917537570026,1,0,3 -213895,117803731851,0,11,0 -418303,534942975307,0,10,0 -210243,822386586545,0,9,0 -211179,117803731851,1,0,5 -568874,925363717604,0,0,0 -373535,925363717604,1,0,3 -232675,534942975307,1,0,5 -92636,191017627906,1,0,1 -398372,917537570026,0,6,0 -401827,534942975307,1,0,2 -155515,65168429090,1,0,1 -33026,304519167044,0,17,0 -493183,179797603392,1,0,1 -167758,179797603392,1,0,4 -522471,191017627906,0,11,0 -313610,925363717604,1,0,1 -176225,12251697120,0,16,0 -588107,925363717604,0,13,0 -280600,393203478859,0,10,0 -491601,179797603392,0,4,0 -445133,773905428637,1,0,3 -301999,12251697120,1,0,5 -65750,526858192111,0,19,0 -350976,12251697120,0,9,0 -67867,773905428637,1,0,2 -594037,191017627906,0,11,0 -261995,534942975307,1,0,3 -133066,288854012131,1,0,4 -40015,179797603392,1,0,5 -571126,288854012131,0,10,0 -514451,773905428637,0,8,0 -201640,288854012131,1,0,4 -71935,526858192111,1,0,2 -316596,773905428637,0,6,0 -246923,12251697120,1,0,3 -79789,773905428637,1,0,4 -47468,917537570026,0,17,0 -161925,773905428637,0,9,0 -225460,393203478859,1,0,4 -530756,640580450837,0,4,0 -94219,338037795442,1,0,4 -136211,179797603392,0,0,0 -559897,191017627906,1,0,1 -332026,179797603392,1,0,1 -35911,917537570026,1,0,5 -329450,191017627906,0,4,0 -102812,393203478859,0,11,0 -578374,917537570026,0,15,0 -156477,881719336823,0,0,0 -277455,179797603392,0,7,0 -186143,881719336823,1,0,3 -228562,393203478859,1,0,3 -346392,822386586545,1,0,3 -102532,881719336823,0,1,0 -589048,822386586545,1,0,1 -430856,288854012131,1,0,5 -408260,881719336823,0,16,0 -180588,477090731329,0,16,0 -502918,288854012131,0,7,0 -392616,393203478859,1,0,1 -463878,22654468721,1,0,1 -85787,393203478859,1,0,5 -238574,288854012131,0,4,0 -22862,822386586545,0,19,0 -481629,288854012131,0,3,0 From 2096bb4b03dbf3d90a1c12b60fd7f6bcdfe496eb Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 11 Sep 2024 15:48:54 -0700 Subject: [PATCH 011/191] Encrypted Input Buffer (#1261) * add new query command for test queries, make encrypted inputs default failing test Update ipa-core/src/bin/report_collector.rs Co-authored-by: Alex Koshelev clean up EncryptedInput args use assert_eq instead of assert with == move enc-input-files to correct part of cli invocation make sure encrypted inputs runs with https, test both code paths in-memory-infra only needed for crypto tests, allows it to be compiled for integration tests * rebase bug fix * refactor encrypted file loading into dedicated struct * rename EncryptedOprfReportStreams struct and provide better docs --- ipa-core/Cargo.toml | 2 +- ipa-core/src/bin/report_collector.rs | 237 +++++++++++++++++++++------ ipa-core/src/cli/crypto.rs | 85 +++++----- ipa-core/src/cli/mod.rs | 7 +- ipa-core/src/cli/playbook/ipa.rs | 2 + ipa-core/src/cli/playbook/mod.rs | 2 +- ipa-core/src/report.rs | 75 +++++++++ ipa-core/tests/common/mod.rs | 50 +++++- ipa-core/tests/compact_gate.rs | 51 ++++-- ipa-core/tests/encrypted_input.rs | 90 +++++----- ipa-core/tests/helper_networks.rs | 6 +- 11 files changed, 430 insertions(+), 177 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 835ebc28c..ecb68cc33 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -196,7 +196,7 @@ bench = false [[bin]] name = "crypto_util" -required-features = ["cli", "test-fixture", "web-app", "in-memory-infra"] +required-features = ["cli", "test-fixture", "web-app"] bench = false [[bench]] diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 58e2752d1..3f0c3ef5b 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -15,14 +15,17 @@ use hyper::http::uri::Scheme; use ipa_core::{ cli::{ noise::{apply, ApplyDpArgs}, - playbook::{make_clients, playbook_oprf_ipa, validate, validate_dp, InputSource}, + playbook::{ + make_clients, playbook_oprf_ipa, run_query_and_validate, validate, validate_dp, + InputSource, + }, CsvSerializer, IpaQueryResult, Verbosity, }, config::{KeyRegistries, NetworkConfig}, ff::{boolean_array::BA32, FieldType}, helpers::query::{DpMechanism, IpaQueryConfig, QueryConfig, QuerySize, QueryType}, net::MpcHelperClient, - report::DEFAULT_KEY_ID, + report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID}, test_fixture::{ ipa::{ipa_in_the_clear, CappingOrder, IpaQueryStyle, IpaSecurityModel, TestRawDataRecord}, EventGenerator, EventGeneratorConfig, @@ -54,7 +57,7 @@ struct Args { input: CommandInput, /// The destination file for output. - #[arg(long, value_name = "FILE")] + #[arg(long, value_name = "OUTPUT_FILE")] output_file: Option, #[command(subcommand)] @@ -97,11 +100,30 @@ enum ReportCollectorCommand { }, /// Apply differential privacy noise to IPA inputs ApplyDpNoise(ApplyDpArgs), - /// Execute OPRF IPA in a semi-honest setting + /// Execute OPRF IPA in a semi-honest majority setting with known test data + /// and compare results against expectation + SemiHonestOprfIpaTest(IpaQueryConfig), + /// Execute OPRF IPA in an honest majority (one malicious helper) setting + /// with known test data and compare results against expectation + MalciousOprfIpaTest(IpaQueryConfig), + /// Execute OPRF IPA in a semi-honest majority setting with unknown encrypted data #[command(visible_alias = "oprf-ipa")] - SemiHonestOprfIpa(IpaQueryConfig), + SemiHonestOprfIpa { + #[clap(flatten)] + encrypted_inputs: EncryptedInputs, + + #[clap(flatten)] + ipa_query_config: IpaQueryConfig, + }, /// Execute OPRF IPA in an honest majority (one malicious helper) setting - MaliciousOprfIpa(IpaQueryConfig), + /// with unknown encrypted data + MaliciousOprfIpa { + #[clap(flatten)] + encrypted_inputs: EncryptedInputs, + + #[clap(flatten)] + ipa_query_config: IpaQueryConfig, + }, } #[derive(Debug, clap::Args)] @@ -113,6 +135,21 @@ struct GenInputArgs { breakdowns: u32, } +#[derive(Debug, Parser)] +struct EncryptedInputs { + /// The encrypted input for H1 + #[arg(long, value_name = "H1_ENCRYPTED_INPUT_FILE")] + enc_input_file1: PathBuf, + + /// The encrypted input for H2 + #[arg(long, value_name = "H2_ENCRYPTED_INPUT_FILE")] + enc_input_file2: PathBuf, + + /// The encrypted input for H3 + #[arg(long, value_name = "H3_ENCRYPTED_INPUT_FILE")] + enc_input_file3: PathBuf, +} + #[tokio::main] async fn main() -> Result<(), Box> { let args = Args::parse(); @@ -132,8 +169,8 @@ async fn main() -> Result<(), Box> { gen_args, } => gen_inputs(count, seed, args.output_file, gen_args)?, ReportCollectorCommand::ApplyDpNoise(ref dp_args) => apply_dp_noise(&args, dp_args)?, - ReportCollectorCommand::SemiHonestOprfIpa(config) => { - ipa( + ReportCollectorCommand::SemiHonestOprfIpaTest(config) => { + ipa_test( &args, &network, IpaSecurityModel::SemiHonest, @@ -143,8 +180,8 @@ async fn main() -> Result<(), Box> { ) .await? } - ReportCollectorCommand::MaliciousOprfIpa(config) => { - ipa( + ReportCollectorCommand::MalciousOprfIpaTest(config) => { + ipa_test( &args, &network, IpaSecurityModel::Malicious, @@ -154,6 +191,34 @@ async fn main() -> Result<(), Box> { ) .await? } + ReportCollectorCommand::MaliciousOprfIpa { + ref encrypted_inputs, + ipa_query_config, + } => { + ipa( + &args, + IpaSecurityModel::Malicious, + ipa_query_config, + &clients, + IpaQueryStyle::Oprf, + encrypted_inputs, + ) + .await? + } + ReportCollectorCommand::SemiHonestOprfIpa { + ref encrypted_inputs, + ipa_query_config, + } => { + ipa( + &args, + IpaSecurityModel::SemiHonest, + ipa_query_config, + &clients, + IpaQueryStyle::Oprf, + encrypted_inputs, + ) + .await? + } }; Ok(()) @@ -185,23 +250,126 @@ fn gen_inputs( Ok(()) } -async fn ipa( - args: &Args, - network: &NetworkConfig, +/// Panics +/// if (security_model, query_style) tuple is undefined +fn get_query_type( security_model: IpaSecurityModel, + query_style: &IpaQueryStyle, ipa_query_config: IpaQueryConfig, - helper_clients: &[MpcHelperClient; 3], - query_style: IpaQueryStyle, -) -> Result<(), Box> { - let input = InputSource::from(&args.input); - let query_type = match (security_model, &query_style) { +) -> QueryType { + match (security_model, query_style) { (IpaSecurityModel::SemiHonest, IpaQueryStyle::Oprf) => { QueryType::SemiHonestOprfIpa(ipa_query_config) } (IpaSecurityModel::Malicious, IpaQueryStyle::Oprf) => { QueryType::MaliciousOprfIpa(ipa_query_config) } + } +} + +fn write_ipa_output_file( + path: &PathBuf, + query_result: &IpaQueryResult, +) -> Result<(), Box> { + // it will be sad to lose the results if file already exists. + let path = if Path::is_file(path) { + let mut new_file_name = thread_rng() + .sample_iter(&Alphanumeric) + .take(5) + .map(char::from) + .collect::(); + let file_name = path.file_stem().ok_or("not a file")?; + + new_file_name.insert(0, '-'); + new_file_name.insert_str(0, &file_name.to_string_lossy()); + tracing::warn!( + "{} file exists, renaming to {:?}", + path.display(), + new_file_name + ); + + // it will not be 100% accurate until file_prefix API is stabilized + Cow::Owned( + path.with_file_name(&new_file_name) + .with_extension(path.extension().unwrap_or("".as_ref())), + ) + } else { + Cow::Borrowed(path) }; + let mut file = File::options() + .write(true) + .create_new(true) + .open(path.deref()) + .map_err(|e| format!("Failed to create output file {}: {e}", path.display()))?; + + write!(file, "{}", serde_json::to_string_pretty(query_result)?)?; + Ok(()) +} + +async fn ipa( + args: &Args, + security_model: IpaSecurityModel, + ipa_query_config: IpaQueryConfig, + helper_clients: &[MpcHelperClient; 3], + query_style: IpaQueryStyle, + encrypted_inputs: &EncryptedInputs, +) -> Result<(), Box> { + let query_type = get_query_type(security_model, &query_style, ipa_query_config); + + let files = [ + &encrypted_inputs.enc_input_file1, + &encrypted_inputs.enc_input_file2, + &encrypted_inputs.enc_input_file3, + ]; + + let encrypted_oprf_report_files = EncryptedOprfReportStreams::from(files); + + let query_config = QueryConfig { + size: QuerySize::try_from(encrypted_oprf_report_files.query_size).unwrap(), + field_type: FieldType::Fp32BitPrime, + query_type, + }; + + let query_id = helper_clients[0] + .create_query(query_config) + .await + .expect("Unable to create query!"); + + tracing::info!("Starting query for OPRF"); + let actual = match query_style { + IpaQueryStyle::Oprf => { + // the value for histogram values (BA32) must be kept in sync with the server-side + // implementation, otherwise a runtime reconstruct error will be generated. + // see ipa-core/src/query/executor.rs + run_query_and_validate::( + encrypted_oprf_report_files.streams, + encrypted_oprf_report_files.query_size, + helper_clients, + query_id, + ipa_query_config, + ) + .await + } + }; + + if let Some(ref path) = args.output_file { + write_ipa_output_file(path, &actual)?; + } else { + println!("{}", serde_json::to_string_pretty(&actual)?); + } + Ok(()) +} + +async fn ipa_test( + args: &Args, + network: &NetworkConfig, + security_model: IpaSecurityModel, + ipa_query_config: IpaQueryConfig, + helper_clients: &[MpcHelperClient; 3], + query_style: IpaQueryStyle, +) -> Result<(), Box> { + let input = InputSource::from(&args.input); + let query_type = get_query_type(security_model, &query_style, ipa_query_config); let input_rows = input.iter::().collect::>(); let query_config = QueryConfig { @@ -255,38 +423,7 @@ async fn ipa( }; if let Some(ref path) = args.output_file { - // it will be sad to lose the results if file already exists. - let path = if Path::is_file(path) { - let mut new_file_name = thread_rng() - .sample_iter(&Alphanumeric) - .take(5) - .map(char::from) - .collect::(); - let file_name = path.file_stem().ok_or("not a file")?; - - new_file_name.insert(0, '-'); - new_file_name.insert_str(0, &file_name.to_string_lossy()); - tracing::warn!( - "{} file exists, renaming to {:?}", - path.display(), - new_file_name - ); - - // it will not be 100% accurate until file_prefix API is stabilized - Cow::Owned( - path.with_file_name(&new_file_name) - .with_extension(path.extension().unwrap_or("".as_ref())), - ) - } else { - Cow::Borrowed(path) - }; - let mut file = File::options() - .write(true) - .create_new(true) - .open(path.deref()) - .map_err(|e| format!("Failed to create output file {}: {e}", path.display()))?; - - write!(file, "{}", serde_json::to_string_pretty(&actual)?)?; + write_ipa_output_file(path, &actual)?; } tracing::info!("{m:?}", m = ipa_query_config); diff --git a/ipa-core/src/cli/crypto.rs b/ipa-core/src/cli/crypto.rs index f7ccca4f8..1a7ccc826 100644 --- a/ipa-core/src/cli/crypto.rs +++ b/ipa-core/src/cli/crypto.rs @@ -238,7 +238,7 @@ pub async fn decrypt_and_reconstruct(args: DecryptArgs) -> Result<(), BoxError> Ok(()) } -#[cfg(test)] +#[cfg(all(test, feature = "in-memory-infra"))] mod tests { use std::{ fs::File, @@ -259,12 +259,10 @@ mod tests { CsvSerializer, }, ff::{boolean_array::BA16, U128Conversions}, - helpers::{ - query::{IpaQueryConfig, QuerySize}, - BodyStream, - }, + helpers::query::{IpaQueryConfig, QuerySize}, hpke::{IpaPrivateKey, KeyRegistry, PrivateKeyOnly}, query::OprfIpaQuery, + report::EncryptedOprfReportSteams, test_fixture::{ ipa::TestRawDataRecord, join3v, EventGenerator, EventGeneratorConfig, Reconstruct, TestWorld, @@ -540,6 +538,7 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" #[tokio::test] async fn encrypt_and_execute_query() { + panic!("is this run?"); const EXPECTED: &[u128] = &[0, 8, 5]; let records: Vec = vec![ @@ -601,22 +600,12 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); let _ = encrypt(&encrypt_args); - let enc1 = output_dir.path().join("helper1.enc"); - let enc2 = output_dir.path().join("helper2.enc"); - let enc3 = output_dir.path().join("helper3.enc"); - - let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new()); - for (i, path) in [enc1, enc2, enc3].iter().enumerate() { - let file = File::open(path).unwrap(); - let reader = BufReader::new(file); - for line in reader.lines() { - let line = line.unwrap(); - let encrypted_report_bytes = hex::decode(line.trim()).unwrap(); - println!("{}", encrypted_report_bytes.len()); - buffers[i].put_u16_le(encrypted_report_bytes.len().try_into().unwrap()); - buffers[i].put_slice(encrypted_report_bytes.as_slice()); - } - } + let files = [ + output_dir.path().join("helper1.enc"), + output_dir.path().join("helper2.enc"), + output_dir.path().join("helper3.enc"), + ]; + let encrypted_oprf_report_files = EncryptedOprfReportStreams::from(files); let world = TestWorld::default(); let contexts = world.contexts(); @@ -631,31 +620,35 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" ]; #[allow(clippy::large_futures)] - let results = join3v(buffers.into_iter().zip(contexts).zip(mk_private_keys).map( - |((buffer, ctx), mk_private_key)| { - let query_config = IpaQueryConfig { - per_user_credit_cap: 8, - attribution_window_seconds: None, - max_breakdown_key: 3, - with_dp: 0, - epsilon: 1.0, - plaintext_match_keys: false, - }; - let input = BodyStream::from(buffer); - - let private_registry = - Arc::new(KeyRegistry::::from_keys([PrivateKeyOnly( - IpaPrivateKey::from_bytes(&mk_private_key) - .expect("manually constructed for test"), - )])); - - OprfIpaQuery::<_, BA16, KeyRegistry>::new( - query_config, - private_registry, - ) - .execute(ctx, query_size, input) - }, - )) + let results = join3v( + encrypted_oprf_report_files + .streams + .into_iter() + .zip(contexts) + .zip(mk_private_keys) + .map(|((input, ctx), mk_private_key)| { + let query_config = IpaQueryConfig { + per_user_credit_cap: 8, + attribution_window_seconds: None, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 1.0, + plaintext_match_keys: false, + }; + + let private_registry = + Arc::new(KeyRegistry::::from_keys([PrivateKeyOnly( + IpaPrivateKey::from_bytes(&mk_private_key) + .expect("manually constructed for test"), + )])); + + OprfIpaQuery::<_, BA16, KeyRegistry>::new( + query_config, + private_registry, + ) + .execute(ctx, query_size, input) + }), + ) .await; assert_eq!( diff --git a/ipa-core/src/cli/mod.rs b/ipa-core/src/cli/mod.rs index 28847f416..ccb2e37fd 100644 --- a/ipa-core/src/cli/mod.rs +++ b/ipa-core/src/cli/mod.rs @@ -1,11 +1,6 @@ #[cfg(feature = "web-app")] mod clientconf; -#[cfg(all( - feature = "test-fixture", - feature = "web-app", - feature = "cli", - feature = "in-memory-infra" -))] +#[cfg(all(feature = "test-fixture", feature = "web-app", feature = "cli",))] pub mod crypto; mod csv; mod ipa_output; diff --git a/ipa-core/src/cli/playbook/ipa.rs b/ipa-core/src/cli/playbook/ipa.rs index b2c8050c6..5f56911f0 100644 --- a/ipa-core/src/cli/playbook/ipa.rs +++ b/ipa-core/src/cli/playbook/ipa.rs @@ -95,6 +95,8 @@ where run_query_and_validate::(inputs, query_size, clients, query_id, query_config).await } +/// # Panics +/// if results are invalid #[allow(clippy::disallowed_methods)] // allow try_join_all pub async fn run_query_and_validate( inputs: [BodyStream; 3], diff --git a/ipa-core/src/cli/playbook/mod.rs b/ipa-core/src/cli/playbook/mod.rs index 74bf9484d..135fa3117 100644 --- a/ipa-core/src/cli/playbook/mod.rs +++ b/ipa-core/src/cli/playbook/mod.rs @@ -14,7 +14,7 @@ pub use input::InputSource; pub use multiply::secure_mul; use tokio::time::sleep; -pub use self::ipa::playbook_oprf_ipa; +pub use self::ipa::{playbook_oprf_ipa, run_query_and_validate}; use crate::{ config::{ClientConfig, NetworkConfig, PeerConfig}, ff::boolean_array::{BA20, BA3, BA8}, diff --git a/ipa-core/src/report.rs b/ipa-core/src/report.rs index c2411ce28..a9da93454 100644 --- a/ipa-core/src/report.rs +++ b/ipa-core/src/report.rs @@ -1,7 +1,34 @@ +//! Provides report types which are aggregated by the IPA protocol +//! +//! The `OprfReport` is the primary data type which each helpers use to aggreate in the IPA +//! protocol. +//! From each Helper's POV, the Report Collector POSTs a length delimited byte +//! stream, which is then processed as follows: +//! +//! `BodyStream` → `EncryptedOprfReport` → `OprfReport` +//! +//! From the Report Collectors's POV, there are two potential paths: +//! 1. In production, encrypted events are recieved from clients and accumulated out of band +//! as 3 files of newline delimited hex encoded enrypted events. +//! 2. For testing, simluated plaintext events are provided as a CSV. +//! +//! Path 1 is proccssed as follows: +//! +//! `files: [PathBuf; 3]` → `EncryptedOprfReportsFiles` → `helpers::BodyStream` +//! +//! Path 2 is processed as follows: +//! +//! `cli::playbook::InputSource` (`PathBuf` or `stdin()`) → +//! `test_fixture::ipa::TestRawDataRecord` → `OprfReport` → encrypted bytes +//! (via `Oprf.delmited_encrypt_to`) → `helpers::BodyStream` + use std::{ fmt::{Display, Formatter}, + fs::File, + io::{BufRead, BufReader}, marker::PhantomData, ops::{Add, Deref}, + path::PathBuf, }; use bytes::{BufMut, Bytes}; @@ -13,6 +40,7 @@ use typenum::{Sum, Unsigned, U1, U16}; use crate::{ error::BoxError, ff::{boolean_array::BA64, Serializable}, + helpers::BodyStream, hpke::{ open_in_place, seal_in_place, CryptError, EncapsulationSize, Info, PrivateKeyRegistry, PublicKeyRegistry, TagSize, @@ -159,6 +187,53 @@ pub enum InvalidReportError { Length(usize, usize), } +/// A struct intended for the Report Collector to hold the streams of underlying +/// `EncryptedOprfReports` represented as length delmited bytes. Helpers receive an +/// individual stream, which are unpacked into `EncryptedOprfReports` and decrypted +/// into `OprfReports`. +pub struct EncryptedOprfReportStreams { + pub streams: [BodyStream; 3], + pub query_size: usize, +} + +/// A trait to build an `EncryptedOprfReportStreams` struct from 3 files of +/// `EncryptedOprfReports` formated at newline delimited hex. +impl From<[&PathBuf; 3]> for EncryptedOprfReportStreams { + fn from(files: [&PathBuf; 3]) -> Self { + let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new()); + let mut query_sizes: [usize; 3] = [0, 0, 0]; + for (i, path) in files.iter().enumerate() { + let file = + File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}")); + let reader = BufReader::new(file); + for line in reader.lines() { + let encrypted_report_bytes = hex::decode( + line.expect("Unable to read line. {file:?} is likely corrupt") + .trim(), + ) + .expect("Unable to read line. {file:?} is likely corrupt"); + buffers[i].put_u16_le( + encrypted_report_bytes + .len() + .try_into() + .expect("Unable to read line. {file:?} is likely corrupt"), + ); + buffers[i].put_slice(encrypted_report_bytes.as_slice()); + query_sizes[i] += 1; + } + } + // Panic if input sizes are not the same + // Panic instead of returning an Error as this is non-recoverable + assert_eq!(query_sizes[0], query_sizes[1]); + assert_eq!(query_sizes[1], query_sizes[2]); + + Self { + streams: buffers.map(BodyStream::from), + // without loss of generality, set query length to length of first input size + query_size: query_sizes[0], + } + } +} // TODO: If we are parsing reports from CSV files, we may also want an owned version of EncryptedReport. /// A binary report as submitted by a report collector, containing encrypted `OprfReport` diff --git a/ipa-core/tests/common/mod.rs b/ipa-core/tests/common/mod.rs index d04d99efd..ca1d5e08a 100644 --- a/ipa-core/tests/common/mod.rs +++ b/ipa-core/tests/common/mod.rs @@ -24,6 +24,7 @@ pub mod tempdir; pub const HELPER_BIN: &str = env!("CARGO_BIN_EXE_helper"); pub const TEST_MPC_BIN: &str = env!("CARGO_BIN_EXE_test_mpc"); pub const TEST_RC_BIN: &str = env!("CARGO_BIN_EXE_report_collector"); +pub const CRYPTO_UTIL_BIN: &str = env!("CARGO_BIN_EXE_crypto_util"); pub trait UnwrapStatusExt { fn unwrap_status(self); @@ -216,17 +217,27 @@ pub fn test_network(https: bool) { T::execute(path, https); } -pub fn test_ipa(mode: IpaSecurityModel, https: bool) { +pub fn test_ipa(mode: IpaSecurityModel, https: bool, encrypted_inputs: bool) { test_ipa_with_config( mode, https, IpaQueryConfig { ..Default::default() }, + encrypted_inputs, ); } -pub fn test_ipa_with_config(mode: IpaSecurityModel, https: bool, config: IpaQueryConfig) { +pub fn test_ipa_with_config( + mode: IpaSecurityModel, + https: bool, + config: IpaQueryConfig, + encrypted_inputs: bool, +) { + if encrypted_inputs & !https { + panic!("encrypted_input requires https") + }; + const INPUT_SIZE: usize = 100; // set to true to always keep the temp dir after test finishes let dir = TempDir::new_delete_on_drop(); @@ -250,11 +261,25 @@ pub fn test_ipa_with_config(mode: IpaSecurityModel, https: bool, config: IpaQuer .stdin(Stdio::piped()); command.status().unwrap_status(); + if encrypted_inputs { + // Encrypt Input + let mut command = Command::new(CRYPTO_UTIL_BIN); + command + .arg("encrypt") + .args(["--input-file".as_ref(), inputs_file.as_os_str()]) + .args(["--output-dir".as_ref(), path.as_os_str()]) + .args(["--network".into(), dir.path().join("network.toml")]) + .stdin(Stdio::piped()); + command.status().unwrap_status(); + } + // Run IPA let mut command = Command::new(TEST_RC_BIN); + if !encrypted_inputs { + command.args(["--input-file".as_ref(), inputs_file.as_os_str()]); + } command .args(["--network".into(), dir.path().join("network.toml")]) - .args(["--input-file".as_ref(), inputs_file.as_os_str()]) .args(["--output-file".as_ref(), output_file.as_os_str()]) .args(["--wait", "2"]) .silent(); @@ -263,12 +288,23 @@ pub fn test_ipa_with_config(mode: IpaSecurityModel, https: bool, config: IpaQuer command.arg("--disable-https"); } - let protocol = match mode { - IpaSecurityModel::SemiHonest => "semi-honest-oprf-ipa", - IpaSecurityModel::Malicious => "malicious-oprf-ipa", + let protocol = match (mode, encrypted_inputs) { + (IpaSecurityModel::SemiHonest, true) => "semi-honest-oprf-ipa", + (IpaSecurityModel::SemiHonest, false) => "semi-honest-oprf-ipa-test", + (IpaSecurityModel::Malicious, true) => "malicious-oprf-ipa", + (IpaSecurityModel::Malicious, false) => "malicious-oprf-ipa-test", }; + command.arg(protocol); + if encrypted_inputs { + let enc1 = dir.path().join("helper1.enc"); + let enc2 = dir.path().join("helper2.enc"); + let enc3 = dir.path().join("helper3.enc"); + command + .args(["--enc-input-file1".as_ref(), enc1.as_os_str()]) + .args(["--enc-input-file2".as_ref(), enc2.as_os_str()]) + .args(["--enc-input-file3".as_ref(), enc3.as_os_str()]); + } command - .arg(protocol) .args(["--max-breakdown-key", &config.max_breakdown_key.to_string()]) .args([ "--per-user-credit-cap", diff --git a/ipa-core/tests/compact_gate.rs b/ipa-core/tests/compact_gate.rs index c67be42bc..a1a379648 100644 --- a/ipa-core/tests/compact_gate.rs +++ b/ipa-core/tests/compact_gate.rs @@ -12,6 +12,7 @@ fn test_compact_gate>( mode: IpaSecurityModel, per_user_credit_cap: u32, attribution_window_seconds: I, + encrypted_input: bool, ) { let config = IpaQueryConfig { per_user_credit_cap, @@ -20,31 +21,59 @@ fn test_compact_gate>( ..Default::default() }; - test_ipa_with_config(mode, false, config); + // test https with encrypted input + // and http with plaintest input + test_ipa_with_config(mode, encrypted_input, config, encrypted_input); } #[test] -fn compact_gate_cap_8_no_window_semi_honest() { - test_compact_gate(IpaSecurityModel::SemiHonest, 8, 0); +fn compact_gate_cap_8_no_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 8, 0, true); +} + +#[test] +fn compact_gate_cap_8_no_window_semi_honest_plaintext_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 8, 0, false); +} + +#[test] +#[ignore] // TODO +fn compact_gate_cap_8_no_window_malicious_encrypted_input() { + test_compact_gate(IpaSecurityModel::Malicious, 8, 0, true); } #[test] #[ignore] // TODO -fn compact_gate_cap_8_no_window_malicious() { - test_compact_gate(IpaSecurityModel::Malicious, 8, 0); +fn compact_gate_cap_8_no_window_malicious_plaintext_input() { + test_compact_gate(IpaSecurityModel::Malicious, 8, 0, false); +} + +#[test] +fn compact_gate_cap_8_with_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 8, 86400, true); +} + +#[test] +fn compact_gate_cap_8_with_window_semi_honest_plaintext_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 8, 86400, false); +} + +#[test] +fn compact_gate_cap_16_no_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 16, 0, true); } #[test] -fn compact_gate_cap_8_with_window_semi_honest() { - test_compact_gate(IpaSecurityModel::SemiHonest, 8, 86400); +fn compact_gate_cap_16_no_window_semi_honest_plaintext_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 16, 0, false); } #[test] -fn compact_gate_cap_16_no_window_semi_honest() { - test_compact_gate(IpaSecurityModel::SemiHonest, 16, 0); +fn compact_gate_cap_16_with_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 16, 86400, true); } #[test] -fn compact_gate_cap_16_with_window_semi_honest() { - test_compact_gate(IpaSecurityModel::SemiHonest, 16, 86400); +fn compact_gate_cap_16_with_window_semi_honest_plaintext_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 16, 86400, false); } diff --git a/ipa-core/tests/encrypted_input.rs b/ipa-core/tests/encrypted_input.rs index 23f5e1aa3..e25a2a309 100644 --- a/ipa-core/tests/encrypted_input.rs +++ b/ipa-core/tests/encrypted_input.rs @@ -6,14 +6,8 @@ ))] mod tests { - use std::{ - fs::File, - io::{BufRead, BufReader, Write}, - path::Path, - sync::Arc, - }; + use std::{io::Write, path::Path, sync::Arc}; - use bytes::BufMut; use clap::Parser; use hpke::Deserializable; use ipa_core::{ @@ -22,12 +16,10 @@ mod tests { CsvSerializer, }, ff::{boolean_array::BA16, U128Conversions}, - helpers::{ - query::{IpaQueryConfig, QuerySize}, - BodyStream, - }, + helpers::query::{IpaQueryConfig, QuerySize}, hpke::{IpaPrivateKey, KeyRegistry, PrivateKeyOnly}, query::OprfIpaQuery, + report::EncryptedOprfReportStreams, test_fixture::{ipa::TestRawDataRecord, join3v, Reconstruct, TestWorld}, }; use tempfile::{tempdir, NamedTempFile}; @@ -132,22 +124,12 @@ public_key = "b900be35da06106a83ed73c33f733e03e4ea5888b7ea4c912ab270b0b0f8381e" build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); let _ = encrypt(&encrypt_args); - let enc1 = output_dir.path().join("helper1.enc"); - let enc2 = output_dir.path().join("helper2.enc"); - let enc3 = output_dir.path().join("helper3.enc"); - - let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new()); - for (i, path) in [enc1, enc2, enc3].iter().enumerate() { - let file = File::open(path).unwrap(); - let reader = BufReader::new(file); - for line in reader.lines() { - let line = line.unwrap(); - let encrypted_report_bytes = hex::decode(line.trim()).unwrap(); - println!("{}", encrypted_report_bytes.len()); - buffers[i].put_u16_le(encrypted_report_bytes.len().try_into().unwrap()); - buffers[i].put_slice(encrypted_report_bytes.as_slice()); - } - } + let files = [ + &output_dir.path().join("helper1.enc"), + &output_dir.path().join("helper2.enc"), + &output_dir.path().join("helper3.enc"), + ]; + let encrypted_oprf_report_streams = EncryptedOprfReportStreams::from(files); let world = TestWorld::default(); let contexts = world.contexts(); @@ -162,31 +144,35 @@ public_key = "b900be35da06106a83ed73c33f733e03e4ea5888b7ea4c912ab270b0b0f8381e" ]; #[allow(clippy::large_futures)] - let results = join3v(buffers.into_iter().zip(contexts).zip(mk_private_keys).map( - |((buffer, ctx), mk_private_key)| { - let query_config = IpaQueryConfig { - per_user_credit_cap: 8, - attribution_window_seconds: None, - max_breakdown_key: 3, - with_dp: 0, - epsilon: 1.0, - plaintext_match_keys: false, - }; - let input = BodyStream::from(buffer); - - let private_registry = - Arc::new(KeyRegistry::::from_keys([PrivateKeyOnly( - IpaPrivateKey::from_bytes(&mk_private_key) - .expect("manually constructed for test"), - )])); - - OprfIpaQuery::<_, BA16, KeyRegistry>::new( - query_config, - private_registry, - ) - .execute(ctx, query_size, input) - }, - )) + let results = join3v( + encrypted_oprf_report_streams + .streams + .into_iter() + .zip(contexts) + .zip(mk_private_keys) + .map(|((input, ctx), mk_private_key)| { + let query_config = IpaQueryConfig { + per_user_credit_cap: 8, + attribution_window_seconds: None, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 1.0, + plaintext_match_keys: false, + }; + + let private_registry = + Arc::new(KeyRegistry::::from_keys([PrivateKeyOnly( + IpaPrivateKey::from_bytes(&mk_private_key) + .expect("manually constructed for test"), + )])); + + OprfIpaQuery::<_, BA16, KeyRegistry>::new( + query_config, + private_registry, + ) + .execute(ctx, query_size, input) + }), + ) .await; assert_eq!( diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index 8bbcc0622..f44e5229b 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -45,19 +45,19 @@ fn http_network_large_input() { #[test] #[cfg(all(test, web_test))] fn http_semi_honest_ipa() { - test_ipa(IpaSecurityModel::SemiHonest, false); + test_ipa(IpaSecurityModel::SemiHonest, false, false); } #[test] #[cfg(all(test, web_test))] fn https_semi_honest_ipa() { - test_ipa(IpaSecurityModel::SemiHonest, true); + test_ipa(IpaSecurityModel::SemiHonest, true, true); } #[test] #[cfg(all(test, web_test, not(feature = "compact-gate")))] // TODO: enable for compact gate fn https_malicious_ipa() { - test_ipa(IpaSecurityModel::Malicious, true); + test_ipa(IpaSecurityModel::Malicious, true, true); } /// Similar to [`network`] tests, but it uses keygen + confgen CLIs to generate helper client config From 8a04fa48c42b34a03ae3c4e4c092023dfdca0b4a Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 11 Sep 2024 17:25:42 -0700 Subject: [PATCH 012/191] align variable and struct naming (#1272) --- ipa-core/src/bin/report_collector.rs | 8 ++++---- ipa-core/src/cli/crypto.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 3f0c3ef5b..6c604c5ae 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -322,10 +322,10 @@ async fn ipa( &encrypted_inputs.enc_input_file3, ]; - let encrypted_oprf_report_files = EncryptedOprfReportStreams::from(files); + let encrypted_oprf_report_streams = EncryptedOprfReportStreams::from(files); let query_config = QueryConfig { - size: QuerySize::try_from(encrypted_oprf_report_files.query_size).unwrap(), + size: QuerySize::try_from(encrypted_oprf_report_streams.query_size).unwrap(), field_type: FieldType::Fp32BitPrime, query_type, }; @@ -342,8 +342,8 @@ async fn ipa( // implementation, otherwise a runtime reconstruct error will be generated. // see ipa-core/src/query/executor.rs run_query_and_validate::( - encrypted_oprf_report_files.streams, - encrypted_oprf_report_files.query_size, + encrypted_oprf_report_streams.streams, + encrypted_oprf_report_streams.query_size, helper_clients, query_id, ipa_query_config, diff --git a/ipa-core/src/cli/crypto.rs b/ipa-core/src/cli/crypto.rs index 1a7ccc826..2eb90e089 100644 --- a/ipa-core/src/cli/crypto.rs +++ b/ipa-core/src/cli/crypto.rs @@ -605,7 +605,7 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" output_dir.path().join("helper2.enc"), output_dir.path().join("helper3.enc"), ]; - let encrypted_oprf_report_files = EncryptedOprfReportStreams::from(files); + let encrypted_oprf_report_streams = EncryptedOprfReportStreams::from(files); let world = TestWorld::default(); let contexts = world.contexts(); @@ -621,7 +621,7 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" #[allow(clippy::large_futures)] let results = join3v( - encrypted_oprf_report_files + encrypted_oprf_report_streams .streams .into_iter() .zip(contexts) From 9bcdf4c2d626e82d0b0c8a1d5d9335a93a02bfcb Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 11 Sep 2024 17:26:35 -0700 Subject: [PATCH 013/191] remove unused report_collector apply-dp subcommand (#1274) --- ipa-core/src/bin/report_collector.rs | 52 ----------- ipa-core/src/cli/mod.rs | 2 - ipa-core/src/cli/noise.rs | 123 --------------------------- 3 files changed, 177 deletions(-) delete mode 100644 ipa-core/src/cli/noise.rs diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 6c604c5ae..12e69bbd1 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -10,11 +10,9 @@ use std::{ }; use clap::{Parser, Subcommand}; -use comfy_table::{Cell, Table}; use hyper::http::uri::Scheme; use ipa_core::{ cli::{ - noise::{apply, ApplyDpArgs}, playbook::{ make_clients, playbook_oprf_ipa, run_query_and_validate, validate, validate_dp, InputSource, @@ -98,8 +96,6 @@ enum ReportCollectorCommand { #[clap(flatten)] gen_args: EventGeneratorConfig, }, - /// Apply differential privacy noise to IPA inputs - ApplyDpNoise(ApplyDpArgs), /// Execute OPRF IPA in a semi-honest majority setting with known test data /// and compare results against expectation SemiHonestOprfIpaTest(IpaQueryConfig), @@ -168,7 +164,6 @@ async fn main() -> Result<(), Box> { seed, gen_args, } => gen_inputs(count, seed, args.output_file, gen_args)?, - ReportCollectorCommand::ApplyDpNoise(ref dp_args) => apply_dp_noise(&args, dp_args)?, ReportCollectorCommand::SemiHonestOprfIpaTest(config) => { ipa_test( &args, @@ -447,50 +442,3 @@ async fn ipa_test( Ok(()) } - -fn apply_dp_noise(args: &Args, dp_args: &ApplyDpArgs) -> Result<(), Box> { - let IpaQueryResult { breakdowns, .. } = - serde_json::from_slice(&InputSource::from(&args.input).to_vec()?)?; - - let output = apply(&breakdowns, dp_args); - let mut table = Table::new(); - let header = std::iter::once("Epsilon".to_string()) - .chain(std::iter::once("Variance".to_string())) - .chain(std::iter::once("Mean".to_string())) - .chain((0..breakdowns.len()).map(|i| format!("{}", i + 1))) - .collect::>(); - table.set_header(header); - - // original values - table.add_row( - std::iter::repeat("-".to_string()) - .take(3) - .chain(breakdowns.iter().map(ToString::to_string)), - ); - - // reverse because smaller epsilon means more noise and I print the original values - // in the first row. - for epsilon in output.keys().rev() { - let noised_values = output.get(epsilon).unwrap(); - let mut row = vec![ - Cell::new(format!("{:.3}", epsilon)), - Cell::new(format!("{:.3}", noised_values.std)), - Cell::new(format!("{:.3}", noised_values.mean)), - ]; - - for agg in noised_values.breakdowns.iter() { - row.push(Cell::new(format!("{}", agg))); - } - - table.add_row(row); - } - - println!("{}", table); - - if let Some(file) = &args.output_file { - let mut file = File::create(file)?; - serde_json::to_writer_pretty(&mut file, &output)?; - } - - Ok(()) -} diff --git a/ipa-core/src/cli/mod.rs b/ipa-core/src/cli/mod.rs index ccb2e37fd..467425785 100644 --- a/ipa-core/src/cli/mod.rs +++ b/ipa-core/src/cli/mod.rs @@ -7,8 +7,6 @@ mod ipa_output; #[cfg(feature = "web-app")] mod keygen; mod metric_collector; -#[cfg(feature = "cli")] -pub mod noise; mod paths; #[cfg(all(feature = "test-fixture", feature = "web-app", feature = "cli"))] pub mod playbook; diff --git a/ipa-core/src/cli/noise.rs b/ipa-core/src/cli/noise.rs deleted file mode 100644 index f83f93174..000000000 --- a/ipa-core/src/cli/noise.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::{ - cmp::Ordering, - collections::BTreeMap, - fmt::{Debug, Display, Formatter}, -}; - -use clap::Args; -use rand::rngs::StdRng; -use rand_core::SeedableRng; -use serde::{Deserialize, Serialize, Serializer}; - -use crate::protocol::ipa_prf::oprf_padding::InsecureDiscreteDp; - -#[derive(Debug, Args)] -#[clap(about = "Apply differential privacy noise to the given input")] -pub struct ApplyDpArgs { - /// Various epsilon values to use inside the DP. - #[arg(long, short = 'e')] - epsilon: Vec, - - /// Delta parameter for (\epsilon, \delta) DP. - #[arg(long, short = 'd', default_value = "1e-7")] - delta: f64, - - /// Seed for the random number generator. - #[arg(long, short = 's')] - seed: Option, - - /// The sensitivity of the input or maximum contribution allowed per user to preserve privacy. - #[arg(long, short = 'c')] - cap: u32, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct NoisyOutput { - /// Aggregated breakdowns with noise applied. It is important to use unsigned values here - /// to avoid bias/mean skew - pub breakdowns: Box<[i64]>, - pub mean: f64, - pub std: f64, -} - -/// This exists to be able to use f64 as key inside a map. We don't have to deal with infinities or -/// NaN values for epsilons, so we can treat them as raw bytes for this purpose. -#[derive(Debug, Copy, Clone)] -pub struct EpsilonBits(f64); - -impl Serialize for EpsilonBits { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(&self.0.to_string()) - } -} - -impl From for EpsilonBits { - fn from(value: f64) -> Self { - assert!(value.is_finite()); - Self(value) - } -} - -// the following implementations are fine because NaN values are rejected from inside `From` - -impl PartialEq for EpsilonBits { - fn eq(&self, other: &Self) -> bool { - self.0.to_bits().eq(&other.0.to_bits()) - } -} - -impl Eq for EpsilonBits {} - -impl PartialOrd for EpsilonBits { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for EpsilonBits { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.0.partial_cmp(&other.0).unwrap() - } -} - -impl Display for EpsilonBits { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - Display::fmt(&self.0, f) - } -} - -/// Apply DP noise to the given input. -/// -/// ## Panics -/// If DP parameters are not valid. -pub fn apply>(input: I, args: &ApplyDpArgs) -> BTreeMap { - let mut rng = args - .seed - .map_or_else(StdRng::from_entropy, StdRng::seed_from_u64); - let mut result = BTreeMap::new(); - for &epsilon in &args.epsilon { - let discrete_dp = - InsecureDiscreteDp::new(epsilon, args.delta, f64::from(args.cap)).unwrap(); - let mut v = input - .as_ref() - .iter() - .copied() - .map(i64::from) - .collect::>(); - discrete_dp.apply(v.as_mut_slice(), &mut rng); - - result.insert( - epsilon.into(), - NoisyOutput { - breakdowns: v.into_boxed_slice(), - mean: discrete_dp.mean(), - std: discrete_dp.std(), - }, - ); - } - - result -} From 5eb94a2c542f5ee2543e701032119a33637d5a47 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 12 Sep 2024 12:32:42 -0700 Subject: [PATCH 014/191] Fix a typo --- ipa-core/src/bin/report_collector.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 3f0c3ef5b..094975b52 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -105,7 +105,7 @@ enum ReportCollectorCommand { SemiHonestOprfIpaTest(IpaQueryConfig), /// Execute OPRF IPA in an honest majority (one malicious helper) setting /// with known test data and compare results against expectation - MalciousOprfIpaTest(IpaQueryConfig), + MaliciousOprfIpaTest(IpaQueryConfig), /// Execute OPRF IPA in a semi-honest majority setting with unknown encrypted data #[command(visible_alias = "oprf-ipa")] SemiHonestOprfIpa { @@ -180,7 +180,7 @@ async fn main() -> Result<(), Box> { ) .await? } - ReportCollectorCommand::MalciousOprfIpaTest(config) => { + ReportCollectorCommand::MaliciousOprfIpaTest(config) => { ipa_test( &args, &network, From 9a3fb320e70bf854e85b0207e4ee83fffc06b9cc Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Thu, 12 Sep 2024 13:18:40 -0700 Subject: [PATCH 015/191] remove duplicated crypto_util test, update CI to run CLI tests (#1273) --- .github/workflows/check.yml | 6 +- ipa-core/src/cli/crypto.rs | 10 +- ipa-core/tests/encrypted_input.rs | 186 ------------------------------ pre-commit | 5 +- scripts/coverage-ci | 2 +- 5 files changed, 7 insertions(+), 202 deletions(-) delete mode 100644 ipa-core/tests/encrypted_input.rs diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 8d117fa32..58d37d514 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -68,7 +68,7 @@ jobs: run: cargo build --tests - name: Run tests - run: cargo test + run: cargo test --features "cli test-fixture" - name: Run tests with multithreading feature enabled run: cargo test --features "multi-threading" @@ -76,9 +76,6 @@ jobs: - name: Run Web Tests run: cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" - - name: Run Integration Tests - run: cargo test --test encrypted_input --features "cli test-fixture web-app in-memory-infra" - release: name: Release builds and tests runs-on: ubuntu-latest @@ -236,4 +233,3 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} file: ipa.cov fail_ci_if_error: false - diff --git a/ipa-core/src/cli/crypto.rs b/ipa-core/src/cli/crypto.rs index 2eb90e089..b66b6ee24 100644 --- a/ipa-core/src/cli/crypto.rs +++ b/ipa-core/src/cli/crypto.rs @@ -247,7 +247,6 @@ mod tests { sync::Arc, }; - use bytes::BufMut; use clap::Parser; use hpke::Deserializable; use rand::thread_rng; @@ -262,7 +261,7 @@ mod tests { helpers::query::{IpaQueryConfig, QuerySize}, hpke::{IpaPrivateKey, KeyRegistry, PrivateKeyOnly}, query::OprfIpaQuery, - report::EncryptedOprfReportSteams, + report::EncryptedOprfReportStreams, test_fixture::{ ipa::TestRawDataRecord, join3v, EventGenerator, EventGeneratorConfig, Reconstruct, TestWorld, @@ -538,7 +537,6 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" #[tokio::test] async fn encrypt_and_execute_query() { - panic!("is this run?"); const EXPECTED: &[u128] = &[0, 8, 5]; let records: Vec = vec![ @@ -601,9 +599,9 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" let _ = encrypt(&encrypt_args); let files = [ - output_dir.path().join("helper1.enc"), - output_dir.path().join("helper2.enc"), - output_dir.path().join("helper3.enc"), + &output_dir.path().join("helper1.enc"), + &output_dir.path().join("helper2.enc"), + &output_dir.path().join("helper3.enc"), ]; let encrypted_oprf_report_streams = EncryptedOprfReportStreams::from(files); diff --git a/ipa-core/tests/encrypted_input.rs b/ipa-core/tests/encrypted_input.rs deleted file mode 100644 index e25a2a309..000000000 --- a/ipa-core/tests/encrypted_input.rs +++ /dev/null @@ -1,186 +0,0 @@ -#[cfg(all( - feature = "test-fixture", - feature = "web-app", - feature = "cli", - feature = "in-memory-infra" -))] -mod tests { - - use std::{io::Write, path::Path, sync::Arc}; - - use clap::Parser; - use hpke::Deserializable; - use ipa_core::{ - cli::{ - crypto::{encrypt, EncryptArgs}, - CsvSerializer, - }, - ff::{boolean_array::BA16, U128Conversions}, - helpers::query::{IpaQueryConfig, QuerySize}, - hpke::{IpaPrivateKey, KeyRegistry, PrivateKeyOnly}, - query::OprfIpaQuery, - report::EncryptedOprfReportStreams, - test_fixture::{ipa::TestRawDataRecord, join3v, Reconstruct, TestWorld}, - }; - use tempfile::{tempdir, NamedTempFile}; - - fn build_encrypt_args( - input_file: &Path, - output_dir: &Path, - network_file: &Path, - ) -> EncryptArgs { - EncryptArgs::try_parse_from([ - "test_encrypt", - "--input-file", - input_file.to_str().unwrap(), - "--output-dir", - output_dir.to_str().unwrap(), - "--network", - network_file.to_str().unwrap(), - ]) - .unwrap() - } - - fn write_network_file() -> NamedTempFile { - let network_data = r#" -[[peers]] -url = "helper1.test" -[peers.hpke] -public_key = "92a6fb666c37c008defd74abf3204ebea685742eab8347b08e2f7c759893947a" -[[peers]] -url = "helper2.test" -[peers.hpke] -public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" -[[peers]] -url = "helper3.test" -[peers.hpke] -public_key = "b900be35da06106a83ed73c33f733e03e4ea5888b7ea4c912ab270b0b0f8381e" -"#; - let mut network = NamedTempFile::new().unwrap(); - writeln!(network.as_file_mut(), "{network_data}").unwrap(); - network - } - - #[tokio::test] - async fn encrypt_and_execute_query() { - const EXPECTED: &[u128] = &[0, 8, 5]; - - let records: Vec = vec![ - TestRawDataRecord { - timestamp: 0, - user_id: 12345, - is_trigger_report: false, - breakdown_key: 2, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 4, - user_id: 68362, - is_trigger_report: false, - breakdown_key: 1, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 10, - user_id: 12345, - is_trigger_report: true, - breakdown_key: 0, - trigger_value: 5, - }, - TestRawDataRecord { - timestamp: 12, - user_id: 68362, - is_trigger_report: true, - breakdown_key: 0, - trigger_value: 2, - }, - TestRawDataRecord { - timestamp: 20, - user_id: 68362, - is_trigger_report: false, - breakdown_key: 1, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 30, - user_id: 68362, - is_trigger_report: true, - breakdown_key: 1, - trigger_value: 7, - }, - ]; - let query_size = QuerySize::try_from(records.len()).unwrap(); - let mut input_file = NamedTempFile::new().unwrap(); - - for event in records { - let _ = event.to_csv(input_file.as_file_mut()); - writeln!(input_file.as_file()).unwrap(); - } - input_file.as_file_mut().flush().unwrap(); - - let output_dir = tempdir().unwrap(); - let network_file = write_network_file(); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - - let files = [ - &output_dir.path().join("helper1.enc"), - &output_dir.path().join("helper2.enc"), - &output_dir.path().join("helper3.enc"), - ]; - let encrypted_oprf_report_streams = EncryptedOprfReportStreams::from(files); - - let world = TestWorld::default(); - let contexts = world.contexts(); - - let mk_private_keys = vec![ - hex::decode("53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff") - .expect("manually provided for test"), - hex::decode("3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569") - .expect("manually provided for test"), - hex::decode("1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7") - .expect("manually provided for test"), - ]; - - #[allow(clippy::large_futures)] - let results = join3v( - encrypted_oprf_report_streams - .streams - .into_iter() - .zip(contexts) - .zip(mk_private_keys) - .map(|((input, ctx), mk_private_key)| { - let query_config = IpaQueryConfig { - per_user_credit_cap: 8, - attribution_window_seconds: None, - max_breakdown_key: 3, - with_dp: 0, - epsilon: 1.0, - plaintext_match_keys: false, - }; - - let private_registry = - Arc::new(KeyRegistry::::from_keys([PrivateKeyOnly( - IpaPrivateKey::from_bytes(&mk_private_key) - .expect("manually constructed for test"), - )])); - - OprfIpaQuery::<_, BA16, KeyRegistry>::new( - query_config, - private_registry, - ) - .execute(ctx, query_size, input) - }), - ) - .await; - - assert_eq!( - results.reconstruct()[0..3] - .iter() - .map(U128Conversions::as_u128) - .collect::>(), - EXPECTED - ); - } -} diff --git a/pre-commit b/pre-commit index 308164245..713b614d2 100755 --- a/pre-commit +++ b/pre-commit @@ -97,7 +97,7 @@ check "Clippy web checks" \ # The tests here need to be kept in sync with scripts/coverage-ci. check "Tests" \ - cargo test + cargo test --features="cli test-fixture" check "Web tests" \ cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" @@ -108,9 +108,6 @@ check "Web tests (descriptive gate)" \ check "Concurrency tests" \ cargo test -p ipa-core --release --features "shuttle multi-threading" -check "Encrypted Input Tests" \ - cargo test --test encrypted_input --features "cli test-fixture web-app in-memory-infra" - check "IPA benchmark" \ cargo bench --bench oneshot_ipa --no-default-features --features="enable-benches compact-gate" -- -n 62 -c 16 diff --git a/scripts/coverage-ci b/scripts/coverage-ci index b34bc8920..9bb5e87ea 100755 --- a/scripts/coverage-ci +++ b/scripts/coverage-ci @@ -9,7 +9,7 @@ cargo llvm-cov clean --workspace cargo build --all-targets # Need to be kept in sync manually with tests we run inside check.yml. -cargo test +cargo test --features "cli test-fixture" # descriptive-gate does not require a feature flag. for gate in "compact-gate" ""; do From ad7fa8f0142fb91ddbb60256fea829feec95e280 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Thu, 12 Sep 2024 15:05:17 -0700 Subject: [PATCH 016/191] Updated GH actions --- .github/workflows/docker.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 12c6d40bb..6a8b555d1 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -34,7 +34,9 @@ jobs: type=sha - name: "Setup Docker Buildx" - uses: docker/setup-buildx-action@v2 + uses: docker/setup-buildx-action@v3 + with: + platforms: linux/amd64 - name: "Login to GitHub Container Registry" uses: docker/login-action@v2 @@ -44,7 +46,7 @@ jobs: password: ${{ secrets.GITHUB_TOKEN }} - name: "Build and Publish Helper Image" - uses: docker/build-push-action@v4 + uses: docker/build-push-action@v6 with: context: . file: ./docker/ci/helper.Dockerfile From 182e8e4bcdb515d5a13ab869f6fa365817e6882c Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Thu, 12 Sep 2024 15:17:33 -0700 Subject: [PATCH 017/191] Removing confusing Docker CI config --- .github/workflows/docker.yml | 2 +- docker/ci/helper.Dockerfile | 13 ------------- 2 files changed, 1 insertion(+), 14 deletions(-) delete mode 100644 docker/ci/helper.Dockerfile diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 6a8b555d1..8b47bf627 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -49,7 +49,7 @@ jobs: uses: docker/build-push-action@v6 with: context: . - file: ./docker/ci/helper.Dockerfile + file: ./docker/helper.Dockerfile push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} diff --git a/docker/ci/helper.Dockerfile b/docker/ci/helper.Dockerfile deleted file mode 100644 index 7f7b4d376..000000000 --- a/docker/ci/helper.Dockerfile +++ /dev/null @@ -1,13 +0,0 @@ -# syntax=docker/dockerfile:1 -FROM rust:latest as builder - -COPY . /ipa/ -RUN cd /ipa && \ - cargo build --bin helper --release --no-default-features \ - --features "web-app real-world-infra compact-gate" - -# Copy them to the final image -FROM debian:bullseye-slim - -COPY --from=builder /ipa/target/release/helper /bin/ipa-helper -ENTRYPOINT ["/bin/ipa-helper"] From cd85bd1f69e472ffeb7152fa978c3f9097823945 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Mon, 16 Sep 2024 13:43:30 -0700 Subject: [PATCH 018/191] improving malicious shuffle as proposed by Andy --- .../src/protocol/ipa_prf/shuffle/malicious.rs | 143 +++++++----------- 1 file changed, 53 insertions(+), 90 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index 4c515c97a..d698393ce 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -51,12 +51,13 @@ where for<'a> &'a B: Add<&'a B, Output = B>, Standard: Distribution, { + // assert lengths + assert_eq!(S::BITS + 32, B::BITS); // compute amount of MAC keys let amount_of_keys: usize = (usize::try_from(S::BITS).unwrap() + 31) / 32; // // generate MAC keys let keys = (0..amount_of_keys) - .map(|i| ctx.prss().generate_fields(RecordId::from(i))) - .map(|(left, right)| AdditiveShare::new(left, right)) + .map(|i| ctx.prss().generate(RecordId::from(i))) .collect::>>(); // compute and append tags to rows @@ -80,6 +81,7 @@ where .await?; // truncate tags from output_shares + // verify_shuffle ensures that truncate_tags yields the correct rows Ok(truncate_tags(&shuffled_shares)) } @@ -92,34 +94,37 @@ where S: BooleanArray, B: BooleanArray, { - let tag_offset = usize::try_from((S::BITS + 7) / 8).unwrap(); shares_and_tags .iter() .map(|row_with_tag| { AdditiveShare::new( - split_row_and_tag(row_with_tag.left(), tag_offset).0, - split_row_and_tag(row_with_tag.right(), tag_offset).0, + split_row_and_tag(row_with_tag.left()).0, + split_row_and_tag(row_with_tag.right()).0, ) }) .collect() } /// This function splits a row with tag into -/// a row without tag and a tag +/// a row without tag and a tag. +/// +/// When `row_with_tag` does not have the correct format, +/// i.e. deserialization returns an error, +/// the output row and tag will be the default values. /// /// ## Panics /// Panics when the lengths are incorrect: /// `S` in bytes needs to be equal to `tag_offset`. /// `B` in bytes needs to be equal to `tag_offset + 4`. -fn split_row_and_tag( - row_with_tag: B, - tag_offset: usize, -) -> (S, Gf32Bit) { +fn split_row_and_tag(row_with_tag: B) -> (S, Gf32Bit) { + let tag_offset = usize::try_from((S::BITS + 7) / 8).unwrap(); let mut buf = GenericArray::default(); row_with_tag.serialize(&mut buf); ( - S::deserialize(GenericArray::from_slice(&buf.as_slice()[0..tag_offset])).unwrap(), - Gf32Bit::deserialize(GenericArray::from_slice(&buf.as_slice()[tag_offset..])).unwrap(), + S::deserialize(GenericArray::from_slice(&buf.as_slice()[0..tag_offset])) + .unwrap_or_default(), + Gf32Bit::deserialize(GenericArray::from_slice(&buf.as_slice()[tag_offset..])) + .unwrap_or_default(), ) } @@ -138,7 +143,11 @@ async fn verify_shuffle( let k_ctx = ctx .narrow(&OPRFShuffleStep::RevealMACKey) .set_total_records(TotalRecords::specified(key_shares.len())?); - let keys = reveal_keys(&k_ctx, key_shares).await?; + let keys = reveal_keys(&k_ctx, key_shares) + .await? + .iter() + .map(Gf32Bit::from_array) + .collect::>(); // verify messages and shares match ctx.role() { @@ -166,15 +175,15 @@ async fn verify_shuffle( /// or `hash_c_h3 != hash_a_xor_b`. async fn h1_verify( ctx: C, - keys: &[StdArray], + keys: &[Gf32Bit], share_a_and_b: &[AdditiveShare], x1: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for x1 - let hash_x1 = compute_row_hash::(keys, x1); + let hash_x1 = compute_and_hash_tags::(keys, x1); // compute hash for A xor B - let hash_a_xor_b = compute_row_hash::( + let hash_a_xor_b = compute_and_hash_tags::( keys, share_a_and_b .iter() @@ -233,15 +242,15 @@ async fn h1_verify( /// `hash_x2 != hash_y2`. async fn h2_verify( ctx: C, - keys: &[StdArray], + keys: &[Gf32Bit], share_b_and_c: &[AdditiveShare], x2: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for x2 - let hash_x2 = compute_row_hash::(keys, x2); + let hash_x2 = compute_and_hash_tags::(keys, x2); // compute hash for C - let hash_c = compute_row_hash::( + let hash_c = compute_and_hash_tags::( keys, share_b_and_c.iter().map(ReplicatedSecretSharing::right), ); @@ -281,18 +290,18 @@ async fn h2_verify( /// Propagates network errors. async fn h3_verify( ctx: C, - keys: &[StdArray], + keys: &[Gf32Bit], share_c_and_a: &[AdditiveShare], y1: Vec, y2: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for y1 - let hash_y1 = compute_row_hash::(keys, y1); + let hash_y1 = compute_and_hash_tags::(keys, y1); // compute hash for y2 - let hash_y2 = compute_row_hash::(keys, y2); + let hash_y2 = compute_and_hash_tags::(keys, y2); // compute hash for C - let hash_c = compute_row_hash::( + let hash_c = compute_and_hash_tags::( keys, share_c_and_a.iter().map(ReplicatedSecretSharing::left), ); @@ -323,26 +332,26 @@ async fn h3_verify( /// /// ## Panics /// Panics when conversion from `BooleanArray` to `Vec(keys: &[StdArray], row_iterator: I) -> Hash +fn compute_and_hash_tags(keys: &[Gf32Bit], row_iterator: I) -> Hash where S: BooleanArray, B: BooleanArray, I: IntoIterator, { - let tag_offset = usize::try_from((B::BITS + 7) / 8).unwrap() - 4; - let iterator = row_iterator.into_iter().map(|row_with_tag| { - let (row, tag) = split_row_and_tag(row_with_tag, tag_offset); + // when split_row_and_tags returns the default value, the verification will fail + // except 2^-security_parameter, i.e. 2^-32 + let (row, tag) = split_row_and_tag(row_with_tag); >>::try_into(row) .unwrap() .into_iter() .chain(iter::once(tag)) }); - compute_hash(iterator.map(|row_iterator| { - row_iterator + compute_hash(iterator.map(|row_entry_iterator| { + row_entry_iterator .zip(keys) .fold(Gf32Bit::ZERO, |acc, (row_entry, key)| { - acc + row_entry * *key.first() + acc + row_entry * *key }) })) } @@ -362,6 +371,7 @@ async fn reveal_keys( // reveal MAC keys let keys = ctx .parallel_join(key_shares.iter().enumerate().map(|(i, key)| async move { + // uses malicious_reveal directly since we malicious_shuffle always needs the malicious_revel malicious_reveal(ctx.clone(), RecordId::from(i), None, key).await })) .await? @@ -399,8 +409,7 @@ where S: BooleanArray, B: BooleanArray, I: IntoIterator>, - I::IntoIter: ExactSizeIterator, - ::IntoIter: Send, + I::IntoIter: ExactSizeIterator + Send, { let row_iterator = rows.into_iter(); let length = row_iterator.len(); @@ -468,7 +477,7 @@ mod tests { use crate::{ ff::{ boolean_array::{BA112, BA144, BA20, BA32, BA64}, - Serializable, + Serializable, U128Conversions, }, helpers::in_memory_config::{MaliciousHelper, MaliciousHelperContext}, protocol::ipa_prf::shuffle::base::shuffle, @@ -514,13 +523,13 @@ mod tests { assert_eq!(record, result_ba); - let tag = >>::try_into(record) + let tag = Vec::::try_from(record) .unwrap() .iter() .zip(keys) .fold(Gf32Bit::ZERO, |acc, (entry, key)| acc + *entry * key); - let tag_mpc = >>::try_into(BA32::deserialize_from_slice( + let tag_mpc = Vec::::try_from(BA32::deserialize_from_slice( &result[0].as_raw_slice()[14..18], )) .unwrap(); @@ -536,69 +545,23 @@ mod tests { run(|| async { let world = TestWorld::default(); let mut rng = thread_rng(); - // using Gf32Bit here since it implements cmp such that vec can later be sorted let mut records = (0..RECORD_AMOUNT) - .map(|_| rng.gen()) - .collect::>(); - - let records_boolean_array = records - .iter() - .map(|row| { - let mut buf = GenericArray::default(); - row.serialize(&mut buf); - BA32::deserialize(&buf).unwrap() - }) - .collect::>(); - - let result = world - .semi_honest( - records_boolean_array.into_iter(), - |ctx, records| async move { - malicious_shuffle::<_, BA32, BA64, _>(ctx, records) - .await - .unwrap() - }, - ) - .await - .reconstruct(); - - let mut result_galois = result - .iter() - .map(|row| { - let mut buf = GenericArray::default(); - row.serialize(&mut buf); - Gf32Bit::deserialize(&buf).unwrap() - }) - .collect::>(); - - records.sort(); - result_galois.sort(); - - assert_eq!(records, result_galois); - }); - } - - /// This tests checks that the shuffling of `BA112` - /// does not return an error - /// nor panic. - #[test] - fn shuffle_ba112_succeeds() { - const RECORD_AMOUNT: usize = 10; - run(|| async { - let world = TestWorld::default(); - let mut rng = thread_rng(); - - let records = (0..RECORD_AMOUNT) .map(|_| rng.gen()) .collect::>(); - world - .semi_honest(records.into_iter(), |ctx, records| async move { + let mut result = world + .semi_honest(records.clone().into_iter(), |ctx, records| async move { malicious_shuffle::<_, BA112, BA144, _>(ctx, records) .await .unwrap() }) - .await; + .await + .reconstruct(); + + records.sort_by_key(BA112::as_u128); + result.sort_by_key(BA112::as_u128); + + assert_eq!(records, result); }); } From 95112ed44003989f9399156cf9d564c4287951b5 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Tue, 17 Sep 2024 11:51:35 -0700 Subject: [PATCH 019/191] Detect out-of-bounds step indices --- ipa-core/src/protocol/basics/reveal.rs | 2 +- ipa-core/src/protocol/boolean/and.rs | 2 +- ipa-core/src/protocol/boolean/or.rs | 2 +- ipa-core/src/protocol/boolean/step.rs | 61 ++---------- .../src/protocol/ipa_prf/aggregation/mod.rs | 6 +- .../src/protocol/ipa_prf/aggregation/step.rs | 12 +-- .../src/protocol/ipa_prf/prf_sharding/mod.rs | 29 ++++++ .../src/protocol/ipa_prf/prf_sharding/step.rs | 12 +-- ipa-core/src/protocol/ipa_prf/quicksort.rs | 4 +- ipa-step-derive/src/lib.rs | 96 ++++++++++++++++++- ipa-step-derive/src/variant.rs | 57 ++++++++++- ipa-step-test/src/lib.rs | 10 ++ 12 files changed, 207 insertions(+), 86 deletions(-) diff --git a/ipa-core/src/protocol/basics/reveal.rs b/ipa-core/src/protocol/basics/reveal.rs index 51d5b2891..75867046f 100644 --- a/ipa-core/src/protocol/basics/reveal.rs +++ b/ipa-core/src/protocol/basics/reveal.rs @@ -98,7 +98,7 @@ where ctx.parallel_join(zip(&**self, repeat(ctx.clone())).enumerate().map( |(i, (bit, ctx))| async move { generic_reveal( - ctx.narrow(&TwoHundredFiftySixBitOpStep::Bit(i)), + ctx.narrow(&TwoHundredFiftySixBitOpStep::from(i)), record_id, excluded, bit, diff --git a/ipa-core/src/protocol/boolean/and.rs b/ipa-core/src/protocol/boolean/and.rs index c05f5cdf8..9afbfd9af 100644 --- a/ipa-core/src/protocol/boolean/and.rs +++ b/ipa-core/src/protocol/boolean/and.rs @@ -51,7 +51,7 @@ where BitDecomposed::try_from( ctx.parallel_join(zip(a.iter(), b).enumerate().map(|(i, (a, b))| { - let ctx = ctx.narrow(&EightBitStep::Bit(i)); + let ctx = ctx.narrow(&EightBitStep::from(i)); a.multiply(b, ctx, record_id) })) .await?, diff --git a/ipa-core/src/protocol/boolean/or.rs b/ipa-core/src/protocol/boolean/or.rs index c8aa611c9..176cbd239 100644 --- a/ipa-core/src/protocol/boolean/or.rs +++ b/ipa-core/src/protocol/boolean/or.rs @@ -52,7 +52,7 @@ where BitDecomposed::try_from( ctx.parallel_join(zip(a.iter(), b).enumerate().map(|(i, (a, b))| { - let ctx = ctx.narrow(&SixteenBitStep::Bit(i)); + let ctx = ctx.narrow(&SixteenBitStep::from(i)); async move { let ab = a.multiply(b, ctx, record_id).await?; Ok::<_, Error>(-ab + a + b) diff --git a/ipa-core/src/protocol/boolean/step.rs b/ipa-core/src/protocol/boolean/step.rs index 869f92726..0128cb037 100644 --- a/ipa-core/src/protocol/boolean/step.rs +++ b/ipa-core/src/protocol/boolean/step.rs @@ -1,63 +1,22 @@ use ipa_step_derive::CompactStep; #[derive(CompactStep)] -pub enum EightBitStep { - #[step(count = 8)] - Bit(usize), -} - -impl From for EightBitStep { - fn from(v: usize) -> Self { - Self::Bit(v) - } -} +#[step(count = 8, name = "bit")] +pub struct EightBitStep(usize); #[derive(CompactStep)] -pub enum SixteenBitStep { - #[step(count = 16)] - Bit(usize), -} - -impl From for SixteenBitStep { - fn from(v: usize) -> Self { - Self::Bit(v) - } -} +#[step(count = 16, name = "bit")] +pub struct SixteenBitStep(usize); #[derive(CompactStep)] -pub enum ThirtyTwoBitStep { - #[step(count = 32)] - Bit(usize), -} - -impl From for ThirtyTwoBitStep { - fn from(v: usize) -> Self { - Self::Bit(v) - } -} +#[step(count = 32, name = "bit")] +pub struct ThirtyTwoBitStep(usize); #[derive(CompactStep)] -pub enum TwoHundredFiftySixBitOpStep { - #[step(count = 256)] - Bit(usize), -} - -impl From for TwoHundredFiftySixBitOpStep { - fn from(v: usize) -> Self { - Self::Bit(v) - } -} +#[step(count = 256, name = "bit")] +pub struct TwoHundredFiftySixBitOpStep(usize); #[cfg(test)] #[derive(CompactStep)] -pub enum DefaultBitStep { - #[step(count = 256)] - Bit(usize), -} - -#[cfg(test)] -impl From for DefaultBitStep { - fn from(v: usize) -> Self { - Self::Bit(v) - } -} +#[step(count = 256, name = "bit")] +pub struct DefaultBitStep(usize); diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index 69663f5d5..f7bc026cc 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -234,8 +234,8 @@ where let stream = aggregation_input.by_ref().take(chunk); let validator = ctx.clone().dzkp_validator( MaliciousProtocolSteps { - protocol: &Step::AggregateChunk(chunk_counter), - validate: &Step::AggregateChunkValidate(chunk_counter), + protocol: &Step::aggregate_chunk(chunk_counter), + validate: &Step::aggregate_chunk_validate(chunk_counter), }, agg_proof_chunk, ); @@ -333,7 +333,7 @@ where // number of outputs (`next_num_rows`) gets rounded up. If calculating an explicit total // records, that would get rounded down. let par_agg_ctx = ctx - .narrow(&AggregateChunkStep::Aggregate(depth)) + .narrow(&AggregateChunkStep::from(depth)) .set_total_records(TotalRecords::Indeterminate); let next_num_rows = (num_rows + 1) / 2; aggregated_stream = Box::pin( diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs index 2b1995fd2..65fff707d 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs @@ -26,17 +26,9 @@ pub(crate) enum AggregationStep { #[step(count = 512, child = crate::protocol::boolean::step::EightBitStep, name = "b")] pub struct BucketStep(usize); -impl From for BucketStep { - fn from(v: usize) -> Self { - Self(v) - } -} - #[derive(CompactStep)] -pub(crate) enum AggregateChunkStep { - #[step(count = 32, child = AggregateValuesStep)] - Aggregate(usize), -} +#[step(count = 32, child = AggregateValuesStep, name = "depth")] +pub(crate) struct AggregateChunkStep(usize); #[derive(CompactStep)] pub(crate) enum AggregateValuesStep { diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 69b645055..66e3a89e6 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -899,6 +899,7 @@ pub mod tests { boolean_array::{BooleanArray, BA16, BA20, BA3, BA5, BA8}, Field, U128Conversions, }, + helpers::repeat_n, protocol::ipa_prf::prf_sharding::attribute_cap_aggregate, rand::Rng, secret_sharing::{ @@ -909,6 +910,7 @@ pub mod tests { test_fixture::{Reconstruct, Runner, TestWorld}, }; + #[derive(Clone)] struct PreShardedAndSortedOPRFTestInput { prf_of_match_key: u64, is_trigger_bit: Boolean, @@ -1102,6 +1104,7 @@ pub mod tests { ); }); } + #[test] fn semi_honest_aggregation_capping_attribution_with_attribution_window() { const ATTRIBUTION_WINDOW_SECONDS: u32 = 200; @@ -1162,6 +1165,32 @@ pub mod tests { }); } + #[test] + #[should_panic(expected = "v < usize::try_from(64usize).unwrap()")] + fn attribution_too_many_records_per_user() { + run(|| async move { + let world = TestWorld::default(); + + let records: Vec> = + repeat_n(oprf_test_input(123, false, 17, 0), 65).collect(); + + let histogram = repeat_n(1, 65).collect::>(); + let histogram_ref = histogram.as_slice(); + + world + .malicious(records.into_iter(), |ctx, input_rows| async move { + attribute_cap_aggregate::<_, BA5, BA3, BA16, BA20, 5, 32>( + ctx, + input_rows, + None, + histogram_ref, + ) + .await + .unwrap() + }) + .await; + }); + } #[test] fn capping_bugfix() { const HISTOGRAM: [usize; 10] = [5, 5, 5, 5, 5, 5, 5, 2, 1, 1]; diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs index e32de5a40..710b0a7e3 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs @@ -1,16 +1,8 @@ use ipa_step_derive::CompactStep; #[derive(CompactStep)] -pub enum UserNthRowStep { - #[step(count = 64, child = AttributionPerRowStep)] - Row(usize), -} - -impl From for UserNthRowStep { - fn from(v: usize) -> Self { - Self::Row(v) - } -} +#[step(count = 64, child = AttributionPerRowStep, name = "row")] +pub struct UserNthRowStep(usize); #[derive(CompactStep)] pub(crate) enum AttributionStep { diff --git a/ipa-core/src/protocol/ipa_prf/quicksort.rs b/ipa-core/src/protocol/ipa_prf/quicksort.rs index c198dbe90..16c3f1c8d 100644 --- a/ipa-core/src/protocol/ipa_prf/quicksort.rs +++ b/ipa-core/src/protocol/ipa_prf/quicksort.rs @@ -171,8 +171,8 @@ where .expect("num_comparisons_needed should not be zero"); let v = ctx.set_total_records(total_records).dzkp_validator( MaliciousProtocolSteps { - protocol: &Step::QuicksortPass(quicksort_pass), - validate: &Step::QuicksortPassValidate(quicksort_pass), + protocol: &Step::quicksort_pass(quicksort_pass), + validate: &Step::quicksort_pass_validate(quicksort_pass), }, // TODO: use something like this when validating in chunks // `TARGET_PROOF_SIZE / usize::try_from(K::BITS).unwrap() / SORT_CHUNK`` diff --git a/ipa-step-derive/src/lib.rs b/ipa-step-derive/src/lib.rs index 0ffdd7b1c..e86ddb934 100644 --- a/ipa-step-derive/src/lib.rs +++ b/ipa-step-derive/src/lib.rs @@ -402,6 +402,13 @@ mod test { "e! { impl ::ipa_step::Step for ManyArms {} + impl ManyArms { + pub fn arm(v: u8) -> Self { + assert!(v < u8::try_from(3usize).unwrap()); + Self::Arm(v) + } + } + #[allow( clippy::useless_conversion, clippy::unnecessary_fallible_conversions, @@ -424,7 +431,8 @@ mod test { const STEP_COUNT: ::ipa_step::CompactGateIndex = 3; fn base_index (& self) -> ::ipa_step::CompactGateIndex { match self { - Self::Arm (i) => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + Self::Arm (i) if *i < u8::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + _ => panic!("Index out of range in ManyArms. Consider using bounds-checked step constructors."), } } fn step_string(i: ::ipa_step::CompactGateIndex) -> String { @@ -451,6 +459,13 @@ mod test { "e! { impl ::ipa_step::Step for ManyArms {} + impl ManyArms { + pub fn arm(v: u8) -> Self { + assert!(v < u8::try_from(3usize).unwrap()); + Self::Arm(v) + } + } + #[allow( clippy::useless_conversion, clippy::unnecessary_fallible_conversions, @@ -473,7 +488,8 @@ mod test { const STEP_COUNT: ::ipa_step::CompactGateIndex = 3; fn base_index (& self) -> ::ipa_step::CompactGateIndex { match self { - Self::Arm (i) => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + Self::Arm (i) if *i < u8::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + _ => panic!("Index out of range in ManyArms. Consider using bounds-checked step constructors."), } } fn step_string(i: ::ipa_step::CompactGateIndex) -> String { @@ -642,6 +658,12 @@ mod test { "e! { impl ::ipa_step::Step for Parent {} + impl Parent { + pub fn offspring(v: u8) -> Self { + assert!(v < u8::try_from(5usize).unwrap()); + Self::Offspring(v) + } + } #[allow( clippy::useless_conversion, @@ -667,7 +689,8 @@ mod test { const STEP_COUNT: ::ipa_step::CompactGateIndex = (::STEP_COUNT + 1) * 5; fn base_index(&self) -> ::ipa_step::CompactGateIndex { match self { - Self::Offspring(i) => (::STEP_COUNT + 1) * ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + Self::Offspring(i) if *i < u8::try_from(5usize).unwrap() => (::STEP_COUNT + 1) * ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + _ => panic!("Index out of range in Parent. Consider using bounds-checked step constructors."), } } fn step_string(i: ::ipa_step::CompactGateIndex) -> String { @@ -726,6 +749,13 @@ mod test { "e! { impl ::ipa_step::Step for AllArms {} + impl AllArms { + pub fn int(v: usize) -> Self { + assert!(v < usize::try_from(3usize).unwrap()); + Self::Int(v) + } + } + #[allow( clippy::useless_conversion, clippy::unnecessary_fallible_conversions, @@ -752,9 +782,10 @@ mod test { fn base_index(&self) -> ::ipa_step::CompactGateIndex { match self { Self::Empty => 0, - Self::Int(i) => ::ipa_step::CompactGateIndex::try_from(*i).unwrap() + 1, + Self::Int(i) if *i < usize::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap() + 1, Self::Child => 4, Self::Final => <::some::other::StepEnum as ::ipa_step::CompactStep>::STEP_COUNT + 5, + _ => panic!("Index out of range in AllArms. Consider using bounds-checked step constructors."), } } fn step_string(i: ::ipa_step::CompactGateIndex) -> String { @@ -854,6 +885,63 @@ mod test { ); } + #[test] + fn struct_int() { + derive_success( + quote! { + #[derive(CompactStep)] + #[step(count = 3)] + struct StructInt(u8); + }, + "e! { + impl ::ipa_step::Step for StructInt {} + + impl From for StructInt { + fn from(v: u8) -> Self { + assert!(v < u8::try_from(3usize).unwrap()); + Self(v) + } + } + + #[allow( + clippy::useless_conversion, + clippy::unnecessary_fallible_conversions, + )] + impl ::std::convert::AsRef for StructInt { + fn as_ref(&self) -> &str { + const STRUCT_INT_NAMES: [&str; 3] = ["struct_int0" , "struct_int1" , "struct_int2"]; + match self { + Self(i) => STRUCT_INT_NAMES[usize::try_from(*i).unwrap()], + } + } + } + + #[allow( + clippy::useless_conversion, + clippy::unnecessary_fallible_conversions, + clippy::identity_op, + )] + impl ::ipa_step::CompactStep for StructInt { + const STEP_COUNT: ::ipa_step::CompactGateIndex = 3; + + fn base_index(&self) -> ::ipa_step::CompactGateIndex { + match self { + Self(i) if *i < u8::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), + _ => panic!("Index out of range in StructInt. Consider using bounds-checked step constructors."), + } + } + + fn step_string(i: ::ipa_step::CompactGateIndex) -> String { + match i { + _ if i < 3 => Self(u8::try_from(i - (0)).unwrap()).as_ref().to_owned(), + _ => panic!("step {i} is not valid for {t}", t = ::std::any::type_name::()), + } + } + } + }, + ); + } + #[test] fn struct_missing_count() { derive_failure( diff --git a/ipa-step-derive/src/variant.rs b/ipa-step-derive/src/variant.rs index 1b72bdfe1..71ceccdc7 100644 --- a/ipa-step-derive/src/variant.rs +++ b/ipa-step-derive/src/variant.rs @@ -214,6 +214,8 @@ pub struct Generator { arm_count: ExtendedSum, // This tracks the index of each item. index_arms: TokenStream, + // This tracks integer variant constructors. + int_variant_constructors: TokenStream, // This tracks the arrays of names that are used for integer variants. name_arrays: TokenStream, // This tracks the arms of the `AsRef` match implementation. @@ -339,6 +341,16 @@ impl Generator { quote!(Self) }; + if is_variant { + let constructor = format_ident!("{}", step_ident.to_string().to_snake_case()); + self.int_variant_constructors.extend(quote! { + pub fn #constructor(v: #step_integer) -> Self { + assert!(v < #step_integer::try_from(#step_count).unwrap()); + Self::#step_ident(v) + } + }); + } + // Construct some nice names for each integer value in the range. let array_name = format_ident!("{}_NAMES", step_ident.to_string().to_shouting_case()); let skip_zeros = match *step_count - 1 { @@ -363,7 +375,7 @@ impl Generator { let idx = self.arm_count.clone() + quote!((<#child as ::ipa_step::CompactStep>::STEP_COUNT + 1) * ::ipa_step::CompactGateIndex::try_from(*i).unwrap()); self.index_arms.extend(quote! { - #arm(i) => #idx, + #arm(i) if *i < #step_integer::try_from(#step_count).unwrap() => #idx, }); // With `step_count` variations present, each has a name. @@ -404,7 +416,7 @@ impl Generator { let idx = self.arm_count.clone() + quote!(::ipa_step::CompactGateIndex::try_from(*i).unwrap()); self.index_arms.extend(quote! { - #arm(i) => #idx, + #arm(i) if *i < #step_integer::try_from(#step_count).unwrap() => #idx, }); let range_end = arm_count.clone() + *step_count; @@ -415,6 +427,7 @@ impl Generator { } } + #[allow(clippy::too_many_lines)] pub fn generate(mut self, ident: &Ident, attr: &VariantAttribute) -> TokenStream { self.add_outer(attr); @@ -422,6 +435,41 @@ impl Generator { impl ::ipa_step::Step for #ident {} }; + // Generate a bounds-checking `impl From` if this is an integer unit struct step. + if let &Some((count, ref type_path)) = &attr.integer { + result.extend(quote! { + impl From<#type_path> for #ident { + fn from(v: #type_path) -> Self { + assert!(v < #type_path::try_from(#count).unwrap()); + Self(v) + } + } + }); + } + + // Generate bounds-checking variant constructors if there are integer variants. + if !self.int_variant_constructors.is_empty() { + let constructors = self.int_variant_constructors; + result.extend(quote! { + impl #ident { + #constructors + } + }); + } + + let index_arm_wild = if self.name_arrays.is_empty() { + quote!() + } else { + // Note that the current `AsRef` impl indexes into an array of the valid step names, so + // will panic if used here to generate the message. + let panic_msg = format!( + "Index out of range in {ident}. Consider using bounds-checked step constructors.", + ); + quote! { + _ => panic!(#panic_msg), + } + }; + assert_eq!(self.index_arms.is_empty(), self.as_ref_arms.is_empty()); let (index_arms, as_ref_arms) = if self.index_arms.is_empty() { let n = attr.name(); @@ -431,7 +479,10 @@ impl Generator { let as_ref_arms = self.as_ref_arms; ( quote! { - match self { #index_arms } + match self { + #index_arms + #index_arm_wild + } }, quote! { match self { #as_ref_arms } diff --git a/ipa-step-test/src/lib.rs b/ipa-step-test/src/lib.rs index c2c9abfed..e9294186c 100644 --- a/ipa-step-test/src/lib.rs +++ b/ipa-step-test/src/lib.rs @@ -55,6 +55,16 @@ mod tests { _ = ComplexGate::from("/two2/one").narrow(&BasicStep::Two); } + /// Attempts to narrow with an out-of-range index should panic + /// (rather than produce an incorrect output gate). + #[test] + #[should_panic( + expected = "Index out of range in ComplexStep. Consider using bounds-checked step constructors." + )] + fn index_out_of_range() { + _ = ComplexGate::default().narrow(&ComplexStep::Two(10)); + } + /// Test that the alpha and beta gates work. #[test] fn alpha_and_beta() { From 0221147fbb2e5851fecbf7ff6dfd59206a04ab63 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Tue, 17 Sep 2024 16:46:33 -0700 Subject: [PATCH 020/191] Using new public version of Axum Server --- .gitignore | 1 + ipa-core/Cargo.toml | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 8aa41d239..674cf39f2 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ /in-market-test/hpke/bin /in-market-test/hpke/lib /in-market-test/hpke/pyvenv.cfg +input-data-*.txt \ No newline at end of file diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index ecb68cc33..819eecdb7 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -85,9 +85,7 @@ async-scoped = { version = "0.9.0", features = ["use-tokio"], optional = true } axum = { version = "0.7.5", optional = true, features = ["http2", "macros"] } # The following is a temporary version until we can stabilize the build on a higher version # of axum, rustls and the http stack. -axum-server = { git = "https://github.com/cberkhoff/axum-server/", branch = "0.6.1", version = "0.6.1", optional = true, features = [ - "tls-rustls", -] } +axum-server = { version = "0.7.1", optional = true, features = ["tls-rustls"] } base64 = { version = "0.21.2", optional = true } bitvec = "1.0" bytes = "1.4" From fec61d0874693e083949f4b136bf15e3226814ec Mon Sep 17 00:00:00 2001 From: danielmasny Date: Tue, 17 Sep 2024 17:30:09 -0700 Subject: [PATCH 021/191] creating shuffle trait with different shuffle implementations for malicious and semi honest + plumbing it in the executor --- ipa-core/src/protocol/ipa_prf/mod.rs | 4 +- ipa-core/src/protocol/ipa_prf/shuffle/base.rs | 23 ++++- .../src/protocol/ipa_prf/shuffle/malicious.rs | 11 ++- ipa-core/src/protocol/ipa_prf/shuffle/mod.rs | 87 +++++++++++++++++-- ipa-core/src/query/runner/oprf_ipa.rs | 6 +- 5 files changed, 109 insertions(+), 22 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 994713911..55d8e7414 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -96,7 +96,7 @@ use crate::{ protocol::{ context::Validator, dp::dp_for_histogram, - ipa_prf::{oprf_padding::PaddingParameters, prf_eval::PrfSharing}, + ipa_prf::{oprf_padding::PaddingParameters, prf_eval::PrfSharing, shuffle::Shuffle}, }, secret_sharing::replicated::semi_honest::AdditiveShare, }; @@ -228,7 +228,7 @@ pub async fn oprf_ipa<'ctx, C, BK, TV, HV, TS, const SS_BITS: usize, const B: us dp_padding_params: PaddingParameters, ) -> Result>, Error> where - C: UpgradableContext + 'ctx, + C: UpgradableContext + 'ctx + Shuffle, BK: BreakdownKey, TV: BooleanArray + U128Conversions, HV: BooleanArray + U128Conversions, diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs index a34477d69..a48cbd128 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs @@ -18,7 +18,22 @@ use crate::{ /// # Errors /// Will propagate errors from transport and a few typecasts -pub async fn shuffle( +pub async fn semi_honest_shuffle(ctx: C, shares: I) -> Result>, Error> +where + C: Context, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + S: SharedValue + Add, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, + Standard: Distribution, +{ + Ok(shuffle_protocol(ctx, shares).await?.0) +} + +/// # Errors +/// Will propagate errors from transport and a few typecasts +pub async fn shuffle_protocol( ctx: C, shares: I, ) -> Result<(Vec>, IntermediateShuffleMessages), Error> @@ -430,7 +445,7 @@ where pub mod tests { use rand::{thread_rng, Rng}; - use super::shuffle; + use super::shuffle_protocol; use crate::{ ff::{Gf40Bit, U128Conversions}, secret_sharing::replicated::ReplicatedSecretSharing, @@ -453,7 +468,7 @@ pub mod tests { // Stable seed is used to get predictable shuffle results. let mut actual = TestWorld::new_with(TestWorldConfig::default().with_seed(123)) .semi_honest(records.clone().into_iter(), |ctx, shares| async move { - shuffle(ctx, shares).await.unwrap().0 + shuffle_protocol(ctx, shares).await.unwrap().0 }) .await .reconstruct(); @@ -484,7 +499,7 @@ pub mod tests { let [h1, h2, h3] = world .semi_honest(records.clone().into_iter(), |ctx, records| async move { - shuffle(ctx, records).await + shuffle_protocol(ctx, records).await }) .await; diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index d698393ce..218a1c195 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -18,7 +18,9 @@ use crate::{ protocol::{ basics::{malicious_reveal, mul::semi_honest_multiply}, context::Context, - ipa_prf::shuffle::{base::IntermediateShuffleMessages, shuffle, step::OPRFShuffleStep}, + ipa_prf::shuffle::{ + base::IntermediateShuffleMessages, shuffle_protocol, step::OPRFShuffleStep, + }, prss::SharedRandomness, RecordId, }, @@ -65,7 +67,7 @@ where compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?; // shuffle - let (shuffled_shares, messages) = shuffle( + let (shuffled_shares, messages) = shuffle_protocol( ctx.narrow(&OPRFShuffleStep::ShuffleProtocol), shares_and_tags, ) @@ -480,7 +482,7 @@ mod tests { Serializable, U128Conversions, }, helpers::in_memory_config::{MaliciousHelper, MaliciousHelperContext}, - protocol::ipa_prf::shuffle::base::shuffle, + protocol::ipa_prf::shuffle::base::shuffle_protocol, secret_sharing::SharedValue, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig}, @@ -590,7 +592,8 @@ mod tests { // trivial shares of Gf32Bit::ONE let key_shares = vec![AdditiveShare::new(Gf32Bit::ONE, Gf32Bit::ONE); 1]; // run shuffle - let (shares, messages) = shuffle(ctx.narrow("shuffle"), rows).await.unwrap(); + let (shares, messages) = + shuffle_protocol(ctx.narrow("shuffle"), rows).await.unwrap(); // verify it verify_shuffle::<_, BA32, BA64>( ctx.narrow("verify"), diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index 2908bf066..4119878d4 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -1,8 +1,8 @@ -use std::ops::Add; +use std::{future::Future, ops::Add}; -use rand::distributions::Standard; +use rand::distributions::{Distribution, Standard}; -use self::base::shuffle; +use self::base::shuffle_protocol; use super::{ boolean_ops::{expand_shared_array_in_place, extract_from_shared_array}, prf_sharding::SecretSharedAttributionOutputs, @@ -11,30 +11,99 @@ use crate::{ error::Error, ff::{ boolean::Boolean, - boolean_array::{BooleanArray, BA112, BA64}, + boolean_array::{BooleanArray, BA112, BA144, BA64}, ArrayAccess, }, - protocol::{context::Context, ipa_prf::OPRFIPAInputRow}, + protocol::{ + context::{Context, MaliciousContext, SemiHonestContext}, + ipa_prf::{ + shuffle::{base::semi_honest_shuffle, malicious::malicious_shuffle}, + OPRFIPAInputRow, + }, + }, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, SharedValue, }, + sharding::ShardBinding, }; pub mod base; -#[allow(dead_code)] pub mod malicious; #[cfg(descriptive_gate)] mod sharded; pub(crate) mod step; +pub trait Shuffle: Context { + fn shuffle( + self, + shares: I, + ) -> impl Future>, Error>> + Send + where + S: BooleanArray, + B: BooleanArray, + I: IntoIterator> + Send, + I::IntoIter: ExactSizeIterator, + ::IntoIter: Send, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, + for<'a> &'a B: Add, + for<'a> &'a B: Add<&'a B, Output = B>, + Standard: Distribution, + Standard: Distribution; +} + +impl<'b, T: ShardBinding> Shuffle for SemiHonestContext<'b, T> { + fn shuffle( + self, + shares: I, + ) -> impl Future>, Error>> + Send + where + S: BooleanArray, + B: BooleanArray, + I: IntoIterator> + Send, + I::IntoIter: ExactSizeIterator, + ::IntoIter: Send, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, + for<'a> &'a B: Add, + for<'a> &'a B: Add<&'a B, Output = B>, + Standard: Distribution, + Standard: Distribution, + { + semi_honest_shuffle::<_, I, S>(self, shares) + } +} + +impl<'b> Shuffle for MaliciousContext<'b> { + fn shuffle( + self, + shares: I, + ) -> impl Future>, Error>> + Send + where + S: BooleanArray, + B: BooleanArray, + I: IntoIterator> + Send, + I::IntoIter: ExactSizeIterator, + ::IntoIter: Send, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, + for<'a> &'a B: Add, + for<'a> &'a B: Add<&'a B, Output = B>, + Standard: Distribution, + Standard: Distribution, + { + malicious_shuffle::<_, S, B, I>(self, shares) + } +} + #[tracing::instrument(name = "shuffle_inputs", skip_all)] pub async fn shuffle_inputs( ctx: C, input: Vec>, ) -> Result>, Error> where - C: Context, + C: Context + Shuffle, BK: BooleanArray, TV: BooleanArray, TS: BooleanArray, @@ -44,7 +113,7 @@ where .map(|item| oprfreport_to_shuffle_input::(&item)) .collect::>(); - let (shuffled, _) = shuffle(ctx, shuffle_input).await?; + let shuffled = ctx.shuffle::(shuffle_input).await?; Ok(shuffled .into_iter() @@ -71,7 +140,7 @@ where .map(|item| attribution_outputs_to_shuffle_input::(&item)) .collect::>(); - let (shuffled, _) = shuffle(ctx, shuffle_input).await?; + let (shuffled, _) = shuffle_protocol(ctx, shuffle_input).await?; Ok(shuffled .into_iter() diff --git a/ipa-core/src/query/runner/oprf_ipa.rs b/ipa-core/src/query/runner/oprf_ipa.rs index 99c59b727..b2ec0cffc 100644 --- a/ipa-core/src/query/runner/oprf_ipa.rs +++ b/ipa-core/src/query/runner/oprf_ipa.rs @@ -21,8 +21,8 @@ use crate::{ basics::{BooleanArrayMul, Reveal, ShareKnownValue}, context::{DZKPUpgraded, MacUpgraded, UpgradableContext}, ipa_prf::{ - oprf_ipa, oprf_padding::PaddingParameters, prf_eval::PrfSharing, OPRFIPAInputRow, - AGG_CHUNK, CONV_CHUNK, PRF_CHUNK, SORT_CHUNK, + oprf_ipa, oprf_padding::PaddingParameters, prf_eval::PrfSharing, shuffle::Shuffle, + OPRFIPAInputRow, AGG_CHUNK, CONV_CHUNK, PRF_CHUNK, SORT_CHUNK, }, prss::FromPrss, step::ProtocolStep::IpaPrf, @@ -55,7 +55,7 @@ impl OprfIpaQuery { #[allow(clippy::too_many_lines)] impl OprfIpaQuery where - C: UpgradableContext, + C: UpgradableContext + Shuffle, HV: BooleanArray + U128Conversions, R: PrivateKeyRegistry, Replicated: Serializable + ShareKnownValue, From 9bf659bba8c3b535b990aed8bfa71b91060c25d2 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Tue, 17 Sep 2024 17:55:53 -0700 Subject: [PATCH 022/191] fix compact gate for malicious shuffle --- ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs | 6 +----- ipa-core/src/protocol/ipa_prf/shuffle/step.rs | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index 218a1c195..b2e4001cf 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -67,11 +67,7 @@ where compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?; // shuffle - let (shuffled_shares, messages) = shuffle_protocol( - ctx.narrow(&OPRFShuffleStep::ShuffleProtocol), - shares_and_tags, - ) - .await?; + let (shuffled_shares, messages) = shuffle_protocol(ctx.clone(), shares_and_tags).await?; // verify the shuffle verify_shuffle::<_, S, B>( diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs index cc492dcd5..9c86b8821 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs @@ -10,7 +10,6 @@ pub(crate) enum OPRFShuffleStep { TransferX2, TransferY1, GenerateTags, - ShuffleProtocol, VerifyShuffle, RevealMACKey, HashesH3toH1, From 2c4deb7c9306ea6a1c47bda1ef8609dfb679f874 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Tue, 17 Sep 2024 18:10:01 -0700 Subject: [PATCH 023/191] more fixing by adding additional step node during malicious shuffle verification --- .../src/protocol/ipa_prf/shuffle/malicious.rs | 18 ++++++++++-------- ipa-core/src/protocol/ipa_prf/shuffle/step.rs | 5 +++++ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index b2e4001cf..b03363288 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -19,7 +19,9 @@ use crate::{ basics::{malicious_reveal, mul::semi_honest_multiply}, context::Context, ipa_prf::shuffle::{ - base::IntermediateShuffleMessages, shuffle_protocol, step::OPRFShuffleStep, + base::IntermediateShuffleMessages, + shuffle_protocol, + step::{OPRFShuffleStep, VerifyShuffleStep}, }, prss::SharedRandomness, RecordId, @@ -139,7 +141,7 @@ async fn verify_shuffle( ) -> Result<(), Error> { // reveal keys let k_ctx = ctx - .narrow(&OPRFShuffleStep::RevealMACKey) + .narrow(&VerifyShuffleStep::RevealMACKey) .set_total_records(TotalRecords::specified(key_shares.len())?); let keys = reveal_keys(&k_ctx, key_shares) .await? @@ -190,10 +192,10 @@ async fn h1_verify( // setup channels let h3_ctx = ctx - .narrow(&OPRFShuffleStep::HashesH3toH1) + .narrow(&VerifyShuffleStep::HashesH3toH1) .set_total_records(TotalRecords::specified(2)?); let h2_ctx = ctx - .narrow(&OPRFShuffleStep::HashH2toH1) + .narrow(&VerifyShuffleStep::HashH2toH1) .set_total_records(TotalRecords::ONE); let channel_h3 = &h3_ctx.recv_channel::(ctx.role().peer(Direction::Left)); let channel_h2 = &h2_ctx.recv_channel::(ctx.role().peer(Direction::Right)); @@ -255,10 +257,10 @@ async fn h2_verify( // setup channels let h1_ctx = ctx - .narrow(&OPRFShuffleStep::HashH2toH1) + .narrow(&VerifyShuffleStep::HashH2toH1) .set_total_records(TotalRecords::specified(1)?); let h3_ctx = ctx - .narrow(&OPRFShuffleStep::HashH3toH2) + .narrow(&VerifyShuffleStep::HashH3toH2) .set_total_records(TotalRecords::specified(1)?); let channel_h1 = &h1_ctx.send_channel::(ctx.role().peer(Direction::Left)); let channel_h3 = &h3_ctx.recv_channel::(ctx.role().peer(Direction::Right)); @@ -306,10 +308,10 @@ async fn h3_verify( // setup channels let h1_ctx = ctx - .narrow(&OPRFShuffleStep::HashesH3toH1) + .narrow(&VerifyShuffleStep::HashesH3toH1) .set_total_records(TotalRecords::specified(2)?); let h2_ctx = ctx - .narrow(&OPRFShuffleStep::HashH3toH2) + .narrow(&VerifyShuffleStep::HashH3toH2) .set_total_records(TotalRecords::specified(1)?); let channel_h1 = &h1_ctx.send_channel::(ctx.role().peer(Direction::Right)); let channel_h2 = &h2_ctx.send_channel::(ctx.role().peer(Direction::Left)); diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs index 9c86b8821..126996574 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs @@ -10,7 +10,12 @@ pub(crate) enum OPRFShuffleStep { TransferX2, TransferY1, GenerateTags, + #[step(child = crate::protocol::ipa_prf::shuffle::step::VerifyShuffleStep)] VerifyShuffle, +} + +#[derive(CompactStep)] +pub(crate) enum VerifyShuffleStep { RevealMACKey, HashesH3toH1, HashH2toH1, From da00669cab543ad414d462ab8e1b3f9bde1be19e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 18 Sep 2024 11:30:38 -0700 Subject: [PATCH 024/191] Sharded shuffle query type In order for us to test sharded circuits through the HTTP stack, we need a protocol that exercises both shard-to-shard and helper-to-helper path. Shuffle seems like a good fit for that, so this change just plumbs the new query type through the stack --- ipa-core/src/bin/test_mpc.rs | 10 ++++++++++ ipa-core/src/helpers/transport/query/mod.rs | 5 +++++ ipa-core/src/net/http_serde.rs | 1 + ipa-core/src/query/executor.rs | 6 ++++++ 4 files changed, 22 insertions(+) diff --git a/ipa-core/src/bin/test_mpc.rs b/ipa-core/src/bin/test_mpc.rs index 74e2e7284..9da4afbb2 100644 --- a/ipa-core/src/bin/test_mpc.rs +++ b/ipa-core/src/bin/test_mpc.rs @@ -85,6 +85,11 @@ enum TestAction { /// All helpers add their shares locally and set the resulting share to be the /// sum. No communication is required to run the circuit. AddInPrimeField, + /// A test protocol for sharded MPCs. The goal here is to use + /// both shard-to-shard and helper-to-helper communication channels. + /// This is exactly what shuffle does and that's why it is picked + /// for this purpose. + ShardedShuffle, } #[tokio::main] @@ -102,6 +107,7 @@ async fn main() -> Result<(), Box> { match args.action { TestAction::Multiply => multiply(&args, &clients).await, TestAction::AddInPrimeField => add(&args, &clients).await, + TestAction::ShardedShuffle => sharded_shuffle(&args, &clients).await, }; Ok(()) @@ -159,3 +165,7 @@ async fn add(args: &Args, helper_clients: &[MpcHelperClient; 3]) { FieldType::Fp32BitPrime => add_in_field::(args, helper_clients).await, }; } + +async fn sharded_shuffle(_args: &Args, _helper_clients: &[MpcHelperClient; 3]) { + unimplemented!() +} diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index f509185a5..3cb655173 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -198,6 +198,8 @@ pub enum QueryType { TestMultiply, #[cfg(any(test, feature = "test-fixture", feature = "cli"))] TestAddInPrimeField, + #[cfg(any(test, feature = "test-fixture", feature = "cli"))] + TestShardedShuffle, SemiHonestOprfIpa(IpaQueryConfig), MaliciousOprfIpa(IpaQueryConfig), } @@ -206,6 +208,7 @@ impl QueryType { /// TODO: strum pub const TEST_MULTIPLY_STR: &'static str = "test-multiply"; pub const TEST_ADD_STR: &'static str = "test-add"; + pub const TEST_SHARDED_SHUFFLE_STR: &'static str = "test-sharded-shuffle"; pub const SEMI_HONEST_OPRF_IPA_STR: &'static str = "semi-honest-oprf-ipa"; pub const MALICIOUS_OPRF_IPA_STR: &'static str = "malicious-oprf-ipa"; } @@ -218,6 +221,8 @@ impl AsRef for QueryType { QueryType::TestMultiply => Self::TEST_MULTIPLY_STR, #[cfg(any(test, feature = "cli", feature = "test-fixture"))] QueryType::TestAddInPrimeField => Self::TEST_ADD_STR, + #[cfg(any(test, feature = "cli", feature = "test-fixture"))] + QueryType::TestShardedShuffle => Self::TEST_SHARDED_SHUFFLE_STR, QueryType::SemiHonestOprfIpa(_) => Self::SEMI_HONEST_OPRF_IPA_STR, QueryType::MaliciousOprfIpa(_) => Self::MALICIOUS_OPRF_IPA_STR, } diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 687fd2f19..cef850ae8 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -152,6 +152,7 @@ pub mod query { match self.query_type { #[cfg(any(test, feature = "test-fixture", feature = "cli"))] QueryType::TestMultiply | QueryType::TestAddInPrimeField => Ok(()), + QueryType::TestShardedShuffle => Ok(()), QueryType::SemiHonestOprfIpa(config) | QueryType::MaliciousOprfIpa(config) => { write!( f, diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index c24ab8b5f..a3e7b866d 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -94,6 +94,12 @@ pub fn execute( Box::pin(execute_test_multiply::(prss, gateway, input)) }) } + #[cfg(any(test, feature = "cli", feature = "test-fixture"))] + (QueryType::TestShardedShuffle, _) => { + do_query(config, gateway, input, |_prss, _gateway, _config, _input| { + unimplemented!() + }) + } #[cfg(any(test, feature = "weak-field"))] (QueryType::TestAddInPrimeField, FieldType::Fp31) => { do_query(config, gateway, input, |prss, gateway, _config, input| { From 84605cb2dc8b4350ab90f5d99d394e65604a633d Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Wed, 18 Sep 2024 13:20:23 -0700 Subject: [PATCH 025/191] New in the clear implementation for Hybrid --- ipa-core/src/test_fixture/hybrid.rs | 127 ++++++++++++++++++++++++++++ ipa-core/src/test_fixture/mod.rs | 1 + 2 files changed, 128 insertions(+) create mode 100644 ipa-core/src/test_fixture/hybrid.rs diff --git a/ipa-core/src/test_fixture/hybrid.rs b/ipa-core/src/test_fixture/hybrid.rs new file mode 100644 index 000000000..7c1cb7d73 --- /dev/null +++ b/ipa-core/src/test_fixture/hybrid.rs @@ -0,0 +1,127 @@ +use std::collections::{HashMap, HashSet}; + +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq)] +pub enum TestHybridRecord { + TestImpression { match_key: u64, breakdown_key: u32 }, + TestConversion { match_key: u64, value: u32 }, +} + +pub fn hybrid_in_the_clear(input_rows: &[TestHybridRecord], max_breakdown: usize) -> Vec { + let mut conversion_match_keys = HashSet::::new(); + let mut impression_match_keys = HashSet::::new(); + + for input in input_rows { + match input { + TestHybridRecord::TestImpression { match_key, .. } => { + impression_match_keys.insert(*match_key); + } + TestHybridRecord::TestConversion { match_key, .. } => { + conversion_match_keys.insert(*match_key); + } + } + } + + let mut attributed_conversions = HashMap::::new(); + + for input in input_rows { + match input { + TestHybridRecord::TestImpression { + match_key, + breakdown_key, + } => { + if let Some(_) = conversion_match_keys.get(match_key) { + attributed_conversions + .entry(*match_key) + .and_modify(|e| e.0 = *breakdown_key) + .or_insert((*breakdown_key, 0)); + } + } + TestHybridRecord::TestConversion { match_key, value } => { + if let Some(_) = impression_match_keys.get(match_key) { + attributed_conversions + .entry(*match_key) + .and_modify(|e| e.1 += value) + .or_insert((0, *value)); + } + } + } + } + + let mut output = vec![0; max_breakdown]; + for (_, (breakdown_key, value)) in attributed_conversions { + output[usize::try_from(breakdown_key).unwrap()] += value; + } + + return output; +} + +#[cfg(all(test, unit_test))] +mod tests { + use super::TestHybridRecord; + use crate::test_fixture::hybrid::hybrid_in_the_clear; + + #[test] + fn basic() { + let test_data = vec![ + TestHybridRecord::TestImpression { + match_key: 12345, + breakdown_key: 2, + }, + TestHybridRecord::TestImpression { + match_key: 23456, + breakdown_key: 4, + }, + TestHybridRecord::TestConversion { + match_key: 23456, + value: 25, + }, // attributed + TestHybridRecord::TestImpression { + match_key: 34567, + breakdown_key: 1, + }, + TestHybridRecord::TestImpression { + match_key: 45678, + breakdown_key: 3, + }, + TestHybridRecord::TestConversion { + match_key: 45678, + value: 13, + }, // attributed + TestHybridRecord::TestImpression { + match_key: 56789, + breakdown_key: 5, + }, + TestHybridRecord::TestConversion { + match_key: 67890, + value: 14, + }, // NOT attributed + TestHybridRecord::TestImpression { + match_key: 78901, + breakdown_key: 2, + }, + TestHybridRecord::TestConversion { + match_key: 78901, + value: 12, + }, // attributed + TestHybridRecord::TestConversion { + match_key: 78901, + value: 31, + }, // attributed + TestHybridRecord::TestImpression { + match_key: 89012, + breakdown_key: 4, + }, + TestHybridRecord::TestConversion { + match_key: 89012, + value: 8, + }, // attributed + ]; + let expected = vec![ + 0, 0, 43, // 12 + 31 + 13, 33, // 25 + 8 + 0, + ]; + let result = hybrid_in_the_clear(&test_data, 6); + assert_eq!(result, expected); + } +} diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index 999f327d6..f49e739c8 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -12,6 +12,7 @@ mod app; pub mod circuit; mod event_gen; pub mod ipa; +pub mod hybrid; pub mod logging; pub mod metrics; pub(crate) mod step; From 4b83d0b81851d2f77e48112a2b1d749198dbd341 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Wed, 18 Sep 2024 13:50:45 -0700 Subject: [PATCH 026/191] A few cleanups and fixes --- ipa-core/src/test_fixture/hybrid.rs | 45 ++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/ipa-core/src/test_fixture/hybrid.rs b/ipa-core/src/test_fixture/hybrid.rs index 7c1cb7d73..478b624c9 100644 --- a/ipa-core/src/test_fixture/hybrid.rs +++ b/ipa-core/src/test_fixture/hybrid.rs @@ -6,6 +6,23 @@ pub enum TestHybridRecord { TestConversion { match_key: u64, value: u32 }, } +struct HashmapEntry { + breakdown_key: u32, + total_value: u32, +} + +impl HashmapEntry { + pub fn new(breakdown_key: u32, value: u32) -> Self { + Self { + breakdown_key, + total_value: value, + } + } +} + +/// # Panics +/// It won't, so long as you can convert a u32 to a usize +#[must_use] pub fn hybrid_in_the_clear(input_rows: &[TestHybridRecord], max_breakdown: usize) -> Vec { let mut conversion_match_keys = HashSet::::new(); let mut impression_match_keys = HashSet::::new(); @@ -21,7 +38,8 @@ pub fn hybrid_in_the_clear(input_rows: &[TestHybridRecord], max_breakdown: usize } } - let mut attributed_conversions = HashMap::::new(); + // The key is the "match key" and the value stores both the breakdown and total attributed value + let mut attributed_conversions = HashMap::::new(); for input in input_rows { match input { @@ -29,40 +47,42 @@ pub fn hybrid_in_the_clear(input_rows: &[TestHybridRecord], max_breakdown: usize match_key, breakdown_key, } => { - if let Some(_) = conversion_match_keys.get(match_key) { + if conversion_match_keys.contains(match_key) { attributed_conversions .entry(*match_key) - .and_modify(|e| e.0 = *breakdown_key) - .or_insert((*breakdown_key, 0)); + .and_modify(|e| e.breakdown_key = *breakdown_key) + .or_insert(HashmapEntry::new(*breakdown_key, 0)); } } TestHybridRecord::TestConversion { match_key, value } => { - if let Some(_) = impression_match_keys.get(match_key) { + if impression_match_keys.contains(match_key) { attributed_conversions .entry(*match_key) - .and_modify(|e| e.1 += value) - .or_insert((0, *value)); + .and_modify(|e| e.total_value += value) + .or_insert(HashmapEntry::new(0, *value)); } } } } let mut output = vec![0; max_breakdown]; - for (_, (breakdown_key, value)) in attributed_conversions { - output[usize::try_from(breakdown_key).unwrap()] += value; + for (_, entry) in attributed_conversions { + output[usize::try_from(entry.breakdown_key).unwrap()] += entry.total_value; } - return output; + output } #[cfg(all(test, unit_test))] mod tests { + use rand::{seq::SliceRandom, thread_rng}; + use super::TestHybridRecord; use crate::test_fixture::hybrid::hybrid_in_the_clear; #[test] fn basic() { - let test_data = vec![ + let mut test_data = vec![ TestHybridRecord::TestImpression { match_key: 12345, breakdown_key: 2, @@ -116,6 +136,9 @@ mod tests { value: 8, }, // attributed ]; + + let mut rng = thread_rng(); + test_data.shuffle(&mut rng); let expected = vec![ 0, 0, 43, // 12 + 31 13, 33, // 25 + 8 From afbce30d165942139b263e75a43e201f8ee19909 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Wed, 18 Sep 2024 14:03:33 -0700 Subject: [PATCH 027/191] formatting --- ipa-core/src/test_fixture/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index f49e739c8..2ae05994f 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -11,8 +11,8 @@ mod app; #[cfg(feature = "in-memory-infra")] pub mod circuit; mod event_gen; -pub mod ipa; pub mod hybrid; +pub mod ipa; pub mod logging; pub mod metrics; pub(crate) mod step; From 1b123e0fc8d61b6c1b93b5031bda6b264db08abb Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 18 Sep 2024 13:56:13 -0700 Subject: [PATCH 028/191] Improve error messages --- .../src/protocol/ipa_prf/prf_sharding/mod.rs | 2 +- ipa-step-derive/src/lib.rs | 37 +++++++++---- ipa-step-derive/src/variant.rs | 54 +++++++++++-------- ipa-step-test/src/lib.rs | 2 +- 4 files changed, 59 insertions(+), 36 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 66e3a89e6..c80581ae4 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -1166,7 +1166,7 @@ pub mod tests { } #[test] - #[should_panic(expected = "v < usize::try_from(64usize).unwrap()")] + #[should_panic(expected = "Step index 64 out of bounds for UserNthRowStep with count 64.")] fn attribution_too_many_records_per_user() { run(|| async move { let world = TestWorld::default(); diff --git a/ipa-step-derive/src/lib.rs b/ipa-step-derive/src/lib.rs index e86ddb934..ea5d19268 100644 --- a/ipa-step-derive/src/lib.rs +++ b/ipa-step-derive/src/lib.rs @@ -118,7 +118,7 @@ fn derive_step_impl(ast: &DeriveInput) -> Result { let mut g = Generator::default(); let attr = match &ast.data { Data::Enum(data) => { - for v in VariantAttribute::parse_variants(data)? { + for v in VariantAttribute::parse_variants(ident, data)? { g.add_variant(&v); } VariantAttribute::parse_outer(ident, &ast.attrs, None)? @@ -404,7 +404,10 @@ mod test { impl ManyArms { pub fn arm(v: u8) -> Self { - assert!(v < u8::try_from(3usize).unwrap()); + assert!( + v < u8::try_from(3usize).unwrap(), + "Step index {v} out of bounds for ManyArms::Arm with count 3.", + ); Self::Arm(v) } } @@ -432,7 +435,7 @@ mod test { fn base_index (& self) -> ::ipa_step::CompactGateIndex { match self { Self::Arm (i) if *i < u8::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), - _ => panic!("Index out of range in ManyArms. Consider using bounds-checked step constructors."), + Self::Arm (i) => panic!("Step index {i} out of bounds for ManyArms::Arm with count 3. Consider using bounds-checked step constructors."), } } fn step_string(i: ::ipa_step::CompactGateIndex) -> String { @@ -461,7 +464,10 @@ mod test { impl ManyArms { pub fn arm(v: u8) -> Self { - assert!(v < u8::try_from(3usize).unwrap()); + assert!( + v < u8::try_from(3usize).unwrap(), + "Step index {v} out of bounds for ManyArms::Arm with count 3.", + ); Self::Arm(v) } } @@ -489,7 +495,7 @@ mod test { fn base_index (& self) -> ::ipa_step::CompactGateIndex { match self { Self::Arm (i) if *i < u8::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), - _ => panic!("Index out of range in ManyArms. Consider using bounds-checked step constructors."), + Self::Arm (i) => panic!("Step index {i} out of bounds for ManyArms::Arm with count 3. Consider using bounds-checked step constructors."), } } fn step_string(i: ::ipa_step::CompactGateIndex) -> String { @@ -660,7 +666,10 @@ mod test { impl Parent { pub fn offspring(v: u8) -> Self { - assert!(v < u8::try_from(5usize).unwrap()); + assert!( + v < u8::try_from(5usize).unwrap(), + "Step index {v} out of bounds for Parent::Offspring with count 5.", + ); Self::Offspring(v) } } @@ -690,7 +699,7 @@ mod test { fn base_index(&self) -> ::ipa_step::CompactGateIndex { match self { Self::Offspring(i) if *i < u8::try_from(5usize).unwrap() => (::STEP_COUNT + 1) * ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), - _ => panic!("Index out of range in Parent. Consider using bounds-checked step constructors."), + Self::Offspring(i) => panic!("Step index {i} out of bounds for Parent::Offspring with count 5. Consider using bounds-checked step constructors."), } } fn step_string(i: ::ipa_step::CompactGateIndex) -> String { @@ -751,7 +760,10 @@ mod test { impl AllArms { pub fn int(v: usize) -> Self { - assert!(v < usize::try_from(3usize).unwrap()); + assert!( + v < usize::try_from(3usize).unwrap(), + "Step index {v} out of bounds for AllArms::Int with count 3.", + ); Self::Int(v) } } @@ -783,9 +795,9 @@ mod test { match self { Self::Empty => 0, Self::Int(i) if *i < usize::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap() + 1, + Self::Int(i) => panic!("Step index {i} out of bounds for AllArms::Int with count 3. Consider using bounds-checked step constructors."), Self::Child => 4, Self::Final => <::some::other::StepEnum as ::ipa_step::CompactStep>::STEP_COUNT + 5, - _ => panic!("Index out of range in AllArms. Consider using bounds-checked step constructors."), } } fn step_string(i: ::ipa_step::CompactGateIndex) -> String { @@ -898,7 +910,10 @@ mod test { impl From for StructInt { fn from(v: u8) -> Self { - assert!(v < u8::try_from(3usize).unwrap()); + assert!( + v < u8::try_from(3usize).unwrap(), + "Step index {v} out of bounds for StructInt with count 3.", + ); Self(v) } } @@ -927,7 +942,7 @@ mod test { fn base_index(&self) -> ::ipa_step::CompactGateIndex { match self { Self(i) if *i < u8::try_from(3usize).unwrap() => ::ipa_step::CompactGateIndex::try_from(*i).unwrap(), - _ => panic!("Index out of range in StructInt. Consider using bounds-checked step constructors."), + Self(i) => panic!("Step index {i} out of bounds for StructInt with count 3. Consider using bounds-checked step constructors."), } } diff --git a/ipa-step-derive/src/variant.rs b/ipa-step-derive/src/variant.rs index 71ceccdc7..792aa8fd1 100644 --- a/ipa-step-derive/src/variant.rs +++ b/ipa-step-derive/src/variant.rs @@ -9,6 +9,7 @@ use syn::{ use crate::{sum::ExtendedSum, IntoSpan}; struct VariantAttrParser<'a> { + full_name: String, ident: &'a Ident, name: Option, count: Option, @@ -17,8 +18,9 @@ struct VariantAttrParser<'a> { } impl<'a> VariantAttrParser<'a> { - fn new(ident: &'a Ident) -> Self { + fn new(full_name: String, ident: &'a Ident) -> Self { Self { + full_name, ident, name: None, count: None, @@ -161,6 +163,7 @@ impl<'a> VariantAttrParser<'a> { ) } else { Ok(VariantAttribute { + full_name: self.full_name, ident: self.ident.clone(), name: self .name @@ -173,6 +176,7 @@ impl<'a> VariantAttrParser<'a> { } pub struct VariantAttribute { + full_name: String, ident: Ident, name: String, integer: Option<(usize, TypePath)>, @@ -188,10 +192,11 @@ impl VariantAttribute { } /// Parse a set of attributes out from a representation of an enum. - pub fn parse_variants(data: &DataEnum) -> Result, syn::Error> { + pub fn parse_variants(enum_ident: &Ident, data: &DataEnum) -> Result, syn::Error> { let mut steps = Vec::with_capacity(data.variants.len()); for v in &data.variants { - steps.push(VariantAttrParser::new(&v.ident).parse_variant(v)?); + let full_name = format!("{}::{}", enum_ident, v.ident); + steps.push(VariantAttrParser::new(full_name, &v.ident).parse_variant(v)?); } Ok(steps) } @@ -202,7 +207,7 @@ impl VariantAttribute { attrs: &[Attribute], fields: Option<&Fields>, ) -> Result { - VariantAttrParser::new(ident).parse_outer(attrs, fields) + VariantAttrParser::new(ident.to_string(), ident).parse_outer(attrs, fields) } } @@ -256,6 +261,7 @@ impl Generator { fn add_empty(&mut self, v: &VariantAttribute, is_variant: bool) { // Unpack so that we can use `quote!()`. let VariantAttribute { + full_name: _, ident: step_ident, name: step_name, integer: None, @@ -325,6 +331,7 @@ impl Generator { fn add_int(&mut self, v: &VariantAttribute, is_variant: bool) { // Unpack so that we can use `quote!()`. let VariantAttribute { + full_name: step_full_name, ident: step_ident, name: step_name, integer: Some((step_count, step_integer)), @@ -343,9 +350,15 @@ impl Generator { if is_variant { let constructor = format_ident!("{}", step_ident.to_string().to_snake_case()); + let out_of_bounds_msg = format!( + "Step index {{v}} out of bounds for {step_full_name} with count {step_count}." + ); self.int_variant_constructors.extend(quote! { pub fn #constructor(v: #step_integer) -> Self { - assert!(v < #step_integer::try_from(#step_count).unwrap()); + assert!( + v < #step_integer::try_from(#step_count).unwrap(), + #out_of_bounds_msg, + ); Self::#step_ident(v) } }); @@ -374,8 +387,11 @@ impl Generator { if let Some(child) = step_child { let idx = self.arm_count.clone() + quote!((<#child as ::ipa_step::CompactStep>::STEP_COUNT + 1) * ::ipa_step::CompactGateIndex::try_from(*i).unwrap()); + let out_of_bounds_msg = + format!("Step index {{i}} out of bounds for {step_full_name} with count {step_count}. Consider using bounds-checked step constructors."); self.index_arms.extend(quote! { #arm(i) if *i < #step_integer::try_from(#step_count).unwrap() => #idx, + #arm(i) => panic!(#out_of_bounds_msg), }); // With `step_count` variations present, each has a name. @@ -415,8 +431,11 @@ impl Generator { } else { let idx = self.arm_count.clone() + quote!(::ipa_step::CompactGateIndex::try_from(*i).unwrap()); + let out_of_bounds_msg = + format!("Step index {{i}} out of bounds for {step_full_name} with count {step_count}. Consider using bounds-checked step constructors."); self.index_arms.extend(quote! { #arm(i) if *i < #step_integer::try_from(#step_count).unwrap() => #idx, + #arm(i) => panic!(#out_of_bounds_msg), }); let range_end = arm_count.clone() + *step_count; @@ -437,10 +456,15 @@ impl Generator { // Generate a bounds-checking `impl From` if this is an integer unit struct step. if let &Some((count, ref type_path)) = &attr.integer { + let out_of_bounds_msg = + format!("Step index {{v}} out of bounds for {ident} with count {count}."); result.extend(quote! { impl From<#type_path> for #ident { fn from(v: #type_path) -> Self { - assert!(v < #type_path::try_from(#count).unwrap()); + assert!( + v < #type_path::try_from(#count).unwrap(), + #out_of_bounds_msg, + ); Self(v) } } @@ -457,19 +481,6 @@ impl Generator { }); } - let index_arm_wild = if self.name_arrays.is_empty() { - quote!() - } else { - // Note that the current `AsRef` impl indexes into an array of the valid step names, so - // will panic if used here to generate the message. - let panic_msg = format!( - "Index out of range in {ident}. Consider using bounds-checked step constructors.", - ); - quote! { - _ => panic!(#panic_msg), - } - }; - assert_eq!(self.index_arms.is_empty(), self.as_ref_arms.is_empty()); let (index_arms, as_ref_arms) = if self.index_arms.is_empty() { let n = attr.name(); @@ -479,10 +490,7 @@ impl Generator { let as_ref_arms = self.as_ref_arms; ( quote! { - match self { - #index_arms - #index_arm_wild - } + match self { #index_arms } }, quote! { match self { #as_ref_arms } diff --git a/ipa-step-test/src/lib.rs b/ipa-step-test/src/lib.rs index e9294186c..84eab760d 100644 --- a/ipa-step-test/src/lib.rs +++ b/ipa-step-test/src/lib.rs @@ -59,7 +59,7 @@ mod tests { /// (rather than produce an incorrect output gate). #[test] #[should_panic( - expected = "Index out of range in ComplexStep. Consider using bounds-checked step constructors." + expected = "Step index 10 out of bounds for ComplexStep::Two with count 10. Consider using bounds-checked step constructors." )] fn index_out_of_range() { _ = ComplexGate::default().narrow(&ComplexStep::Two(10)); From a26d58e2788cc5184e7779516e377452b7b82f54 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 18 Sep 2024 15:07:24 -0700 Subject: [PATCH 029/191] Hybrid query params --- .../src/helpers/transport/query/hybrid.rs | 32 ++++++++++++++++ ipa-core/src/helpers/transport/query/mod.rs | 6 +++ ipa-core/src/net/http_serde.rs | 16 ++++++++ ipa-core/src/query/executor.rs | 26 ++++++++++--- ipa-core/src/query/runner/hybrid.rs | 37 +++++++++++++++++++ ipa-core/src/query/runner/mod.rs | 2 + 6 files changed, 113 insertions(+), 6 deletions(-) create mode 100644 ipa-core/src/helpers/transport/query/hybrid.rs create mode 100644 ipa-core/src/query/runner/hybrid.rs diff --git a/ipa-core/src/helpers/transport/query/hybrid.rs b/ipa-core/src/helpers/transport/query/hybrid.rs new file mode 100644 index 000000000..2b6906d28 --- /dev/null +++ b/ipa-core/src/helpers/transport/query/hybrid.rs @@ -0,0 +1,32 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq)] +#[cfg_attr(feature = "clap", derive(clap::Args))] +pub struct HybridQueryParams { + #[cfg_attr(feature = "clap", arg(long, default_value = "8"))] + pub per_user_credit_cap: u32, + #[cfg_attr(feature = "clap", arg(long, default_value = "5"))] + pub max_breakdown_key: u32, + #[cfg_attr(feature = "clap", arg(short = 'd', long, default_value = "1"))] + pub with_dp: u32, + #[cfg_attr(feature = "clap", arg(short = 'e', long, default_value = "5.0"))] + pub epsilon: f64, + #[cfg_attr(feature = "clap", arg(long))] + #[serde(default)] + pub plaintext_match_keys: bool, +} + +#[cfg(test)] +impl Eq for HybridQueryParams {} + +impl Default for HybridQueryParams { + fn default() -> Self { + Self { + per_user_credit_cap: 8, + max_breakdown_key: 20, + with_dp: 1, + epsilon: 0.10, + plaintext_match_keys: false, + } + } +} diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index 3cb655173..ac70209b3 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -1,8 +1,11 @@ +mod hybrid; + use std::{ fmt::{Debug, Display, Formatter}, num::NonZeroU32, }; +pub use hybrid::HybridQueryParams; use serde::{Deserialize, Deserializer, Serialize}; use crate::{ @@ -202,6 +205,7 @@ pub enum QueryType { TestShardedShuffle, SemiHonestOprfIpa(IpaQueryConfig), MaliciousOprfIpa(IpaQueryConfig), + SemiHonestHybrid(HybridQueryParams), } impl QueryType { @@ -211,6 +215,7 @@ impl QueryType { pub const TEST_SHARDED_SHUFFLE_STR: &'static str = "test-sharded-shuffle"; pub const SEMI_HONEST_OPRF_IPA_STR: &'static str = "semi-honest-oprf-ipa"; pub const MALICIOUS_OPRF_IPA_STR: &'static str = "malicious-oprf-ipa"; + pub const SEMI_HONEST_HYBRID_STR: &'static str = "semi-honest-hybrid"; } /// TODO: should this `AsRef` impl (used for `Substep`) take into account config of IPA? @@ -225,6 +230,7 @@ impl AsRef for QueryType { QueryType::TestShardedShuffle => Self::TEST_SHARDED_SHUFFLE_STR, QueryType::SemiHonestOprfIpa(_) => Self::SEMI_HONEST_OPRF_IPA_STR, QueryType::MaliciousOprfIpa(_) => Self::MALICIOUS_OPRF_IPA_STR, + QueryType::SemiHonestHybrid(_) => Self::SEMI_HONEST_HYBRID_STR, } } } diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index cef850ae8..98fae4dc0 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -171,6 +171,22 @@ pub mod query { write!(f, "&attribution_window_seconds={}", window.get())?; } + Ok(()) + } + QueryType::SemiHonestHybrid(config) => { + write!( + f, + "&per_user_credit_cap={}&max_breakdown_key={}&with_dp={}&epsilon={}", + config.per_user_credit_cap, + config.max_breakdown_key, + config.with_dp, + config.epsilon, + )?; + + if config.plaintext_match_keys { + write!(f, "&plaintext_match_keys=true")?; + } + Ok(()) } } diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index a3e7b866d..b3e197e4d 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -44,7 +44,7 @@ use crate::{ Gate, }, query::{ - runner::{OprfIpaQuery, QueryResult}, + runner::{HybridQuery, OprfIpaQuery, QueryResult}, state::RunningQuery, }, sync::Arc, @@ -95,11 +95,12 @@ pub fn execute( }) } #[cfg(any(test, feature = "cli", feature = "test-fixture"))] - (QueryType::TestShardedShuffle, _) => { - do_query(config, gateway, input, |_prss, _gateway, _config, _input| { - unimplemented!() - }) - } + (QueryType::TestShardedShuffle, _) => do_query( + config, + gateway, + input, + |_prss, _gateway, _config, _input| unimplemented!(), + ), #[cfg(any(test, feature = "weak-field"))] (QueryType::TestAddInPrimeField, FieldType::Fp31) => { do_query(config, gateway, input, |prss, gateway, _config, input| { @@ -144,6 +145,19 @@ pub fn execute( ) }, ), + (QueryType::SemiHonestHybrid(query_params), _) => do_query( + config, + gateway, + input, + move |prss, gateway, config, input| { + let ctx = SemiHonestContext::new(prss, gateway); + Box::pin( + HybridQuery::<_, BA32, R>::new(query_params, key_registry) + .execute(ctx, config.size, input) + .then(|res| ready(res.map(|out| Box::new(out) as Box))), + ) + }, + ), } } diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs new file mode 100644 index 000000000..73abcaf75 --- /dev/null +++ b/ipa-core/src/query/runner/hybrid.rs @@ -0,0 +1,37 @@ +use std::{marker::PhantomData, sync::Arc}; + +use crate::{ + error::Error, + helpers::{ + query::{HybridQueryParams, QuerySize}, + BodyStream, + }, + hpke::PrivateKeyRegistry, + secret_sharing::{replicated::semi_honest::AdditiveShare as ReplicatedShare, SharedValue}, +}; + +pub struct Query { + _config: HybridQueryParams, + _key_registry: Arc, + phantom_data: PhantomData<(C, HV)>, +} + +impl Query { + pub fn new(query_params: HybridQueryParams, key_registry: Arc) -> Self { + Self { + _config: query_params, + _key_registry: key_registry, + phantom_data: PhantomData, + } + } + + #[tracing::instrument("hybrid_query", skip_all, fields(sz=%query_size))] + pub async fn execute( + self, + _ctx: C, + query_size: QuerySize, + _input_stream: BodyStream, + ) -> Result>, Error> { + unimplemented!() + } +} diff --git a/ipa-core/src/query/runner/mod.rs b/ipa-core/src/query/runner/mod.rs index 4c7240cbb..9e5935c20 100644 --- a/ipa-core/src/query/runner/mod.rs +++ b/ipa-core/src/query/runner/mod.rs @@ -1,11 +1,13 @@ #[cfg(any(test, feature = "cli", feature = "test-fixture"))] mod add_in_prime_field; +mod hybrid; mod oprf_ipa; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] mod test_multiply; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use add_in_prime_field::execute as test_add_in_prime_field; +pub use hybrid::Query as HybridQuery; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use test_multiply::execute_test_multiply; From f614c785e36799546bcf1382498e5e41b1a451f9 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 18 Sep 2024 15:26:58 -0700 Subject: [PATCH 030/191] Small syntax changes --- ipa-core/src/test_fixture/hybrid.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ipa-core/src/test_fixture/hybrid.rs b/ipa-core/src/test_fixture/hybrid.rs index 478b624c9..63ecf73e5 100644 --- a/ipa-core/src/test_fixture/hybrid.rs +++ b/ipa-core/src/test_fixture/hybrid.rs @@ -24,8 +24,8 @@ impl HashmapEntry { /// It won't, so long as you can convert a u32 to a usize #[must_use] pub fn hybrid_in_the_clear(input_rows: &[TestHybridRecord], max_breakdown: usize) -> Vec { - let mut conversion_match_keys = HashSet::::new(); - let mut impression_match_keys = HashSet::::new(); + let mut conversion_match_keys = HashSet::new(); + let mut impression_match_keys = HashSet::new(); for input in input_rows { match input { @@ -39,7 +39,7 @@ pub fn hybrid_in_the_clear(input_rows: &[TestHybridRecord], max_breakdown: usize } // The key is the "match key" and the value stores both the breakdown and total attributed value - let mut attributed_conversions = HashMap::::new(); + let mut attributed_conversions = HashMap::new(); for input in input_rows { match input { @@ -48,10 +48,10 @@ pub fn hybrid_in_the_clear(input_rows: &[TestHybridRecord], max_breakdown: usize breakdown_key, } => { if conversion_match_keys.contains(match_key) { - attributed_conversions + let v = attributed_conversions .entry(*match_key) - .and_modify(|e| e.breakdown_key = *breakdown_key) .or_insert(HashmapEntry::new(*breakdown_key, 0)); + v.breakdown_key = *breakdown_key; } } TestHybridRecord::TestConversion { match_key, value } => { From 0aa1e5032795ac9bac5b5028fe79f94e45672530 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 18 Sep 2024 15:35:15 -0700 Subject: [PATCH 031/191] Fix compile errors in release --- ipa-core/src/net/http_serde.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 98fae4dc0..927ae4a4d 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -152,6 +152,7 @@ pub mod query { match self.query_type { #[cfg(any(test, feature = "test-fixture", feature = "cli"))] QueryType::TestMultiply | QueryType::TestAddInPrimeField => Ok(()), + #[cfg(any(test, feature = "test-fixture", feature = "cli"))] QueryType::TestShardedShuffle => Ok(()), QueryType::SemiHonestOprfIpa(config) | QueryType::MaliciousOprfIpa(config) => { write!( From 6125d8cc643acf08da42d951547ed0f4255b2b54 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Wed, 18 Sep 2024 16:08:32 -0700 Subject: [PATCH 032/191] Creating an event gen script for Hybrid --- ipa-core/src/test_fixture/hybrid_event_gen.rs | 209 ++++++++++++++++++ ipa-core/src/test_fixture/mod.rs | 1 + 2 files changed, 210 insertions(+) create mode 100644 ipa-core/src/test_fixture/hybrid_event_gen.rs diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs new file mode 100644 index 000000000..2cbe57820 --- /dev/null +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -0,0 +1,209 @@ +use std::num::{NonZeroU32, NonZeroU64}; + +use rand::Rng; + +use super::hybrid::TestHybridRecord; + +#[derive(Debug, Copy, Clone)] +#[cfg_attr(feature = "clap", derive(clap::ValueEnum))] +pub enum ConversionDistribution { + Default, + LotsOfConversionsPerImpression, + OnlyImpressions, + OnlyConversions, +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "clap", derive(clap::Args))] +pub struct Config { + #[cfg_attr(feature = "clap", arg(long, default_value = "1000000000000"))] + pub num_events: NonZeroU64, + #[cfg_attr(feature = "clap", arg(long, default_value = "5"))] + pub max_conversion_value: NonZeroU32, + #[cfg_attr(feature = "clap", arg(long, default_value = "20"))] + pub max_breakdown_key: NonZeroU32, + #[cfg_attr(feature = "clap", arg(long, default_value = "10"))] + pub max_convs_per_imp: NonZeroU32, + /// Indicates the distribution of impression to conversion reports. + #[cfg_attr(feature = "clap", arg(value_enum, long, default_value_t = ConversionDistribution::Default))] + pub conversion_distribution: ConversionDistribution, +} + +impl Default for Config { + fn default() -> Self { + Self::new(1_000, 5, 20, 10) + } +} + +impl Config { + /// Creates a new instance of [`Self`] + /// + /// ## Panics + /// If any argument is 0. + #[must_use] + pub fn new( + num_events: u64, + max_conversion_value: u32, + max_breakdown_key: u32, + max_convs_per_imp: u32, + ) -> Self { + Self { + num_events: NonZeroU64::try_from(num_events).unwrap(), + max_conversion_value: NonZeroU32::try_from(max_conversion_value).unwrap(), + max_breakdown_key: NonZeroU32::try_from(max_breakdown_key).unwrap(), + max_convs_per_imp: NonZeroU32::try_from(max_convs_per_imp).unwrap(), + conversion_distribution: ConversionDistribution::Default, + } + } +} + +pub struct EventGenerator { + config: Config, + rng: R, + in_flight: Vec, +} + +impl EventGenerator { + pub fn with_default_config(rng: R) -> Self { + Self::with_config(rng, Config::default()) + } + + /// # Panics + /// If the configuration is not valid. + pub fn with_config(rng: R, config: Config) -> Self { + let max_capacity = usize::try_from(config.max_convs_per_imp.get() + 1).unwrap(); + Self { + config, + rng, + in_flight: Vec::with_capacity(max_capacity), + } + } + + fn gen_batch(&mut self) { + match self.config.conversion_distribution { + ConversionDistribution::OnlyImpressions => { + let match_key = self.rng.gen::(); + let imp = self.gen_impression(match_key); + self.in_flight.push(imp); + } + ConversionDistribution::OnlyConversions => { + let match_key = self.rng.gen::(); + let conv = self.gen_conversion(match_key); + self.in_flight.push(conv); + } + ConversionDistribution::Default => { + let match_key = self.rng.gen::(); + match self.rng.gen::() { + // 10% chance of unmatched conversion + 0..=25 => { + let conv = self.gen_conversion(match_key); + self.in_flight.push(conv); + } + + // 70% chance of unmatched impression + 26..=206 => { + let imp = self.gen_impression(match_key); + self.in_flight.push(imp); + } + + // 20% chance of impression with at least one conversion + _ => { + let imp = self.gen_impression(match_key); + let conv = self.gen_conversion(match_key); + self.in_flight.push(imp); + self.in_flight.push(conv); + let mut conv_count = 1; + // long-tailed distribution of # of conversions per impression + // 15.6% chance of adding each subsequent conversion + // will not exceed the configured maximum number of conversions per impression + while conv_count < self.config.max_convs_per_imp.get() + && self.rng.gen::() < 40 + { + let conv = self.gen_conversion(match_key); + self.in_flight.push(conv); + conv_count += 1; + } + } + } + } + ConversionDistribution::LotsOfConversionsPerImpression => { + let match_key = self.rng.gen::(); + match self.rng.gen::() { + // 40% chance of unmatched conversion + 0..=102 => { + let conv = self.gen_conversion(match_key); + self.in_flight.push(conv); + } + + // 30% chance of unmatched impression + 103..=180 => { + let imp = self.gen_impression(match_key); + self.in_flight.push(imp); + } + + // 30% chance of impression with at least one conversion + _ => { + let imp = self.gen_impression(match_key); + let conv = self.gen_conversion(match_key); + self.in_flight.push(imp); + self.in_flight.push(conv); + let mut conv_count = 1; + // long-tailed distribution of # of conversions per impression + // 80% chance of adding each subsequent conversion + // will not exceed the configured maximum number of conversions per impression + while conv_count < self.config.max_convs_per_imp.get() + && self.rng.gen::() < 205 + { + let conv = self.gen_conversion(match_key); + self.in_flight.push(conv); + conv_count += 1; + } + } + } + } + } + } + + fn gen_conversion(&mut self, match_key: u64) -> TestHybridRecord { + TestHybridRecord::TestConversion { + match_key, + value: self + .rng + .gen_range(1..self.config.max_conversion_value.get()), + } + } + + fn gen_impression(&mut self, match_key: u64) -> TestHybridRecord { + TestHybridRecord::TestImpression { + match_key, + breakdown_key: self.rng.gen_range(0..self.config.max_breakdown_key.get()), + } + } +} + +impl Iterator for EventGenerator { + type Item = TestHybridRecord; + + fn next(&mut self) -> Option { + if self.in_flight.is_empty() { + self.gen_batch(); + } + self.in_flight.pop() + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use rand::thread_rng; + + use super::*; + + #[test] + fn iter() { + let gen = EventGenerator::with_default_config(thread_rng()); + assert_eq!(10, gen.take(10).collect::>().len()); + + let gen = EventGenerator::with_default_config(thread_rng()); + assert_eq!(59, gen.take(59).collect::>().len()); + } +} diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index 2ae05994f..54df493ee 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -12,6 +12,7 @@ mod app; pub mod circuit; mod event_gen; pub mod hybrid; +mod hybrid_event_gen; pub mod ipa; pub mod logging; pub mod metrics; From cdc423077c03ec83f5f057734707cb81f58a76b3 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Wed, 18 Sep 2024 16:31:45 -0700 Subject: [PATCH 033/191] refactoring --- ipa-core/src/test_fixture/hybrid_event_gen.rs | 111 ++++++------------ 1 file changed, 39 insertions(+), 72 deletions(-) diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs index 2cbe57820..73dd7a898 100644 --- a/ipa-core/src/test_fixture/hybrid_event_gen.rs +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -64,12 +64,14 @@ pub struct EventGenerator { } impl EventGenerator { + #[must_use] pub fn with_default_config(rng: R) -> Self { Self::with_config(rng, Config::default()) } /// # Panics /// If the configuration is not valid. + #[must_use] pub fn with_config(rng: R, config: Config) -> Self { let max_capacity = usize::try_from(config.max_convs_per_imp.get() + 1).unwrap(); Self { @@ -82,84 +84,49 @@ impl EventGenerator { fn gen_batch(&mut self) { match self.config.conversion_distribution { ConversionDistribution::OnlyImpressions => { - let match_key = self.rng.gen::(); - let imp = self.gen_impression(match_key); - self.in_flight.push(imp); + self.gen_batch_with_params(0.0, 1.0, 0.0); } ConversionDistribution::OnlyConversions => { - let match_key = self.rng.gen::(); - let conv = self.gen_conversion(match_key); - self.in_flight.push(conv); + self.gen_batch_with_params(1.0, 0.0, 0.0); } ConversionDistribution::Default => { - let match_key = self.rng.gen::(); - match self.rng.gen::() { - // 10% chance of unmatched conversion - 0..=25 => { - let conv = self.gen_conversion(match_key); - self.in_flight.push(conv); - } - - // 70% chance of unmatched impression - 26..=206 => { - let imp = self.gen_impression(match_key); - self.in_flight.push(imp); - } - - // 20% chance of impression with at least one conversion - _ => { - let imp = self.gen_impression(match_key); - let conv = self.gen_conversion(match_key); - self.in_flight.push(imp); - self.in_flight.push(conv); - let mut conv_count = 1; - // long-tailed distribution of # of conversions per impression - // 15.6% chance of adding each subsequent conversion - // will not exceed the configured maximum number of conversions per impression - while conv_count < self.config.max_convs_per_imp.get() - && self.rng.gen::() < 40 - { - let conv = self.gen_conversion(match_key); - self.in_flight.push(conv); - conv_count += 1; - } - } - } + self.gen_batch_with_params(0.1, 0.7, 0.15); } ConversionDistribution::LotsOfConversionsPerImpression => { - let match_key = self.rng.gen::(); - match self.rng.gen::() { - // 40% chance of unmatched conversion - 0..=102 => { - let conv = self.gen_conversion(match_key); - self.in_flight.push(conv); - } - - // 30% chance of unmatched impression - 103..=180 => { - let imp = self.gen_impression(match_key); - self.in_flight.push(imp); - } - - // 30% chance of impression with at least one conversion - _ => { - let imp = self.gen_impression(match_key); - let conv = self.gen_conversion(match_key); - self.in_flight.push(imp); - self.in_flight.push(conv); - let mut conv_count = 1; - // long-tailed distribution of # of conversions per impression - // 80% chance of adding each subsequent conversion - // will not exceed the configured maximum number of conversions per impression - while conv_count < self.config.max_convs_per_imp.get() - && self.rng.gen::() < 205 - { - let conv = self.gen_conversion(match_key); - self.in_flight.push(conv); - conv_count += 1; - } - } - } + self.gen_batch_with_params(0.3, 0.4, 0.8); + } + } + } + + fn gen_batch_with_params( + &mut self, + unmatched_conversions: f32, + unmatched_impressions: f32, + subsequent_conversion_prob: f32, + ) { + assert!(unmatched_conversions + unmatched_impressions <= 1.0); + let match_key = self.rng.gen::(); + let rand = self.rng.gen::(); + if rand < unmatched_conversions { + let conv = self.gen_conversion(match_key); + self.in_flight.push(conv); + } else if rand < unmatched_conversions + unmatched_impressions { + let imp = self.gen_impression(match_key); + self.in_flight.push(imp); + } else { + let imp = self.gen_impression(match_key); + let conv = self.gen_conversion(match_key); + self.in_flight.push(imp); + self.in_flight.push(conv); + let mut conv_count = 1; + // long-tailed distribution of # of conversions per impression + // will not exceed the configured maximum number of conversions per impression + while conv_count < self.config.max_convs_per_imp.get() + && self.rng.gen::() < subsequent_conversion_prob + { + let conv = self.gen_conversion(match_key); + self.in_flight.push(conv); + conv_count += 1; } } } From 1374efdef6d32871e3bf857131cd70e5ec30aa95 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 18 Sep 2024 22:41:33 -0700 Subject: [PATCH 034/191] Add a test that reproduces saturation panic It required quite a bit of plumbing to make the compact gate stuff work --- ipa-core/src/lib.rs | 3 +- ipa-core/src/protocol/basics/mod.rs | 5 + ipa-core/src/protocol/ipa_prf/mod.rs | 106 ++++++++++++++++++ .../src/protocol/ipa_prf/prf_sharding/mod.rs | 4 +- ipa-core/src/protocol/step.rs | 8 +- ipa-core/src/secret_sharing/into_shares.rs | 2 +- .../src/secret_sharing/vector/transpose.rs | 2 + ipa-core/src/test_fixture/circuit.rs | 2 +- ipa-core/src/test_fixture/mod.rs | 2 +- ipa-core/src/test_fixture/step.rs | 2 +- 10 files changed, 126 insertions(+), 10 deletions(-) diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 8c04fac7a..a4f625a4c 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -118,7 +118,7 @@ pub(crate) mod test_executor { } } -#[cfg(all(test, unit_test, not(feature = "shuttle")))] +#[cfg(all(test, not(feature = "shuttle")))] pub(crate) mod test_executor { use std::future::Future; @@ -137,6 +137,7 @@ pub(crate) mod test_executor { .block_on(f()) } + #[allow(dead_code)] pub fn run(f: F) -> T where F: Fn() -> Fut + Send + Sync + 'static, diff --git a/ipa-core/src/protocol/basics/mod.rs b/ipa-core/src/protocol/basics/mod.rs index d0315b8cf..278aded76 100644 --- a/ipa-core/src/protocol/basics/mod.rs +++ b/ipa-core/src/protocol/basics/mod.rs @@ -149,6 +149,11 @@ impl<'a, B: ShardBinding> BooleanProtocols BooleanProtocols, 3> + for AdditiveShare +{ +} + impl<'a, B: ShardBinding> BooleanProtocols, 32> for AdditiveShare { diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 55d8e7414..a14da50fe 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -721,3 +721,109 @@ pub mod tests { }); } } + +#[cfg(all(test, all(feature = "compact-gate", feature = "in-memory-infra")))] +mod compact_gate_tests { + + use ipa_step::StepNarrow; + + use crate::{ + ff::{ + boolean_array::{BA20, BA5, BA8}, + U128Conversions, + }, + helpers::query::DpMechanism, + protocol::{ + ipa_prf::{oprf_ipa, oprf_padding::PaddingParameters}, + step::{ProtocolGate, ProtocolStep}, + }, + test_executor::run, + test_fixture::{ipa::TestRawDataRecord, Reconstruct, Runner, TestWorld, TestWorldConfig}, + }; + + #[test] + fn saturated_agg() { + const EXPECTED: &[u128] = &[0, 255, 255, 0, 0, 0, 0, 0]; + + run(|| async { + let world = TestWorld::new_with(TestWorldConfig { + initial_gate: Some(ProtocolGate::default().narrow(&ProtocolStep::IpaPrf)), + ..Default::default() + }); + + let records: Vec = vec![ + TestRawDataRecord { + timestamp: 0, + user_id: 12345, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 5, + user_id: 12345, + is_trigger_report: false, + breakdown_key: 2, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 10, + user_id: 12345, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 255, + }, + TestRawDataRecord { + timestamp: 20, + user_id: 12345, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 255, + }, + TestRawDataRecord { + timestamp: 30, + user_id: 12345, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 255, + }, + TestRawDataRecord { + timestamp: 0, + user_id: 68362, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 20, + user_id: 68362, + is_trigger_report: true, + breakdown_key: 1, + trigger_value: 255, + }, + ]; + let dp_params = DpMechanism::NoDp; + let padding_params = PaddingParameters::relaxed(); + + let mut result: Vec<_> = world + .semi_honest(records.into_iter(), |ctx, input_rows| async move { + oprf_ipa::<_, BA5, BA8, BA8, BA20, 5, 32>( + ctx, + input_rows, + None, + dp_params, + padding_params, + ) + .await + .unwrap() + }) + .await + .reconstruct(); + result.truncate(EXPECTED.len()); + assert_eq!( + result.iter().map(|&v| v.as_u128()).collect::>(), + EXPECTED, + ); + }); + } +} diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index c80581ae4..16e9fea14 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -320,14 +320,14 @@ pub struct AttributionOutputs { pub type SecretSharedAttributionOutputs = AttributionOutputs, Replicated>; -#[cfg(all(test, any(unit_test, feature = "shuttle")))] +#[cfg(test)] #[derive(Debug, Clone, Ord, PartialEq, PartialOrd, Eq)] pub struct AttributionOutputsTestInput { pub bk: BK, pub tv: TV, } -#[cfg(all(test, any(unit_test, feature = "shuttle")))] +#[cfg(test)] impl crate::secret_sharing::IntoShares<(Replicated, Replicated)> for AttributionOutputsTestInput where diff --git a/ipa-core/src/protocol/step.rs b/ipa-core/src/protocol/step.rs index 604052222..a9125c84f 100644 --- a/ipa-core/src/protocol/step.rs +++ b/ipa-core/src/protocol/step.rs @@ -9,9 +9,11 @@ pub enum ProtocolStep { IpaPrf, Multiply, PrimeFieldAddition, - #[cfg(any(test, feature = "test-fixture"))] - #[step(count = 10, child = crate::test_fixture::step::TestExecutionStep)] - Test(usize), + /// Steps used in unit tests are grouped under this one. Ideally it should be + /// gated behind test configuration, but it does not work with build.rs that + /// does not enable any features when creating protocol gate file + #[step(child = crate::test_fixture::step::TestExecutionStep)] + Test, /// This step includes all the steps that are currently not linked into a top-level protocol. /// diff --git a/ipa-core/src/secret_sharing/into_shares.rs b/ipa-core/src/secret_sharing/into_shares.rs index ddd83ec3c..a7bde0764 100644 --- a/ipa-core/src/secret_sharing/into_shares.rs +++ b/ipa-core/src/secret_sharing/into_shares.rs @@ -41,7 +41,7 @@ where } } -#[cfg(all(test, unit_test))] +#[cfg(test)] impl IntoShares> for Result where U: IntoShares, diff --git a/ipa-core/src/secret_sharing/vector/transpose.rs b/ipa-core/src/secret_sharing/vector/transpose.rs index a9afa8492..65bbacb53 100644 --- a/ipa-core/src/secret_sharing/vector/transpose.rs +++ b/ipa-core/src/secret_sharing/vector/transpose.rs @@ -614,6 +614,8 @@ impl_transpose_shares_bool_to_ba!(BA32, 32, 256, test_transpose_shares_bool_to_b impl_transpose_shares_bool_to_ba_small!(BA8, 8, 32, test_transpose_shares_bool_to_ba_8x32); // added to support HV = BA32 to hold results when adding Binomial noise impl_transpose_shares_bool_to_ba_small!(BA32, 32, 32, test_transpose_shares_bool_to_ba_32x32); +// // Usage: IPA test case for saturated additions +// impl_transpose_shares_bool_to_ba_small!(BA3, 3, 32, test_transpose_shares_bool_to_ba_3x32); // Usage: Aggregation output tests impl_transpose_shares_bool_to_ba_small!(BA8, 8, 8, test_transpose_shares_bool_to_ba_8x8); diff --git a/ipa-core/src/test_fixture/circuit.rs b/ipa-core/src/test_fixture/circuit.rs index 2e353d3c3..03d7aeff1 100644 --- a/ipa-core/src/test_fixture/circuit.rs +++ b/ipa-core/src/test_fixture/circuit.rs @@ -82,7 +82,7 @@ pub async fn arithmetic( active, ..Default::default() }, - initial_gate: Some(Gate::default().narrow(&ProtocolStep::Test(0))), + initial_gate: Some(Gate::default().narrow(&ProtocolStep::Test)), ..Default::default() }; let world = TestWorld::new_with(config); diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index 2ae05994f..c5ebc446a 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -15,7 +15,7 @@ pub mod hybrid; pub mod ipa; pub mod logging; pub mod metrics; -pub(crate) mod step; +pub mod step; #[cfg(feature = "in-memory-infra")] mod test_gate; diff --git a/ipa-core/src/test_fixture/step.rs b/ipa-core/src/test_fixture/step.rs index a0881c2a6..bb90df607 100644 --- a/ipa-core/src/test_fixture/step.rs +++ b/ipa-core/src/test_fixture/step.rs @@ -2,7 +2,7 @@ use ipa_step_derive::CompactStep; /// Provides a unique per-iteration context in tests. #[derive(CompactStep)] -pub(crate) enum TestExecutionStep { +pub enum TestExecutionStep { #[step(count = 999)] Iter(usize), } From 2facdcbd0266c01d73a8a31910ca1874d87d0973 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 18 Sep 2024 23:17:42 -0700 Subject: [PATCH 035/191] Fix the panic inside addition by using wider compact gate step --- ipa-core/src/protocol/boolean/or.rs | 10 +++++++--- .../ipa_prf/boolean_ops/addition_sequential.rs | 2 +- ipa-core/src/protocol/ipa_prf/boolean_ops/step.rs | 6 +++--- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/ipa-core/src/protocol/boolean/or.rs b/ipa-core/src/protocol/boolean/or.rs index 176cbd239..da07c35cd 100644 --- a/ipa-core/src/protocol/boolean/or.rs +++ b/ipa-core/src/protocol/boolean/or.rs @@ -1,9 +1,11 @@ use std::iter::zip; +use ipa_step::StepNarrow; + use crate::{ error::Error, ff::{boolean::Boolean, Field}, - protocol::{basics::SecureMul, boolean::step::SixteenBitStep, context::Context, RecordId}, + protocol::{basics::SecureMul, boolean::NBitStep, context::Context, Gate, RecordId}, secret_sharing::{ replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd, Linear as LinearSecretSharing, @@ -34,7 +36,7 @@ pub async fn or + SecureMul>( // // Supplying an iterator saves constructing a complete copy of the argument // in memory when it is a uniform constant. -pub async fn bool_or<'a, C, BI, const N: usize>( +pub async fn bool_or<'a, C, S, BI, const N: usize>( ctx: C, record_id: RecordId, a: &BitDecomposed>, @@ -42,17 +44,19 @@ pub async fn bool_or<'a, C, BI, const N: usize>( ) -> Result>, Error> where C: Context, + S: NBitStep, BI: IntoIterator, ::IntoIter: ExactSizeIterator> + Send, Boolean: FieldSimd, AdditiveShare: SecureMul, + Gate: StepNarrow, { let b = b.into_iter(); assert_eq!(a.len(), b.len()); BitDecomposed::try_from( ctx.parallel_join(zip(a.iter(), b).enumerate().map(|(i, (a, b))| { - let ctx = ctx.narrow(&SixteenBitStep::from(i)); + let ctx = ctx.narrow(&S::from(i)); async move { let ab = a.multiply(b, ctx, record_id).await?; Ok::<_, Error>(-ab + a + b) diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs index eff85fe2d..1c55fa578 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs @@ -73,7 +73,7 @@ where .await?; // if carry==1 then {all ones} else {result} - bool_or( + bool_or::<_, S, _, N>( ctx.narrow::(&Step::Select), record_id, &result, diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/step.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/step.rs index c27a8db44..056f5889a 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/step.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/step.rs @@ -1,15 +1,15 @@ use ipa_step_derive::CompactStep; /// FIXME: This step is not generic enough to be used in the `saturated_addition` protocol. -/// It constrains the input to be at most 2 bytes and it will panic in runtime if it is greater +/// It constrains the input to be at most 4 bytes and it will panic in runtime if it is greater /// than that. The issue is that compact gate requires concrete type to be put as child. /// If we ever see it being an issue, we should make a few implementations of this similar to what /// we've done for bit steps #[derive(CompactStep)] pub(crate) enum SaturatedAdditionStep { - #[step(child = crate::protocol::boolean::step::SixteenBitStep)] + #[step(child = crate::protocol::boolean::step::ThirtyTwoBitStep)] Add, - #[step(child = crate::protocol::boolean::step::SixteenBitStep)] + #[step(child = crate::protocol::boolean::step::ThirtyTwoBitStep)] Select, } From 3da73c9eb94d12f8b5a3f9f37b2859d6e294b60e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 18 Sep 2024 23:30:05 -0700 Subject: [PATCH 036/191] Move TestExecutionStep to protocol folder Otherwise, I can't make it work with compact gate `track_steps`. It could be possible if we include stuff from `test_fixture` folder conditionally, but `track_steps` does not support that currently. --- ipa-core/build.rs | 1 - ipa-core/src/net/client/mod.rs | 2 +- ipa-core/src/protocol/step.rs | 9 ++++++++- ipa-core/src/test_fixture/circuit.rs | 4 ++-- ipa-core/src/test_fixture/mod.rs | 1 - ipa-core/src/test_fixture/step.rs | 8 -------- ipa-core/src/test_fixture/test_gate.rs | 2 +- ipa-step/src/gate.rs | 3 ++- ipa-step/src/hash.rs | 1 + 9 files changed, 15 insertions(+), 16 deletions(-) delete mode 100644 ipa-core/src/test_fixture/step.rs diff --git a/ipa-core/build.rs b/ipa-core/build.rs index ed45e74f2..26155f794 100644 --- a/ipa-core/build.rs +++ b/ipa-core/build.rs @@ -27,7 +27,6 @@ track_steps!( dp::step, step, }, - test_fixture::step ); fn main() { diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index e7180b6c5..61fcfece0 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -468,10 +468,10 @@ pub(crate) mod tests { RequestHandler, RoleAssignment, Transport, MESSAGE_PAYLOAD_SIZE_BYTES, }, net::test::TestServer, + protocol::step::TestExecutionStep, query::ProtocolResult, secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, sync::Arc, - test_fixture::step::TestExecutionStep, }; #[tokio::test] diff --git a/ipa-core/src/protocol/step.rs b/ipa-core/src/protocol/step.rs index a9125c84f..c19bdb53b 100644 --- a/ipa-core/src/protocol/step.rs +++ b/ipa-core/src/protocol/step.rs @@ -12,7 +12,7 @@ pub enum ProtocolStep { /// Steps used in unit tests are grouped under this one. Ideally it should be /// gated behind test configuration, but it does not work with build.rs that /// does not enable any features when creating protocol gate file - #[step(child = crate::test_fixture::step::TestExecutionStep)] + #[step(child = TestExecutionStep)] Test, /// This step includes all the steps that are currently not linked into a top-level protocol. @@ -48,3 +48,10 @@ pub enum DeadCodeStep { #[step(child = crate::protocol::ipa_prf::boolean_ops::step::MultiplicationStep)] Multiplication, } + +/// Provides a unique per-iteration context in tests. +#[derive(CompactStep)] +pub enum TestExecutionStep { + #[step(count = 999)] + Iter(usize), +} diff --git a/ipa-core/src/test_fixture/circuit.rs b/ipa-core/src/test_fixture/circuit.rs index 03d7aeff1..5a1ecd67e 100644 --- a/ipa-core/src/test_fixture/circuit.rs +++ b/ipa-core/src/test_fixture/circuit.rs @@ -10,13 +10,13 @@ use crate::{ protocol::{ basics::SecureMul, context::{Context, SemiHonestContext}, - step::ProtocolStep, + step::{ProtocolStep, TestExecutionStep as Step}, Gate, RecordId, }, rand::thread_rng, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, FieldSimd, IntoShares}, seq_join::seq_join, - test_fixture::{step::TestExecutionStep as Step, ReconstructArr, TestWorld, TestWorldConfig}, + test_fixture::{ReconstructArr, TestWorld, TestWorldConfig}, utils::array::zip3, }; diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index c5ebc446a..9ac8eb51f 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -15,7 +15,6 @@ pub mod hybrid; pub mod ipa; pub mod logging; pub mod metrics; -pub mod step; #[cfg(feature = "in-memory-infra")] mod test_gate; diff --git a/ipa-core/src/test_fixture/step.rs b/ipa-core/src/test_fixture/step.rs deleted file mode 100644 index bb90df607..000000000 --- a/ipa-core/src/test_fixture/step.rs +++ /dev/null @@ -1,8 +0,0 @@ -use ipa_step_derive::CompactStep; - -/// Provides a unique per-iteration context in tests. -#[derive(CompactStep)] -pub enum TestExecutionStep { - #[step(count = 999)] - Iter(usize), -} diff --git a/ipa-core/src/test_fixture/test_gate.rs b/ipa-core/src/test_fixture/test_gate.rs index a79a7dfca..a59802ff0 100644 --- a/ipa-core/src/test_fixture/test_gate.rs +++ b/ipa-core/src/test_fixture/test_gate.rs @@ -2,7 +2,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use ipa_step::StepNarrow; -use crate::{protocol::Gate, test_fixture::step::TestExecutionStep}; +use crate::protocol::{step::TestExecutionStep, Gate}; /// This manages the gate information for test runs. Most unit tests want to have multiple runs /// using the same instance of [`TestWorld`], but they don't care about the name of that particular diff --git a/ipa-step/src/gate.rs b/ipa-step/src/gate.rs index 0eac78393..0aa45cc46 100644 --- a/ipa-step/src/gate.rs +++ b/ipa-step/src/gate.rs @@ -29,7 +29,7 @@ fn build_narrows( let short_name = t.rsplit_once("::").map_or_else(|| t.as_ref(), |(_a, b)| b); let msg = format!("unexpected narrow for {gate_name}({{s}}) => {short_name}({{ss}})"); syntax.extend(quote! { - #[allow(clippy::too_many_lines)] + #[allow(clippy::too_many_lines, clippy::unreadable_literal)] impl ::ipa_step::StepNarrow<#ty> for #ident { fn narrow(&self, step: &#ty) -> Self { match self.0 { @@ -87,6 +87,7 @@ fn compact_gate_impl(gate_name: &str) -> TokenStream { let gate_lookup_type = step_hasher.lookup_type(); let mut syntax = quote! { + #[allow(clippy::unreadable_literal)] static STR_LOOKUP: [&str; #step_count] = [#(#gate_names),*]; static GATE_LOOKUP: #gate_lookup_type = #step_hasher diff --git a/ipa-step/src/hash.rs b/ipa-step/src/hash.rs index 6b036dc74..133d3713a 100644 --- a/ipa-step/src/hash.rs +++ b/ipa-step/src/hash.rs @@ -62,6 +62,7 @@ impl ToTokens for HashingSteps { }; struct #lookup_type { + #[allow(clippy::unreadable_literal)] inner: [(u64, u32); #sz] } From 952e277db133778f21dff1e9fd664ca961606b2e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 19 Sep 2024 00:17:34 -0700 Subject: [PATCH 037/191] Add in-memory compact gate tests to the CI --- .github/workflows/check.yml | 5 ++++- scripts/coverage-ci | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 58d37d514..b760ac908 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -145,8 +145,11 @@ jobs: - name: Run arithmetic bench run: cargo bench --bench oneshot_arithmetic --no-default-features --features "enable-benches compact-gate" - - name: Run compact gate tests + - name: Run compact gate tests for HTTP stack run: cargo test --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + + - name: Run in-memory compact gate tests + run: cargo test --features "compact-gate" slow: name: Slow tests env: diff --git a/scripts/coverage-ci b/scripts/coverage-ci index 9bb5e87ea..6a5836c55 100755 --- a/scripts/coverage-ci +++ b/scripts/coverage-ci @@ -19,4 +19,7 @@ done cargo test --bench oneshot_ipa --no-default-features --features "enable-benches compact-gate" -- -n 62 -c 16 cargo test --bench criterion_arithmetic --no-default-features --features "enable-benches compact-gate" +# compact gate + in-memory-infra +cargo test --features "compact-gate" + cargo llvm-cov report "$@" From 35272ac12cc6936cb831473006fcb11b183eff81 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 19 Sep 2024 08:41:57 -0700 Subject: [PATCH 038/191] event gen --- ipa-core/src/test_fixture/hybrid_event_gen.rs | 73 ++++++++++++++++++- 1 file changed, 69 insertions(+), 4 deletions(-) diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs index 73dd7a898..2c104c452 100644 --- a/ipa-core/src/test_fixture/hybrid_event_gen.rs +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -106,7 +106,7 @@ impl EventGenerator { ) { assert!(unmatched_conversions + unmatched_impressions <= 1.0); let match_key = self.rng.gen::(); - let rand = self.rng.gen::(); + let rand = self.rng.gen_range(0.0..1.0); if rand < unmatched_conversions { let conv = self.gen_conversion(match_key); self.in_flight.push(conv); @@ -122,7 +122,7 @@ impl EventGenerator { // long-tailed distribution of # of conversions per impression // will not exceed the configured maximum number of conversions per impression while conv_count < self.config.max_convs_per_imp.get() - && self.rng.gen::() < subsequent_conversion_prob + && self.rng.gen_range(0.0..1.0) < subsequent_conversion_prob { let conv = self.gen_conversion(match_key); self.in_flight.push(conv); @@ -155,12 +155,14 @@ impl Iterator for EventGenerator { if self.in_flight.is_empty() { self.gen_batch(); } - self.in_flight.pop() + Some(self.in_flight.pop().unwrap()) } } #[cfg(all(test, unit_test))] mod tests { + use std::collections::HashMap; + use rand::thread_rng; use super::*; @@ -171,6 +173,69 @@ mod tests { assert_eq!(10, gen.take(10).collect::>().len()); let gen = EventGenerator::with_default_config(thread_rng()); - assert_eq!(59, gen.take(59).collect::>().len()); + assert_eq!(1000, gen.take(1000).collect::>().len()); + } + + #[test] + fn subsequent_convs() { + let gen = EventGenerator::with_default_config(thread_rng()); + let max_convs_per_imp = gen.config.max_convs_per_imp.get(); + let mut match_key_to_event_count = HashMap::new(); + for event in gen.take(10000) { + match event { + TestHybridRecord::TestImpression { match_key, .. } => { + match_key_to_event_count + .entry(match_key) + .and_modify(|count| *count += 1) + .or_insert(1); + } + TestHybridRecord::TestConversion { match_key, .. } => { + match_key_to_event_count + .entry(match_key) + .and_modify(|count| *count += 1) + .or_insert(1); + } + } + } + let histogram_size = usize::try_from(max_convs_per_imp + 2).unwrap(); + let mut histogram: Vec = vec![0; histogram_size]; + for (_, count) in match_key_to_event_count { + histogram[count] += 1; + } + + assert!( + (6470 - histogram[1]).abs() < 200, + "expected {:?} unmatched events, got {:?}", + 647, + histogram[1] + ); + + assert!( + (1370 - histogram[2]).abs() < 100, + "expected {:?} unmatched events, got {:?}", + 137, + histogram[2] + ); + + assert!( + (200 - histogram[3]).abs() < 50, + "expected {:?} unmatched events, got {:?}", + 20, + histogram[3] + ); + + assert!( + (30 - histogram[4]).abs() < 40, + "expected {:?} unmatched events, got {:?}", + 3, + histogram[4] + ); + + assert!( + (0 - histogram[11]).abs() < 10, + "expected {:?} unmatched events, got {:?}", + 0, + histogram[11] + ); } } From 9963b5efbe6e18c7e563d6634c959c9e4593d172 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 19 Sep 2024 08:48:58 -0700 Subject: [PATCH 039/191] Marking dead code as such --- ipa-core/src/test_fixture/hybrid_event_gen.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs index 2c104c452..7a004a062 100644 --- a/ipa-core/src/test_fixture/hybrid_event_gen.rs +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -64,14 +64,14 @@ pub struct EventGenerator { } impl EventGenerator { - #[must_use] + #[allow(dead_code)] pub fn with_default_config(rng: R) -> Self { Self::with_config(rng, Config::default()) } /// # Panics /// If the configuration is not valid. - #[must_use] + #[allow(dead_code)] pub fn with_config(rng: R, config: Config) -> Self { let max_capacity = usize::try_from(config.max_convs_per_imp.get() + 1).unwrap(); Self { From 03dfb30129c462d194758891b81d2cd9b5e99213 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 19 Sep 2024 09:03:09 -0700 Subject: [PATCH 040/191] Kill some dead code --- ipa-core/src/bin/report_collector.rs | 86 ++++++++++------------------ ipa-core/src/test_fixture/ipa.rs | 4 -- 2 files changed, 30 insertions(+), 60 deletions(-) diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 42b765b50..80dbc078d 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -25,7 +25,7 @@ use ipa_core::{ net::MpcHelperClient, report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID}, test_fixture::{ - ipa::{ipa_in_the_clear, CappingOrder, IpaQueryStyle, IpaSecurityModel, TestRawDataRecord}, + ipa::{ipa_in_the_clear, CappingOrder, IpaSecurityModel, TestRawDataRecord}, EventGenerator, EventGeneratorConfig, }, }; @@ -171,7 +171,6 @@ async fn main() -> Result<(), Box> { IpaSecurityModel::SemiHonest, config, &clients, - IpaQueryStyle::Oprf, ) .await? } @@ -182,7 +181,6 @@ async fn main() -> Result<(), Box> { IpaSecurityModel::Malicious, config, &clients, - IpaQueryStyle::Oprf, ) .await? } @@ -195,7 +193,6 @@ async fn main() -> Result<(), Box> { IpaSecurityModel::Malicious, ipa_query_config, &clients, - IpaQueryStyle::Oprf, encrypted_inputs, ) .await? @@ -209,7 +206,6 @@ async fn main() -> Result<(), Box> { IpaSecurityModel::SemiHonest, ipa_query_config, &clients, - IpaQueryStyle::Oprf, encrypted_inputs, ) .await? @@ -245,20 +241,10 @@ fn gen_inputs( Ok(()) } -/// Panics -/// if (security_model, query_style) tuple is undefined -fn get_query_type( - security_model: IpaSecurityModel, - query_style: &IpaQueryStyle, - ipa_query_config: IpaQueryConfig, -) -> QueryType { - match (security_model, query_style) { - (IpaSecurityModel::SemiHonest, IpaQueryStyle::Oprf) => { - QueryType::SemiHonestOprfIpa(ipa_query_config) - } - (IpaSecurityModel::Malicious, IpaQueryStyle::Oprf) => { - QueryType::MaliciousOprfIpa(ipa_query_config) - } +fn get_query_type(security_model: IpaSecurityModel, ipa_query_config: IpaQueryConfig) -> QueryType { + match security_model { + IpaSecurityModel::SemiHonest => QueryType::SemiHonestOprfIpa(ipa_query_config), + IpaSecurityModel::Malicious => QueryType::MaliciousOprfIpa(ipa_query_config), } } @@ -306,10 +292,9 @@ async fn ipa( security_model: IpaSecurityModel, ipa_query_config: IpaQueryConfig, helper_clients: &[MpcHelperClient; 3], - query_style: IpaQueryStyle, encrypted_inputs: &EncryptedInputs, ) -> Result<(), Box> { - let query_type = get_query_type(security_model, &query_style, ipa_query_config); + let query_type = get_query_type(security_model, ipa_query_config); let files = [ &encrypted_inputs.enc_input_file1, @@ -331,21 +316,17 @@ async fn ipa( .expect("Unable to create query!"); tracing::info!("Starting query for OPRF"); - let actual = match query_style { - IpaQueryStyle::Oprf => { - // the value for histogram values (BA32) must be kept in sync with the server-side - // implementation, otherwise a runtime reconstruct error will be generated. - // see ipa-core/src/query/executor.rs - run_query_and_validate::( - encrypted_oprf_report_streams.streams, - encrypted_oprf_report_streams.query_size, - helper_clients, - query_id, - ipa_query_config, - ) - .await - } - }; + // the value for histogram values (BA32) must be kept in sync with the server-side + // implementation, otherwise a runtime reconstruct error will be generated. + // see ipa-core/src/query/executor.rs + let actual = run_query_and_validate::( + encrypted_oprf_report_streams.streams, + encrypted_oprf_report_streams.query_size, + helper_clients, + query_id, + ipa_query_config, + ) + .await; if let Some(ref path) = args.output_file { write_ipa_output_file(path, &actual)?; @@ -361,10 +342,9 @@ async fn ipa_test( security_model: IpaSecurityModel, ipa_query_config: IpaQueryConfig, helper_clients: &[MpcHelperClient; 3], - query_style: IpaQueryStyle, ) -> Result<(), Box> { let input = InputSource::from(&args.input); - let query_type = get_query_type(security_model, &query_style, ipa_query_config); + let query_type = get_query_type(security_model, ipa_query_config); let input_rows = input.iter::().collect::>(); let query_config = QueryConfig { @@ -383,9 +363,7 @@ async fn ipa_test( ipa_query_config.per_user_credit_cap, ipa_query_config.attribution_window_seconds, ipa_query_config.max_breakdown_key, - &(match query_style { - IpaQueryStyle::Oprf => CappingOrder::CapMostRecentFirst, - }), + &CappingOrder::CapMostRecentFirst, ); // pad the output vector to the max breakdown key, to make sure it is aligned with the MPC results @@ -401,21 +379,17 @@ async fn ipa_test( let Some(key_registries) = key_registries.init_from(network) else { panic!("could not load network file") }; - let actual = match query_style { - IpaQueryStyle::Oprf => { - // the value for histogram values (BA32) must be kept in sync with the server-side - // implementation, otherwise a runtime reconstruct error will be generated. - // see ipa-core/src/query/executor.rs - playbook_oprf_ipa::( - input_rows, - helper_clients, - query_id, - ipa_query_config, - Some((DEFAULT_KEY_ID, key_registries)), - ) - .await - } - }; + // the value for histogram values (BA32) must be kept in sync with the server-side + // implementation, otherwise a runtime reconstruct error will be generated. + // see ipa-core/src/query/executor.rs + let actual = playbook_oprf_ipa::( + input_rows, + helper_clients, + query_id, + ipa_query_config, + Some((DEFAULT_KEY_ID, key_registries)), + ) + .await; if let Some(ref path) = args.output_file { write_ipa_output_file(path, &actual)?; diff --git a/ipa-core/src/test_fixture/ipa.rs b/ipa-core/src/test_fixture/ipa.rs index fd91081db..c38a4bc9f 100644 --- a/ipa-core/src/test_fixture/ipa.rs +++ b/ipa-core/src/test_fixture/ipa.rs @@ -27,10 +27,6 @@ pub enum IpaSecurityModel { Malicious, } -pub enum IpaQueryStyle { - Oprf, -} - #[derive(Debug, Clone, Ord, PartialEq, PartialOrd, Eq)] pub struct TestRawDataRecord { pub timestamp: u64, From 6d7c1d237cbe0a2dc9e9c2cb56284d1373f05c98 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 19 Sep 2024 10:03:06 -0700 Subject: [PATCH 041/191] Remove commented code --- ipa-core/src/secret_sharing/vector/transpose.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/ipa-core/src/secret_sharing/vector/transpose.rs b/ipa-core/src/secret_sharing/vector/transpose.rs index 65bbacb53..a9afa8492 100644 --- a/ipa-core/src/secret_sharing/vector/transpose.rs +++ b/ipa-core/src/secret_sharing/vector/transpose.rs @@ -614,8 +614,6 @@ impl_transpose_shares_bool_to_ba!(BA32, 32, 256, test_transpose_shares_bool_to_b impl_transpose_shares_bool_to_ba_small!(BA8, 8, 32, test_transpose_shares_bool_to_ba_8x32); // added to support HV = BA32 to hold results when adding Binomial noise impl_transpose_shares_bool_to_ba_small!(BA32, 32, 32, test_transpose_shares_bool_to_ba_32x32); -// // Usage: IPA test case for saturated additions -// impl_transpose_shares_bool_to_ba_small!(BA3, 3, 32, test_transpose_shares_bool_to_ba_3x32); // Usage: Aggregation output tests impl_transpose_shares_bool_to_ba_small!(BA8, 8, 8, test_transpose_shares_bool_to_ba_8x8); From 375fed77d8d0e8943cd7f3d8ab0854657d92922f Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 19 Sep 2024 11:04:59 -0700 Subject: [PATCH 042/191] improving coverage --- ipa-core/src/test_fixture/hybrid_event_gen.rs | 158 ++++++++++++++++-- 1 file changed, 148 insertions(+), 10 deletions(-) diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs index 7a004a062..9bc30a676 100644 --- a/ipa-core/src/test_fixture/hybrid_event_gen.rs +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -1,4 +1,4 @@ -use std::num::{NonZeroU32, NonZeroU64}; +use std::num::NonZeroU32; use rand::Rng; @@ -16,8 +16,6 @@ pub enum ConversionDistribution { #[derive(Debug, Clone)] #[cfg_attr(feature = "clap", derive(clap::Args))] pub struct Config { - #[cfg_attr(feature = "clap", arg(long, default_value = "1000000000000"))] - pub num_events: NonZeroU64, #[cfg_attr(feature = "clap", arg(long, default_value = "5"))] pub max_conversion_value: NonZeroU32, #[cfg_attr(feature = "clap", arg(long, default_value = "20"))] @@ -31,7 +29,7 @@ pub struct Config { impl Default for Config { fn default() -> Self { - Self::new(1_000, 5, 20, 10) + Self::new(5, 20, 10, ConversionDistribution::Default) } } @@ -42,17 +40,16 @@ impl Config { /// If any argument is 0. #[must_use] pub fn new( - num_events: u64, max_conversion_value: u32, max_breakdown_key: u32, max_convs_per_imp: u32, + conversion_distribution: ConversionDistribution, ) -> Self { Self { - num_events: NonZeroU64::try_from(num_events).unwrap(), max_conversion_value: NonZeroU32::try_from(max_conversion_value).unwrap(), max_breakdown_key: NonZeroU32::try_from(max_breakdown_key).unwrap(), max_convs_per_imp: NonZeroU32::try_from(max_convs_per_imp).unwrap(), - conversion_distribution: ConversionDistribution::Default, + conversion_distribution, } } } @@ -161,7 +158,7 @@ impl Iterator for EventGenerator { #[cfg(all(test, unit_test))] mod tests { - use std::collections::HashMap; + use std::collections::{HashMap, HashSet}; use rand::thread_rng; @@ -177,11 +174,11 @@ mod tests { } #[test] - fn subsequent_convs() { + fn default_config() { let gen = EventGenerator::with_default_config(thread_rng()); let max_convs_per_imp = gen.config.max_convs_per_imp.get(); let mut match_key_to_event_count = HashMap::new(); - for event in gen.take(10000) { + for event in gen.take(10_000) { match event { TestHybridRecord::TestImpression { match_key, .. } => { match_key_to_event_count @@ -238,4 +235,145 @@ mod tests { histogram[11] ); } + + #[test] + fn lots_of_repeat_conversions() { + const MAX_CONVS_PER_IMP: u32 = 10; + const MAX_BREAKDOWN_KEY: u32 = 20; + const MAX_VALUE: u32 = 3; + let gen = EventGenerator::with_config( + thread_rng(), + Config::new( + MAX_VALUE, + MAX_BREAKDOWN_KEY, + MAX_CONVS_PER_IMP, + ConversionDistribution::LotsOfConversionsPerImpression, + ), + ); + let max_convs_per_imp = gen.config.max_convs_per_imp.get(); + let mut match_key_to_event_count = HashMap::new(); + for event in gen.take(100_000) { + match event { + TestHybridRecord::TestImpression { + match_key, + breakdown_key, + } => { + assert!(breakdown_key <= MAX_BREAKDOWN_KEY); + match_key_to_event_count + .entry(match_key) + .and_modify(|count| *count += 1) + .or_insert(1); + } + TestHybridRecord::TestConversion { match_key, value } => { + assert!(value <= MAX_VALUE); + match_key_to_event_count + .entry(match_key) + .and_modify(|count| *count += 1) + .or_insert(1); + } + } + } + let histogram_size = usize::try_from(max_convs_per_imp + 2).unwrap(); + let mut histogram: Vec = vec![0; histogram_size]; + for (_, count) in match_key_to_event_count { + histogram[count] += 1; + } + println!("Histogram: {:?}", histogram); + + assert!( + (30_032 - histogram[1]).abs() < 800, + "expected {:?} unmatched events, got {:?}", + 30_032, + histogram[1] + ); + + assert!( + (2_572 - histogram[2]).abs() < 300, + "expected {:?} unmatched events, got {:?}", + 2_572, + histogram[2] + ); + + assert!( + (2_048 - histogram[3]).abs() < 200, + "expected {:?} unmatched events, got {:?}", + 2_048, + histogram[3] + ); + + assert!( + (1_650 - histogram[4]).abs() < 100, + "expected {:?} unmatched events, got {:?}", + 1_650, + histogram[4] + ); + + assert!( + (1_718 - histogram[11]).abs() < 100, + "expected {:?} unmatched events, got {:?}", + 1_718, + histogram[11] + ); + } + + #[test] + fn only_impressions_config() { + const NUM_EVENTS: usize = 100; + const MAX_CONVS_PER_IMP: u32 = 1; + const MAX_BREAKDOWN_KEY: u32 = 10; + let gen = EventGenerator::with_config( + thread_rng(), + Config::new( + 10, + MAX_BREAKDOWN_KEY, + MAX_CONVS_PER_IMP, + ConversionDistribution::OnlyImpressions, + ), + ); + let mut match_keys = HashSet::new(); + for event in gen.take(NUM_EVENTS) { + match event { + TestHybridRecord::TestImpression { + match_key, + breakdown_key, + } => { + assert!(breakdown_key <= MAX_BREAKDOWN_KEY); + match_keys.insert(match_key); + } + TestHybridRecord::TestConversion { .. } => { + panic!("No conversions should be generated"); + } + } + } + assert_eq!(match_keys.len(), NUM_EVENTS); + } + + #[test] + fn only_conversions_config() { + const NUM_EVENTS: usize = 100; + const MAX_CONVS_PER_IMP: u32 = 1; + const MAX_VALUE: u32 = 10; + let gen = EventGenerator::with_config( + thread_rng(), + Config::new( + MAX_VALUE, + 10, + MAX_CONVS_PER_IMP, + ConversionDistribution::OnlyConversions, + ), + ); + let mut match_keys = HashSet::new(); + for event in gen.take(NUM_EVENTS) { + match event { + TestHybridRecord::TestConversion { match_key, value } => { + assert!(value <= MAX_VALUE); + match_keys.insert(match_key); + } + TestHybridRecord::TestImpression { .. } => { + panic!("No impressions should be generated"); + } + } + } + assert_eq!(match_keys.len(), NUM_EVENTS); + } } From 7b65ea8df3bb1a2c35a987e47d1ca68d78d06b6a Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 19 Sep 2024 11:14:04 -0700 Subject: [PATCH 043/191] Clippy --- ipa-core/src/test_fixture/hybrid_event_gen.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs index 9bc30a676..8d44d38b3 100644 --- a/ipa-core/src/test_fixture/hybrid_event_gen.rs +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -278,7 +278,6 @@ mod tests { for (_, count) in match_key_to_event_count { histogram[count] += 1; } - println!("Histogram: {:?}", histogram); assert!( (30_032 - histogram[1]).abs() < 800, From 6570672a53b9d3ebcc7ba207e6ecddcffb3ad9b4 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 19 Sep 2024 11:24:40 -0700 Subject: [PATCH 044/191] Improve pre-commit runtime by running fewer checks Our observation it takes 10+ minutes for all tests to pass and it is blocking people from working on other things. This change intends to keep only basic checks where we see the majority of failures: formatting, clippy and basic tests. It is still possible to run all checks by setting `EXEC_SLOW_TESTS` environment variable to 1 --- pre-commit | 47 ++++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/pre-commit b/pre-commit index 713b614d2..3edbe2781 100755 --- a/pre-commit +++ b/pre-commit @@ -82,40 +82,41 @@ check() { fi } -check "Benchmark compilation" \ - cargo build --benches --no-default-features --features "enable-benches compact-gate" - check "Clippy checks" \ - cargo clippy --tests -- -D warnings + cargo clippy --features="cli test-fixture" --tests -- -D warnings -check "Clippy concurrency checks" \ - cargo clippy --tests --features shuttle -- -D warnings +check "Tests" \ + cargo test --features="cli test-fixture" -check "Clippy web checks" \ - cargo clippy --tests --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" -- -D warnings +if [ -z "$EXEC_SLOW_TESTS" ] +then -# The tests here need to be kept in sync with scripts/coverage-ci. + check "Benchmark compilation" \ + cargo build --benches --no-default-features --features "enable-benches compact-gate" -check "Tests" \ - cargo test --features="cli test-fixture" + check "Clippy concurrency checks" \ + cargo clippy --tests --features shuttle -- -D warnings -check "Web tests" \ - cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + check "Clippy web checks" \ + cargo clippy --tests --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" -- -D warnings -check "Web tests (descriptive gate)" \ - cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture" + # The tests here need to be kept in sync with scripts/coverage-ci. -check "Concurrency tests" \ - cargo test -p ipa-core --release --features "shuttle multi-threading" + check "Web tests" \ + cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" -check "IPA benchmark" \ - cargo bench --bench oneshot_ipa --no-default-features --features="enable-benches compact-gate" -- -n 62 -c 16 + check "Web tests (descriptive gate)" \ + cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture" -check "Arithmetic circuit benchmark" \ - cargo bench --bench oneshot_arithmetic --no-default-features --features "enable-benches compact-gate" + check "Concurrency tests" \ + cargo test -p ipa-core --release --features "shuttle multi-threading" + + check "IPA benchmark" \ + cargo bench --bench oneshot_ipa --no-default-features --features="enable-benches compact-gate" -- -n 62 -c 16 + + check "Arithmetic circuit benchmark" \ + cargo bench --bench oneshot_arithmetic --no-default-features --features "enable-benches compact-gate" -if [ -z "$EXEC_SLOW_TESTS" ] -then check "Slow tests" \ cargo test --release --test "*" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" fi From 95384a1becd6ccb32b1a34b4a7e8d0fb0f3fce99 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 19 Sep 2024 13:27:47 -0700 Subject: [PATCH 045/191] Script to generate test data for hybrid --- ipa-core/src/bin/report_collector.rs | 47 +++++++++++++++++++++++++++- ipa-core/src/cli/csv.rs | 23 ++++++++++++++ ipa-core/src/test_fixture/mod.rs | 5 ++- 3 files changed, 73 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 80dbc078d..50a938de0 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -25,8 +25,9 @@ use ipa_core::{ net::MpcHelperClient, report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID}, test_fixture::{ + hybrid::{hybrid_in_the_clear, TestHybridRecord}, ipa::{ipa_in_the_clear, CappingOrder, IpaSecurityModel, TestRawDataRecord}, - EventGenerator, EventGeneratorConfig, + EventGenerator, EventGeneratorConfig, HybridEventGenerator, HybridGeneratorConfig, }, }; use rand::{distributions::Alphanumeric, rngs::StdRng, thread_rng, Rng}; @@ -96,6 +97,18 @@ enum ReportCollectorCommand { #[clap(flatten)] gen_args: EventGeneratorConfig, }, + GenHybridInputs { + /// Number of records to generate + #[clap(long, short = 'n')] + count: u32, + + /// The seed for random generator. + #[clap(long, short = 's')] + seed: Option, + + #[clap(flatten)] + gen_args: HybridGeneratorConfig, + }, /// Execute OPRF IPA in a semi-honest majority setting with known test data /// and compare results against expectation SemiHonestOprfIpaTest(IpaQueryConfig), @@ -164,6 +177,11 @@ async fn main() -> Result<(), Box> { seed, gen_args, } => gen_inputs(count, seed, args.output_file, gen_args)?, + ReportCollectorCommand::GenHybridInputs { + count, + seed, + gen_args, + } => gen_hybrid_inputs(count, seed, args.output_file, gen_args)?, ReportCollectorCommand::SemiHonestOprfIpaTest(config) => { ipa_test( &args, @@ -215,6 +233,33 @@ async fn main() -> Result<(), Box> { Ok(()) } +fn gen_hybrid_inputs( + count: u32, + seed: Option, + output_file: Option, + args: HybridGeneratorConfig, +) -> io::Result<()> { + let rng = seed + .map(StdRng::seed_from_u64) + .unwrap_or_else(StdRng::from_entropy); + let event_gen = HybridEventGenerator::with_config(rng, args) + .take(count as usize) + .collect::>(); + + let mut writer: Box = if let Some(path) = output_file { + Box::new(OpenOptions::new().write(true).create_new(true).open(path)?) + } else { + Box::new(stdout().lock()) + }; + + for event in event_gen { + event.to_csv(&mut writer)?; + writer.write_all(b"\n")?; + } + + Ok(()) +} + fn gen_inputs( count: u32, seed: Option, diff --git a/ipa-core/src/cli/csv.rs b/ipa-core/src/cli/csv.rs index 37772ac55..a69523950 100644 --- a/ipa-core/src/cli/csv.rs +++ b/ipa-core/src/cli/csv.rs @@ -20,3 +20,26 @@ impl Serializer for crate::test_fixture::ipa::TestRawDataRecord { Ok(()) } } + +#[cfg(any(test, feature = "test-fixture"))] +impl Serializer for crate::test_fixture::hybrid::TestHybridRecord { + fn to_csv(&self, buf: &mut W) -> io::Result<()> { + match self { + crate::test_fixture::hybrid::TestHybridRecord::TestImpression { + match_key, + breakdown_key, + } => { + write!(buf, "i,")?; + write!(buf, "{match_key},")?; + write!(buf, "{breakdown_key},")?; + } + crate::test_fixture::hybrid::TestHybridRecord::TestConversion { match_key, value } => { + write!(buf, "c,")?; + write!(buf, "{match_key},")?; + write!(buf, "{value}")?; + } + } + + Ok(()) + } +} diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index 54df493ee..086c62ed2 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -12,7 +12,7 @@ mod app; pub mod circuit; mod event_gen; pub mod hybrid; -mod hybrid_event_gen; +pub mod hybrid_event_gen; pub mod ipa; pub mod logging; pub mod metrics; @@ -26,6 +26,9 @@ use std::fmt::Debug; pub use app::TestApp; pub use event_gen::{Config as EventGeneratorConfig, EventGenerator}; use futures::TryFuture; +pub use hybrid_event_gen::{ + Config as HybridGeneratorConfig, EventGenerator as HybridEventGenerator, +}; use rand::{distributions::Standard, prelude::Distribution, rngs::mock::StepRng}; use rand_core::{CryptoRng, RngCore}; pub use sharing::{get_bits, into_bits, Reconstruct, ReconstructArr}; From 3004de350904b71c1f4ea925b5abda87d8cb6322 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 19 Sep 2024 13:31:38 -0700 Subject: [PATCH 046/191] Fix clippy errors --- ipa-core/src/cli/crypto.rs | 48 +++++++++++------------------- ipa-core/src/cli/playbook/input.rs | 6 ++-- ipa-core/src/hpke/registry.rs | 8 ++--- 3 files changed, 23 insertions(+), 39 deletions(-) diff --git a/ipa-core/src/cli/crypto.rs b/ipa-core/src/cli/crypto.rs index b66b6ee24..0c652b01d 100644 --- a/ipa-core/src/cli/crypto.rs +++ b/ipa-core/src/cli/crypto.rs @@ -539,7 +539,7 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" async fn encrypt_and_execute_query() { const EXPECTED: &[u128] = &[0, 8, 5]; - let records: Vec = vec![ + let records = vec![ TestRawDataRecord { timestamp: 0, user_id: 12345, @@ -587,62 +587,50 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" let mut input_file = NamedTempFile::new().unwrap(); for event in records { - let _ = event.to_csv(input_file.as_file_mut()); + event.to_csv(input_file.as_file_mut()).unwrap(); writeln!(input_file.as_file()).unwrap(); } - input_file.as_file_mut().flush().unwrap(); + input_file.flush().unwrap(); let output_dir = tempdir().unwrap(); let network_file = write_network_file(); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); + encrypt(&build_encrypt_args(input_file.path(), output_dir.path(), network_file.path())).unwrap(); let files = [ &output_dir.path().join("helper1.enc"), &output_dir.path().join("helper2.enc"), &output_dir.path().join("helper3.enc"), ]; - let encrypted_oprf_report_streams = EncryptedOprfReportStreams::from(files); let world = TestWorld::default(); - let contexts = world.contexts(); - - let mk_private_keys = vec![ - hex::decode("53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff") - .expect("manually provided for test"), - hex::decode("3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569") - .expect("manually provided for test"), - hex::decode("1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7") - .expect("manually provided for test"), + + let mk_private_keys = [ + "53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff", + "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", + "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", ]; #[allow(clippy::large_futures)] let results = join3v( - encrypted_oprf_report_streams + EncryptedOprfReportStreams::from(files) .streams .into_iter() - .zip(contexts) - .zip(mk_private_keys) + .zip(world.contexts()) + .zip(mk_private_keys.into_iter()) .map(|((input, ctx), mk_private_key)| { + let mk_private_key = hex::decode(mk_private_key) + .map(|bytes| IpaPrivateKey::from_bytes(&bytes).unwrap()) + .unwrap(); let query_config = IpaQueryConfig { - per_user_credit_cap: 8, - attribution_window_seconds: None, max_breakdown_key: 3, with_dp: 0, epsilon: 1.0, - plaintext_match_keys: false, + ..Default::default() }; - let private_registry = - Arc::new(KeyRegistry::::from_keys([PrivateKeyOnly( - IpaPrivateKey::from_bytes(&mk_private_key) - .expect("manually constructed for test"), - )])); - - OprfIpaQuery::<_, BA16, KeyRegistry>::new( + OprfIpaQuery::<_, BA16, _>::new( query_config, - private_registry, + Arc::new(KeyRegistry::from_keys([PrivateKeyOnly(mk_private_key)])), ) .execute(ctx, query_size, input) }), diff --git a/ipa-core/src/cli/playbook/input.rs b/ipa-core/src/cli/playbook/input.rs index 70efa3bcb..b7d31484d 100644 --- a/ipa-core/src/cli/playbook/input.rs +++ b/ipa-core/src/cli/playbook/input.rs @@ -203,13 +203,13 @@ mod tests { } #[test] - #[should_panic] + #[should_panic(expected = "ParseIntError")] fn parse_negative() { Fp31::from_str("-1"); } #[test] - #[should_panic] + #[should_panic(expected = "ParseIntError")] fn parse_empty() { Fp31::from_str(""); } @@ -229,7 +229,7 @@ mod tests { } #[test] - #[should_panic] + #[should_panic(expected = "ParseIntError")] fn tuple_parse_error() { <(Fp31, Fp31)>::from_str("20,"); } diff --git a/ipa-core/src/hpke/registry.rs b/ipa-core/src/hpke/registry.rs index 281f4af47..283d1fbd0 100644 --- a/ipa-core/src/hpke/registry.rs +++ b/ipa-core/src/hpke/registry.rs @@ -93,13 +93,9 @@ impl KeyRegistry { Self { keys: Box::new([]) } } - pub fn from_keys>(pairs: [I; N]) -> Self { + pub fn from_keys(pairs: [K; N]) -> Self { Self { - keys: pairs - .into_iter() - .map(Into::into) - .collect::>() - .into_boxed_slice(), + keys: pairs.into_iter().collect::>().into_boxed_slice(), } } From 1137955dbaa787660686d909cb0602551faac9ee Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 19 Sep 2024 13:32:08 -0700 Subject: [PATCH 047/191] Cover more code with Clippy in CI --- .github/workflows/check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 58d37d514..dbbc9246b 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -53,7 +53,7 @@ jobs: - name: Clippy if: ${{ success() || failure() }} - run: cargo clippy --tests + run: cargo clippy --features "cli test-fixture" --tests - name: Clippy concurrency tests if: ${{ success() || failure() }} From e9e36175ef0d604e2e4000b44368947853458384 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 19 Sep 2024 13:32:55 -0700 Subject: [PATCH 048/191] Formatting --- ipa-core/src/cli/crypto.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/cli/crypto.rs b/ipa-core/src/cli/crypto.rs index 0c652b01d..b66b182a6 100644 --- a/ipa-core/src/cli/crypto.rs +++ b/ipa-core/src/cli/crypto.rs @@ -594,7 +594,12 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" let output_dir = tempdir().unwrap(); let network_file = write_network_file(); - encrypt(&build_encrypt_args(input_file.path(), output_dir.path(), network_file.path())).unwrap(); + encrypt(&build_encrypt_args( + input_file.path(), + output_dir.path(), + network_file.path(), + )) + .unwrap(); let files = [ &output_dir.path().join("helper1.enc"), From 599f3c3fd472722599797fdfabc54596dd8afb1e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 19 Sep 2024 13:40:02 -0700 Subject: [PATCH 049/191] Fix more clippy issues --- ipa-core/src/cli/crypto.rs | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/ipa-core/src/cli/crypto.rs b/ipa-core/src/cli/crypto.rs index b66b182a6..9c8fe7a4b 100644 --- a/ipa-core/src/cli/crypto.rs +++ b/ipa-core/src/cli/crypto.rs @@ -537,7 +537,7 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" #[tokio::test] async fn encrypt_and_execute_query() { - const EXPECTED: &[u128] = &[0, 8, 5]; + const EXPECTED: &[u128] = &[0, 2, 5]; let records = vec![ TestRawDataRecord { @@ -568,20 +568,6 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" breakdown_key: 0, trigger_value: 2, }, - TestRawDataRecord { - timestamp: 20, - user_id: 68362, - is_trigger_report: false, - breakdown_key: 1, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 30, - user_id: 68362, - is_trigger_report: true, - breakdown_key: 1, - trigger_value: 7, - }, ]; let query_size = QuerySize::try_from(records.len()).unwrap(); let mut input_file = NamedTempFile::new().unwrap(); From 7ecbf5ab0907485b9ab35530d96fbaf6d650ee10 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Thu, 19 Sep 2024 13:49:13 -0700 Subject: [PATCH 050/191] hybrid-structs: add struct for hyrbid reports, build them from existing encrypted orpf reports --- ipa-core/src/report/hybrid.rs | 196 ++++++++++++++++++++++ ipa-core/src/{report.rs => report/ipa.rs} | 2 +- ipa-core/src/report/mod.rs | 3 + 3 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 ipa-core/src/report/hybrid.rs rename ipa-core/src/{report.rs => report/ipa.rs} (99%) create mode 100644 ipa-core/src/report/mod.rs diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs new file mode 100644 index 000000000..a87fa47bb --- /dev/null +++ b/ipa-core/src/report/hybrid.rs @@ -0,0 +1,196 @@ +use std::{ + marker::PhantomData, + ops::{Add, Deref}, +}; + +use generic_array::ArrayLength; +use typenum::{Sum, U16}; + +use crate::{ + ff::{boolean_array::BA64, Serializable}, + hpke::PrivateKeyRegistry, + report::{EncryptedOprfReport, EventType, InvalidReportError, KeyIdentifier}, + secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, SharedValue}, +}; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct HybridImpressionReport +where + BK: SharedValue, +{ + match_key: Replicated, + breakdown_key: Replicated, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct HybridConversionReport +where + V: SharedValue, + TS: SharedValue, +{ + match_key: Replicated, + value: Replicated, + _phantom: PhantomData, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum HybridReport +where + BK: SharedValue, + V: SharedValue, + TS: SharedValue, +{ + Impression(HybridImpressionReport), + Conversion(HybridConversionReport), +} + +#[allow(dead_code)] +pub struct HybridImpressionInfo<'a> { + pub key_id: KeyIdentifier, + pub helper_origin: &'a str, +} + +#[allow(dead_code)] +pub struct HybridConversionInfo<'a> { + pub key_id: KeyIdentifier, + pub helper_origin: &'a str, + pub converion_site_domain: &'a str, + pub timestamp: u64, + pub epsilon: f64, + pub sensitivity: f64, +} + +#[allow(dead_code)] +pub enum HybridInfo<'a> { + Impression(HybridImpressionInfo<'a>), + Conversion(HybridConversionInfo<'a>), +} + +impl HybridReport +where + BK: SharedValue, + V: SharedValue, + TS: SharedValue, // this is only needed for the backpart from EncryptedOprfReport + Replicated: Serializable, + Replicated: Serializable, + Replicated: Serializable, + as Serializable>::Size: Add< as Serializable>::Size>, + Sum< as Serializable>::Size, as Serializable>::Size>: + Add< as Serializable>::Size>, + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >: Add, + Sum< + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >, + U16, + >: ArrayLength, +{ + /// ## Errors + /// If the report contents are invalid. + pub fn from_bytes>( + data: B, + key_registry: &P, + ) -> Result { + let encrypted_oprf_report = EncryptedOprfReport::::from_bytes(data)?; + let oprf_report = encrypted_oprf_report.decrypt(key_registry)?; + match oprf_report.event_type { + EventType::Source => Ok(Self::Impression(HybridImpressionReport { + match_key: oprf_report.match_key, + breakdown_key: oprf_report.breakdown_key, + })), + EventType::Trigger => Ok(Self::Conversion(HybridConversionReport { + match_key: oprf_report.match_key, + value: oprf_report.trigger_value, + _phantom: PhantomData::, + })), + } + } +} + +#[cfg(test)] +mod test { + use std::marker::PhantomData; + + use rand::{distributions::Alphanumeric, rngs::ThreadRng, thread_rng, Rng}; + + use super::{HybridConversionReport, HybridImpressionReport, HybridReport}; + use crate::{ + ff::boolean_array::{BA20, BA3, BA8}, + hpke::{KeyPair, KeyRegistry}, + report::{EventType, OprfReport}, + secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, + }; + + fn build_oprf_report(event_type: EventType, rng: &mut ThreadRng) -> OprfReport { + OprfReport:: { + match_key: AdditiveShare::new(rng.gen(), rng.gen()), + timestamp: AdditiveShare::new(rng.gen(), rng.gen()), + breakdown_key: AdditiveShare::new(rng.gen(), rng.gen()), + trigger_value: AdditiveShare::new(rng.gen(), rng.gen()), + event_type, + epoch: rng.gen(), + site_domain: (rng) + .sample_iter(Alphanumeric) + .map(char::from) + .take(10) + .collect(), + } + } + + #[test] + fn convert_to_hybrid_impression_report() { + let mut rng = thread_rng(); + + let b = EventType::Source; + + let oprf_report = build_oprf_report(b, &mut rng); + let hybrid_report = + HybridReport::Impression::(HybridImpressionReport:: { + match_key: oprf_report.match_key.clone(), + breakdown_key: oprf_report.breakdown_key.clone(), + }); + + let key_registry = KeyRegistry::::random(1, &mut rng); + let key_id = 0; + + let enc_report_bytes = oprf_report + .encrypt(key_id, &key_registry, &mut rng) + .unwrap(); + let hybrid_report2 = + HybridReport::::from_bytes(enc_report_bytes.as_slice(), &key_registry) + .unwrap(); + + assert_eq!(hybrid_report, hybrid_report2); + } + + #[test] + fn convert_to_hybrid_report() { + let mut rng = thread_rng(); + + let b = EventType::Trigger; + + let oprf_report = build_oprf_report(b, &mut rng); + let hybrid_report = + HybridReport::Conversion::(HybridConversionReport:: { + match_key: oprf_report.match_key.clone(), + value: oprf_report.trigger_value.clone(), + _phantom: PhantomData::, + }); + + let key_registry = KeyRegistry::::random(1, &mut rng); + let key_id = 0; + + let enc_report_bytes = oprf_report + .encrypt(key_id, &key_registry, &mut rng) + .unwrap(); + let hybrid_report2 = + HybridReport::::from_bytes(enc_report_bytes.as_slice(), &key_registry) + .unwrap(); + + assert_eq!(hybrid_report, hybrid_report2); + } +} diff --git a/ipa-core/src/report.rs b/ipa-core/src/report/ipa.rs similarity index 99% rename from ipa-core/src/report.rs rename to ipa-core/src/report/ipa.rs index a9da93454..29ecf78ea 100644 --- a/ipa-core/src/report.rs +++ b/ipa-core/src/report/ipa.rs @@ -49,7 +49,7 @@ use crate::{ }; // TODO(679): This needs to come from configuration. -static HELPER_ORIGIN: &str = "github.com/private-attribution"; +pub static HELPER_ORIGIN: &str = "github.com/private-attribution"; pub type KeyIdentifier = u8; pub const DEFAULT_KEY_ID: KeyIdentifier = 0; diff --git a/ipa-core/src/report/mod.rs b/ipa-core/src/report/mod.rs new file mode 100644 index 000000000..28fd9e683 --- /dev/null +++ b/ipa-core/src/report/mod.rs @@ -0,0 +1,3 @@ +pub mod ipa; +pub use self::ipa::*; +pub mod hybrid; From 7fd93e156cbf6509a6682269bd22adbc5bd4fb33 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 19 Sep 2024 13:50:30 -0700 Subject: [PATCH 051/191] no more trailing endlines --- ipa-core/src/bin/report_collector.rs | 5 +---- ipa-core/src/cli/csv.rs | 8 ++------ 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 50a938de0..a642b7bfc 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -25,7 +25,6 @@ use ipa_core::{ net::MpcHelperClient, report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID}, test_fixture::{ - hybrid::{hybrid_in_the_clear, TestHybridRecord}, ipa::{ipa_in_the_clear, CappingOrder, IpaSecurityModel, TestRawDataRecord}, EventGenerator, EventGeneratorConfig, HybridEventGenerator, HybridGeneratorConfig, }, @@ -242,9 +241,7 @@ fn gen_hybrid_inputs( let rng = seed .map(StdRng::seed_from_u64) .unwrap_or_else(StdRng::from_entropy); - let event_gen = HybridEventGenerator::with_config(rng, args) - .take(count as usize) - .collect::>(); + let event_gen = HybridEventGenerator::with_config(rng, args).take(count as usize); let mut writer: Box = if let Some(path) = output_file { Box::new(OpenOptions::new().write(true).create_new(true).open(path)?) diff --git a/ipa-core/src/cli/csv.rs b/ipa-core/src/cli/csv.rs index a69523950..621c9b352 100644 --- a/ipa-core/src/cli/csv.rs +++ b/ipa-core/src/cli/csv.rs @@ -29,14 +29,10 @@ impl Serializer for crate::test_fixture::hybrid::TestHybridRecord { match_key, breakdown_key, } => { - write!(buf, "i,")?; - write!(buf, "{match_key},")?; - write!(buf, "{breakdown_key},")?; + write!(buf, "i,{match_key},{breakdown_key}")?; } crate::test_fixture::hybrid::TestHybridRecord::TestConversion { match_key, value } => { - write!(buf, "c,")?; - write!(buf, "{match_key},")?; - write!(buf, "{value}")?; + write!(buf, "c,{match_key},{value}")?; } } From 94715cf0d3ac7fe3026151c9cc8b7b60775ce0dc Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Thu, 19 Sep 2024 14:25:49 -0700 Subject: [PATCH 052/191] remove phantom data --- ipa-core/src/report/hybrid.rs | 98 ++++++++++++++++------------------- ipa-core/src/report/ipa.rs | 2 +- 2 files changed, 47 insertions(+), 53 deletions(-) diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index a87fa47bb..7c442de38 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -1,7 +1,4 @@ -use std::{ - marker::PhantomData, - ops::{Add, Deref}, -}; +use std::ops::{Add, Deref}; use generic_array::ArrayLength; use typenum::{Sum, U16}; @@ -23,25 +20,22 @@ where } #[derive(Clone, Debug, Eq, PartialEq)] -pub struct HybridConversionReport +pub struct HybridConversionReport where V: SharedValue, - TS: SharedValue, { match_key: Replicated, value: Replicated, - _phantom: PhantomData, } #[derive(Clone, Debug, Eq, PartialEq)] -pub enum HybridReport +pub enum HybridReport where BK: SharedValue, V: SharedValue, - TS: SharedValue, { Impression(HybridImpressionReport), - Conversion(HybridConversionReport), + Conversion(HybridConversionReport), } #[allow(dead_code)] @@ -66,35 +60,36 @@ pub enum HybridInfo<'a> { Conversion(HybridConversionInfo<'a>), } -impl HybridReport +impl HybridReport where BK: SharedValue, V: SharedValue, - TS: SharedValue, // this is only needed for the backpart from EncryptedOprfReport - Replicated: Serializable, - Replicated: Serializable, - Replicated: Serializable, - as Serializable>::Size: Add< as Serializable>::Size>, - Sum< as Serializable>::Size, as Serializable>::Size>: - Add< as Serializable>::Size>, - Sum< - Sum< as Serializable>::Size, as Serializable>::Size>, - as Serializable>::Size, - >: Add, - Sum< - Sum< - Sum< as Serializable>::Size, as Serializable>::Size>, - as Serializable>::Size, - >, - U16, - >: ArrayLength, { /// ## Errors /// If the report contents are invalid. - pub fn from_bytes>( - data: B, - key_registry: &P, - ) -> Result { + pub fn from_bytes(data: B, key_registry: &P) -> Result + where + P: PrivateKeyRegistry, + B: Deref, + TS: SharedValue, // this is only needed for the backport from EncryptedOprfReport + Replicated: Serializable, + Replicated: Serializable, + Replicated: Serializable, + as Serializable>::Size: Add< as Serializable>::Size>, + Sum< as Serializable>::Size, as Serializable>::Size>: + Add< as Serializable>::Size>, + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >: Add, + Sum< + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >, + U16, + >: ArrayLength, + { let encrypted_oprf_report = EncryptedOprfReport::::from_bytes(data)?; let oprf_report = encrypted_oprf_report.decrypt(key_registry)?; match oprf_report.event_type { @@ -105,7 +100,6 @@ where EventType::Trigger => Ok(Self::Conversion(HybridConversionReport { match_key: oprf_report.match_key, value: oprf_report.trigger_value, - _phantom: PhantomData::, })), } } @@ -113,7 +107,6 @@ where #[cfg(test)] mod test { - use std::marker::PhantomData; use rand::{distributions::Alphanumeric, rngs::ThreadRng, thread_rng, Rng}; @@ -148,11 +141,10 @@ mod test { let b = EventType::Source; let oprf_report = build_oprf_report(b, &mut rng); - let hybrid_report = - HybridReport::Impression::(HybridImpressionReport:: { - match_key: oprf_report.match_key.clone(), - breakdown_key: oprf_report.breakdown_key.clone(), - }); + let hybrid_report = HybridReport::Impression::(HybridImpressionReport:: { + match_key: oprf_report.match_key.clone(), + breakdown_key: oprf_report.breakdown_key.clone(), + }); let key_registry = KeyRegistry::::random(1, &mut rng); let key_id = 0; @@ -160,9 +152,11 @@ mod test { let enc_report_bytes = oprf_report .encrypt(key_id, &key_registry, &mut rng) .unwrap(); - let hybrid_report2 = - HybridReport::::from_bytes(enc_report_bytes.as_slice(), &key_registry) - .unwrap(); + let hybrid_report2 = HybridReport::::from_bytes::<_, _, BA20>( + enc_report_bytes.as_slice(), + &key_registry, + ) + .unwrap(); assert_eq!(hybrid_report, hybrid_report2); } @@ -174,12 +168,10 @@ mod test { let b = EventType::Trigger; let oprf_report = build_oprf_report(b, &mut rng); - let hybrid_report = - HybridReport::Conversion::(HybridConversionReport:: { - match_key: oprf_report.match_key.clone(), - value: oprf_report.trigger_value.clone(), - _phantom: PhantomData::, - }); + let hybrid_report = HybridReport::Conversion::(HybridConversionReport:: { + match_key: oprf_report.match_key.clone(), + value: oprf_report.trigger_value.clone(), + }); let key_registry = KeyRegistry::::random(1, &mut rng); let key_id = 0; @@ -187,9 +179,11 @@ mod test { let enc_report_bytes = oprf_report .encrypt(key_id, &key_registry, &mut rng) .unwrap(); - let hybrid_report2 = - HybridReport::::from_bytes(enc_report_bytes.as_slice(), &key_registry) - .unwrap(); + let hybrid_report2 = HybridReport::::from_bytes::<_, _, BA20>( + enc_report_bytes.as_slice(), + &key_registry, + ) + .unwrap(); assert_eq!(hybrid_report, hybrid_report2); } diff --git a/ipa-core/src/report/ipa.rs b/ipa-core/src/report/ipa.rs index 29ecf78ea..a9da93454 100644 --- a/ipa-core/src/report/ipa.rs +++ b/ipa-core/src/report/ipa.rs @@ -49,7 +49,7 @@ use crate::{ }; // TODO(679): This needs to come from configuration. -pub static HELPER_ORIGIN: &str = "github.com/private-attribution"; +static HELPER_ORIGIN: &str = "github.com/private-attribution"; pub type KeyIdentifier = u8; pub const DEFAULT_KEY_ID: KeyIdentifier = 0; From 3538defe030240c13796ce5c1e8a38ada0950eb4 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Thu, 19 Sep 2024 14:26:52 -0700 Subject: [PATCH 053/191] add detail to test name --- ipa-core/src/report/hybrid.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index 7c442de38..c5e4703a1 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -162,7 +162,7 @@ mod test { } #[test] - fn convert_to_hybrid_report() { + fn convert_to_hybrid_conversion_report() { let mut rng = thread_rng(); let b = EventType::Trigger; From 628b16a831cff1537da8ebb3d1884e3a792741d8 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 19 Sep 2024 15:31:13 -0700 Subject: [PATCH 054/191] In the clear script to generate expected result from input file --- ipa-core/Cargo.toml | 5 +++ ipa-core/src/bin/in_the_clear.rs | 57 ++++++++++++++++++++++++++++++ ipa-core/src/cli/playbook/input.rs | 32 ++++++++++++++++- 3 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 ipa-core/src/bin/in_the_clear.rs diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 819eecdb7..144398799 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -197,6 +197,11 @@ name = "crypto_util" required-features = ["cli", "test-fixture", "web-app"] bench = false +[[bin]] +name = "in_the_clear" +required-features = ["cli", "test-fixture"] +bench = false + [[bench]] name = "criterion_arithmetic" path = "benches/ct/arithmetic_circuit.rs" diff --git a/ipa-core/src/bin/in_the_clear.rs b/ipa-core/src/bin/in_the_clear.rs new file mode 100644 index 000000000..a58d88fb2 --- /dev/null +++ b/ipa-core/src/bin/in_the_clear.rs @@ -0,0 +1,57 @@ +use std::{ + error::Error, + path::{Path, PathBuf}, +}; + +use clap::Parser; +use ipa_core::{cli::playbook::InputSource, test_fixture::hybrid::{hybrid_in_the_clear, TestHybridRecord}}; + +#[derive(Debug, Parser)] +pub struct CommandInput { + #[arg( + long, + help = "Read the input from the provided file, instead of standard input" + )] + input_file: Option, +} + +impl From<&CommandInput> for InputSource { + fn from(source: &CommandInput) -> Self { + if let Some(ref file_name) = source.input_file { + InputSource::from_file(file_name) + } else { + InputSource::from_stdin() + } + } +} + +#[derive(Debug, Parser)] +#[clap(name = "rc", about = "Report Collector CLI")] +#[command(about)] +struct Args { + #[clap(flatten)] + input: CommandInput, + + /// The destination file for output. + #[arg(long, value_name = "OUTPUT_FILE")] + output_file: PathBuf, +} + +fn main() -> Result<(), Box> { + let args = Args::parse(); + + let input = InputSource::from(&args.input); + + let input_rows = input.iter::().collect::>(); + let expected = hybrid_in_the_clear(&input_rows, 10); + + let mut file = File::options() + .write(true) + .create_new(true) + .open(args.output_file) + .map_err(|e| format!("Failed to create output file {}: {e}", args.output_file.display()))?; + + write!(file, "{}", serde_json::to_string_pretty(&expected)?)?; + + Ok(()) +} diff --git a/ipa-core/src/cli/playbook/input.rs b/ipa-core/src/cli/playbook/input.rs index b7d31484d..2d9011d69 100644 --- a/ipa-core/src/cli/playbook/input.rs +++ b/ipa-core/src/cli/playbook/input.rs @@ -6,9 +6,16 @@ use std::{ path::PathBuf, }; +use ipa_core::{ + cli::playbook::InputSource, +}; + use crate::{ cli::playbook::generator::U128Generator, ff::U128Conversions, - test_fixture::ipa::TestRawDataRecord, + test_fixture::{ + ipa::TestRawDataRecord, + hybrid::{hybrid_in_the_clear, TestHybridRecord}, + } }; pub trait InputItem { @@ -56,6 +63,29 @@ impl InputItem for TestRawDataRecord { } } +impl InputItem for TestHybridRecord { + fn from_str(s: &str) -> Self { + if let [event_type, match_key, number] = s.splitn(3, ',').collect::>()[..] { + match event_type { + 'i' => TestHybridRecord::TestImpression { + match_key, + breakdown_key: number, + }, + 'c' => TestHybridRecord::TestImpression { + match_key, + value: number, + }, + _ => panic!( + "Invalid input. Rows should start with 'i' or 'c'. Did not expect {:?}", + event_type + ), + } + } else { + panic!("{s} is not a valid {}", type_name::()) + } + } +} + pub struct InputSource { inner: Box, sz: Option, From bc78a1e6e63eee3e25e1b5fabcefd14d54ac9bca Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 19 Sep 2024 16:25:54 -0700 Subject: [PATCH 055/191] Revert "In the clear script to generate expected result from input file" This reverts commit 628b16a831cff1537da8ebb3d1884e3a792741d8. --- ipa-core/Cargo.toml | 5 --- ipa-core/src/bin/in_the_clear.rs | 57 ------------------------------ ipa-core/src/cli/playbook/input.rs | 32 +---------------- 3 files changed, 1 insertion(+), 93 deletions(-) delete mode 100644 ipa-core/src/bin/in_the_clear.rs diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 144398799..819eecdb7 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -197,11 +197,6 @@ name = "crypto_util" required-features = ["cli", "test-fixture", "web-app"] bench = false -[[bin]] -name = "in_the_clear" -required-features = ["cli", "test-fixture"] -bench = false - [[bench]] name = "criterion_arithmetic" path = "benches/ct/arithmetic_circuit.rs" diff --git a/ipa-core/src/bin/in_the_clear.rs b/ipa-core/src/bin/in_the_clear.rs deleted file mode 100644 index a58d88fb2..000000000 --- a/ipa-core/src/bin/in_the_clear.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::{ - error::Error, - path::{Path, PathBuf}, -}; - -use clap::Parser; -use ipa_core::{cli::playbook::InputSource, test_fixture::hybrid::{hybrid_in_the_clear, TestHybridRecord}}; - -#[derive(Debug, Parser)] -pub struct CommandInput { - #[arg( - long, - help = "Read the input from the provided file, instead of standard input" - )] - input_file: Option, -} - -impl From<&CommandInput> for InputSource { - fn from(source: &CommandInput) -> Self { - if let Some(ref file_name) = source.input_file { - InputSource::from_file(file_name) - } else { - InputSource::from_stdin() - } - } -} - -#[derive(Debug, Parser)] -#[clap(name = "rc", about = "Report Collector CLI")] -#[command(about)] -struct Args { - #[clap(flatten)] - input: CommandInput, - - /// The destination file for output. - #[arg(long, value_name = "OUTPUT_FILE")] - output_file: PathBuf, -} - -fn main() -> Result<(), Box> { - let args = Args::parse(); - - let input = InputSource::from(&args.input); - - let input_rows = input.iter::().collect::>(); - let expected = hybrid_in_the_clear(&input_rows, 10); - - let mut file = File::options() - .write(true) - .create_new(true) - .open(args.output_file) - .map_err(|e| format!("Failed to create output file {}: {e}", args.output_file.display()))?; - - write!(file, "{}", serde_json::to_string_pretty(&expected)?)?; - - Ok(()) -} diff --git a/ipa-core/src/cli/playbook/input.rs b/ipa-core/src/cli/playbook/input.rs index 2d9011d69..b7d31484d 100644 --- a/ipa-core/src/cli/playbook/input.rs +++ b/ipa-core/src/cli/playbook/input.rs @@ -6,16 +6,9 @@ use std::{ path::PathBuf, }; -use ipa_core::{ - cli::playbook::InputSource, -}; - use crate::{ cli::playbook::generator::U128Generator, ff::U128Conversions, - test_fixture::{ - ipa::TestRawDataRecord, - hybrid::{hybrid_in_the_clear, TestHybridRecord}, - } + test_fixture::ipa::TestRawDataRecord, }; pub trait InputItem { @@ -63,29 +56,6 @@ impl InputItem for TestRawDataRecord { } } -impl InputItem for TestHybridRecord { - fn from_str(s: &str) -> Self { - if let [event_type, match_key, number] = s.splitn(3, ',').collect::>()[..] { - match event_type { - 'i' => TestHybridRecord::TestImpression { - match_key, - breakdown_key: number, - }, - 'c' => TestHybridRecord::TestImpression { - match_key, - value: number, - }, - _ => panic!( - "Invalid input. Rows should start with 'i' or 'c'. Did not expect {:?}", - event_type - ), - } - } else { - panic!("{s} is not a valid {}", type_name::()) - } - } -} - pub struct InputSource { inner: Box, sz: Option, From e9ec812a44713c4722dea75ff03a19240007ba90 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 19 Sep 2024 17:18:03 -0700 Subject: [PATCH 056/191] Disable jemalloc --- ipa-core/Cargo.toml | 2 +- ipa-core/src/bin/helper.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 819eecdb7..3e816b86c 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -148,7 +148,7 @@ typenum = { version = "1.17", features = ["i128"] } # hpke is pinned to it x25519-dalek = "2.0.0-rc.3" -[target.'cfg(not(target_env = "msvc"))'.dependencies] +[target.'cfg(all(not(target_env = "msvc"), not(target_os = "macos")))'.dependencies] tikv-jemallocator = "0.5.0" [build-dependencies] diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 02f1b2101..790245587 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -22,7 +22,7 @@ use ipa_core::{ }; use tracing::{error, info}; -#[cfg(not(target_env = "msvc"))] +#[cfg(all(not(target_env = "msvc"), not(target_os = "macos")))] #[global_allocator] static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; From 05cfa128f85736fdc98a39670adcea932a47b0a5 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 19 Sep 2024 18:40:25 -0700 Subject: [PATCH 057/191] Check for large increases in step count --- ipa-core/src/protocol/ipa_prf/mod.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index a14da50fe..90e00cb64 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -724,8 +724,7 @@ pub mod tests { #[cfg(all(test, all(feature = "compact-gate", feature = "in-memory-infra")))] mod compact_gate_tests { - - use ipa_step::StepNarrow; + use ipa_step::{CompactStep, StepNarrow}; use crate::{ ff::{ @@ -741,6 +740,18 @@ mod compact_gate_tests { test_fixture::{ipa::TestRawDataRecord, Reconstruct, Runner, TestWorld, TestWorldConfig}, }; + #[test] + fn step_count_limit() { + // This is an arbitrary limit intended to catch changes that unintentionally + // blow up the step count. It can be increased, within reason. + const STEP_COUNT_LIMIT: u32 = 200_000; + assert!( + ProtocolStep::STEP_COUNT < STEP_COUNT_LIMIT, + "Step count of {actual} exceeds limit of {STEP_COUNT_LIMIT}.", + actual = ProtocolStep::STEP_COUNT, + ); + } + #[test] fn saturated_agg() { const EXPECTED: &[u128] = &[0, 255, 255, 0, 0, 0, 0, 0]; From 63bb05e1d365dcd9f9695ff28d7587764d50011f Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 19 Sep 2024 23:47:45 -0700 Subject: [PATCH 058/191] Move decrypt/encrypt functions to their own modules We will add secret sharing, so crypto module will grow in size. It makes sense to keep things in their own modules now. It also does a bit of a refactoring to prepare to encrypt something that is not `IpaPrivateData`, but no new functionality is added by this change. --- ipa-core/src/bin/crypto_util.rs | 6 +- ipa-core/src/cli/crypto.rs | 639 ----------------------------- ipa-core/src/cli/crypto/decrypt.rs | 278 +++++++++++++ ipa-core/src/cli/crypto/encrypt.rs | 269 ++++++++++++ ipa-core/src/cli/crypto/mod.rs | 233 +++++++++++ 5 files changed, 783 insertions(+), 642 deletions(-) delete mode 100644 ipa-core/src/cli/crypto.rs create mode 100644 ipa-core/src/cli/crypto/decrypt.rs create mode 100644 ipa-core/src/cli/crypto/encrypt.rs create mode 100644 ipa-core/src/cli/crypto/mod.rs diff --git a/ipa-core/src/bin/crypto_util.rs b/ipa-core/src/bin/crypto_util.rs index 4ded7026c..99556089f 100644 --- a/ipa-core/src/bin/crypto_util.rs +++ b/ipa-core/src/bin/crypto_util.rs @@ -2,7 +2,7 @@ use std::fmt::Debug; use clap::{Parser, Subcommand}; use ipa_core::{ - cli::crypto::{decrypt_and_reconstruct, encrypt, DecryptArgs, EncryptArgs}, + cli::crypto::{DecryptArgs, EncryptArgs}, error::BoxError, }; @@ -24,8 +24,8 @@ enum CryptoUtilCommand { async fn main() -> Result<(), BoxError> { let args = Args::parse(); match args.action { - CryptoUtilCommand::Encrypt(encrypt_args) => encrypt(&encrypt_args)?, - CryptoUtilCommand::Decrypt(decrypt_args) => decrypt_and_reconstruct(decrypt_args).await?, + CryptoUtilCommand::Encrypt(encrypt_args) => encrypt_args.encrypt()?, + CryptoUtilCommand::Decrypt(decrypt_args) => decrypt_args.decrypt_and_reconstruct().await?, } Ok(()) } diff --git a/ipa-core/src/cli/crypto.rs b/ipa-core/src/cli/crypto.rs deleted file mode 100644 index 9c8fe7a4b..000000000 --- a/ipa-core/src/cli/crypto.rs +++ /dev/null @@ -1,639 +0,0 @@ -use std::{ - fmt::Debug, - fs::{read_to_string, File, OpenOptions}, - io::{BufRead, BufReader, Write}, - iter::zip, - path::PathBuf, -}; - -use clap::Parser; -use rand::thread_rng; - -use crate::{ - cli::playbook::{BreakdownKey, InputSource, Timestamp, TriggerValue}, - config::{hpke_registry, HpkeServerConfig, KeyRegistries, NetworkConfig}, - error::BoxError, - ff::{ - boolean_array::{BA20, BA3, BA8}, - U128Conversions, - }, - hpke::{KeyRegistry, PrivateKeyOnly}, - report::{EncryptedOprfReport, EventType, OprfReport, DEFAULT_KEY_ID}, - secret_sharing::IntoShares, - test_fixture::{ipa::TestRawDataRecord, Reconstruct}, -}; - -#[derive(Debug, Parser)] -#[clap(name = "test_encrypt", about = "Test Encrypt")] -#[command(about)] -pub struct EncryptArgs { - /// Path to file to secret share and encrypt - #[arg(long)] - input_file: PathBuf, - // /// The destination dir for encrypted output. - // /// In that dir, it will create helper1.enc, - // /// helper2.enc, and helper3.enc - #[arg(long, value_name = "FILE")] - output_dir: PathBuf, - /// Path to helper network configuration file - #[arg(long)] - network: PathBuf, -} - -#[derive(Debug, Parser)] -#[clap(name = "test_decrypt", about = "Test Decrypt")] -#[command(about)] -pub struct DecryptArgs { - /// Path to helper1 file to decrypt - #[arg(long)] - input_file1: PathBuf, - - /// Helper1 Private key for decrypting match keys - #[arg(long)] - mk_private_key1: PathBuf, - - /// Path to helper2 file to decrypt - #[arg(long)] - input_file2: PathBuf, - - /// Helper2 Private key for decrypting match keys - #[arg(long)] - mk_private_key2: PathBuf, - - /// Path to helper3 file to decrypt - #[arg(long)] - input_file3: PathBuf, - - /// Helper3 Private key for decrypting match keys - #[arg(long)] - mk_private_key3: PathBuf, - - /// The destination file for decrypted output. - #[arg(long, value_name = "FILE")] - output_file: PathBuf, -} - -/// # Panics -/// if input file or network file are not correctly formatted -/// # Errors -/// if it cannot open the files -pub fn encrypt(args: &EncryptArgs) -> Result<(), BoxError> { - let input = InputSource::from_file(&args.input_file); - - let mut rng = thread_rng(); - let mut key_registries = KeyRegistries::default(); - - let network = NetworkConfig::from_toml_str( - &read_to_string(&args.network) - .unwrap_or_else(|e| panic!("Failed to open network file: {:?}. {}", &args.network, e)), - ) - .unwrap_or_else(|e| { - panic!( - "Failed to parse network file into toml: {:?}. {}", - &args.network, e - ) - }); - let Some(key_registries) = key_registries.init_from(&network) else { - panic!("could not load network file") - }; - - let shares: [Vec>; 3] = - input.iter::().share(); - - for (index, (shares, key_registry)) in zip(shares, key_registries).enumerate() { - let output_filename = format!("helper{}.enc", index + 1); - let mut writer = OpenOptions::new() - .write(true) - .create_new(true) - .open(args.output_dir.join(&output_filename)) - .unwrap_or_else(|e| panic!("unable write to {}. {}", &output_filename, e)); - - for share in shares { - let output = share - .encrypt(DEFAULT_KEY_ID, key_registry, &mut rng) - .unwrap(); - let hex_output = hex::encode(&output); - writeln!(writer, "{hex_output}")?; - } - } - - Ok(()) -} - -async fn build_hpke_registry( - private_key_file: PathBuf, -) -> Result, BoxError> { - let mk_encryption = Some(HpkeServerConfig::File { private_key_file }); - let key_registry = hpke_registry(mk_encryption.as_ref()).await?; - Ok(key_registry) -} - -struct DecryptedReports { - reader: BufReader, - key_registry: KeyRegistry, -} - -impl DecryptedReports { - fn new(filename: &PathBuf, key_registry: KeyRegistry) -> Self { - let file = File::open(filename) - .unwrap_or_else(|e| panic!("unable to open file {filename:?}. {e}")); - let reader = BufReader::new(file); - Self { - reader, - key_registry, - } - } -} - -impl Iterator for DecryptedReports { - type Item = OprfReport; - - fn next(&mut self) -> Option { - let mut line = String::new(); - if self.reader.read_line(&mut line).unwrap() > 0 { - let encrypted_report_bytes = hex::decode(line.trim()).unwrap(); - let enc_report = - EncryptedOprfReport::from_bytes(encrypted_report_bytes.as_slice()).unwrap(); - let dec_report: OprfReport = - enc_report.decrypt(&self.key_registry).unwrap(); - Some(dec_report) - } else { - None - } - } -} - -/// # Panics -// if input files or private_keys are not correctly formatted -/// # Errors -/// if it cannot open the files -pub async fn decrypt_and_reconstruct(args: DecryptArgs) -> Result<(), BoxError> { - let key_registry1 = build_hpke_registry(args.mk_private_key1).await?; - let key_registry2 = build_hpke_registry(args.mk_private_key2).await?; - let key_registry3 = build_hpke_registry(args.mk_private_key3).await?; - let decrypted_reports1 = DecryptedReports::new(&args.input_file1, key_registry1); - let decrypted_reports2 = DecryptedReports::new(&args.input_file2, key_registry2); - let decrypted_reports3 = DecryptedReports::new(&args.input_file3, key_registry3); - - let mut writer = Box::new( - OpenOptions::new() - .write(true) - .create_new(true) - .open(args.output_file)?, - ); - - for (dec_report1, (dec_report2, dec_report3)) in - decrypted_reports1.zip(decrypted_reports2.zip(decrypted_reports3)) - { - let timestamp = [ - dec_report1.timestamp, - dec_report2.timestamp, - dec_report3.timestamp, - ] - .reconstruct() - .as_u128(); - - let match_key = [ - dec_report1.match_key, - dec_report2.match_key, - dec_report3.match_key, - ] - .reconstruct() - .as_u128(); - - // these aren't reconstucted, so we explictly make sure - // they are consistent across all three files, then set - // it to the first one (without loss of generality) - assert_eq!(dec_report1.event_type, dec_report2.event_type); - assert_eq!(dec_report2.event_type, dec_report3.event_type); - let is_trigger_report = dec_report1.event_type == EventType::Trigger; - - let breakdown_key = [ - dec_report1.breakdown_key, - dec_report2.breakdown_key, - dec_report3.breakdown_key, - ] - .reconstruct() - .as_u128(); - - let trigger_value = [ - dec_report1.trigger_value, - dec_report2.trigger_value, - dec_report3.trigger_value, - ] - .reconstruct() - .as_u128(); - - writeln!( - writer, - "{},{},{},{},{}", - timestamp, - match_key, - u8::from(is_trigger_report), - breakdown_key, - trigger_value, - )?; - } - - Ok(()) -} - -#[cfg(all(test, feature = "in-memory-infra"))] -mod tests { - use std::{ - fs::File, - io::{BufRead, BufReader, Write}, - path::Path, - sync::Arc, - }; - - use clap::Parser; - use hpke::Deserializable; - use rand::thread_rng; - use tempfile::{tempdir, NamedTempFile}; - - use crate::{ - cli::{ - crypto::{decrypt_and_reconstruct, encrypt, DecryptArgs, EncryptArgs}, - CsvSerializer, - }, - ff::{boolean_array::BA16, U128Conversions}, - helpers::query::{IpaQueryConfig, QuerySize}, - hpke::{IpaPrivateKey, KeyRegistry, PrivateKeyOnly}, - query::OprfIpaQuery, - report::EncryptedOprfReportStreams, - test_fixture::{ - ipa::TestRawDataRecord, join3v, EventGenerator, EventGeneratorConfig, Reconstruct, - TestWorld, - }, - }; - - fn are_files_equal(file1: &Path, file2: &Path) { - let file1 = - File::open(file1).unwrap_or_else(|e| panic!("unable to open {}: {e}", file1.display())); - let file2 = - File::open(file2).unwrap_or_else(|e| panic!("unable to open {}: {e}", file2.display())); - let reader1 = BufReader::new(file1).lines(); - let mut reader2 = BufReader::new(file2).lines(); - for line1 in reader1 { - let line2 = reader2.next().expect("Files have different lengths"); - assert_eq!(line1.unwrap(), line2.unwrap()); - } - assert!(reader2.next().is_none(), "Files have different lengths"); - } - - fn write_input_file() -> NamedTempFile { - let count = 10; - let rng = thread_rng(); - let event_gen_args = EventGeneratorConfig::new(10, 5, 20, 1, 10, 604_800); - - let event_gen = EventGenerator::with_config(rng, event_gen_args) - .take(count) - .collect::>(); - let mut input = NamedTempFile::new().unwrap(); - - for event in event_gen { - let _ = event.to_csv(input.as_file_mut()); - writeln!(input.as_file()).unwrap(); - } - input.as_file_mut().flush().unwrap(); - input - } - - fn write_network_file() -> NamedTempFile { - let network_data = r#" -[[peers]] -url = "helper1.test" -[peers.hpke] -public_key = "92a6fb666c37c008defd74abf3204ebea685742eab8347b08e2f7c759893947a" -[[peers]] -url = "helper2.test" -[peers.hpke] -public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" -[[peers]] -url = "helper3.test" -[peers.hpke] -public_key = "b900be35da06106a83ed73c33f733e03e4ea5888b7ea4c912ab270b0b0f8381e" -"#; - let mut network = NamedTempFile::new().unwrap(); - writeln!(network.as_file_mut(), "{network_data}").unwrap(); - network - } - - fn write_mk_private_key(mk_private_key_data: &str) -> NamedTempFile { - let mut mk_private_key = NamedTempFile::new().unwrap(); - writeln!(mk_private_key.as_file_mut(), "{mk_private_key_data}").unwrap(); - mk_private_key - } - - fn build_encrypt_args( - input_file: &Path, - output_dir: &Path, - network_file: &Path, - ) -> EncryptArgs { - EncryptArgs::try_parse_from([ - "test_encrypt", - "--input-file", - input_file.to_str().unwrap(), - "--output-dir", - output_dir.to_str().unwrap(), - "--network", - network_file.to_str().unwrap(), - ]) - .unwrap() - } - - fn build_decrypt_args( - enc1: &Path, - enc2: &Path, - enc3: &Path, - mk_private_key1: &Path, - mk_private_key2: &Path, - mk_private_key3: &Path, - decrypt_output: &Path, - ) -> DecryptArgs { - DecryptArgs::try_parse_from([ - "test_decrypt", - "--input-file1", - enc1.to_str().unwrap(), - "--input-file2", - enc2.to_str().unwrap(), - "--input-file3", - enc3.to_str().unwrap(), - "--mk-private-key1", - mk_private_key1.to_str().unwrap(), - "--mk-private-key2", - mk_private_key2.to_str().unwrap(), - "--mk-private-key3", - mk_private_key3.to_str().unwrap(), - "--output-file", - decrypt_output.to_str().unwrap(), - ]) - .unwrap() - } - - #[test] - #[should_panic = "Failed to open network file:"] - fn encrypt_no_network_file() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_dir = tempdir().unwrap(); - let network_file = network_dir.path().join("does_not_exist"); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.as_path()); - let _ = encrypt(&encrypt_args); - } - - #[test] - #[should_panic = "TOML parse error at"] - fn encrypt_bad_network_file() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_data = r" -this is not toml! -%^& weird characters -(\deadbeef>? -"; - let mut network_file = NamedTempFile::new().unwrap(); - writeln!(network_file.as_file_mut(), "{network_data}").unwrap(); - - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - } - - #[test] - #[should_panic = "invalid length 2, expected an array of length 3"] - fn encrypt_incomplete_network_file() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_data = r#" -[[peers]] -url = "helper1.test" -[peers.hpke] -public_key = "92a6fb666c37c008defd74abf3204ebea685742eab8347b08e2f7c759893947a" -[[peers]] -url = "helper2.test" -[peers.hpke] -public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" -"#; - let mut network_file = NamedTempFile::new().unwrap(); - writeln!(network_file.as_file_mut(), "{network_data}").unwrap(); - - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - } - - #[tokio::test] - #[should_panic = "No such file or directory (os error 2)"] - async fn decrypt_no_enc_file() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_file = write_network_file(); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - - let decrypt_output = output_dir.path().join("output"); - let enc1 = output_dir.path().join("DOES_NOT_EXIST.enc"); - let enc2 = output_dir.path().join("helper2.enc"); - let enc3 = output_dir.path().join("helper3.enc"); - - let mk_private_key1 = write_mk_private_key( - "53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff", - ); - let mk_private_key2 = write_mk_private_key( - "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", - ); - let mk_private_key3 = write_mk_private_key( - "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", - ); - - let decrypt_args = build_decrypt_args( - enc1.as_path(), - enc2.as_path(), - enc3.as_path(), - mk_private_key1.path(), - mk_private_key2.path(), - mk_private_key3.path(), - &decrypt_output, - ); - let _ = decrypt_and_reconstruct(decrypt_args).await; - } - - #[tokio::test] - #[should_panic = "called `Result::unwrap()` on an `Err` value: Crypt(Other)"] - async fn decrypt_bad_private_key() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_file = write_network_file(); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - - let decrypt_output = output_dir.path().join("output"); - let enc1 = output_dir.path().join("helper1.enc"); - let enc2 = output_dir.path().join("helper2.enc"); - let enc3 = output_dir.path().join("helper3.enc"); - let mk_private_key1 = write_mk_private_key( - "bad9fdc79d98471cedd07ee6743d3bb43aabbddabc49cd9fae1d5daef3f2b3ba", - ); - let mk_private_key2 = write_mk_private_key( - "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", - ); - let mk_private_key3 = write_mk_private_key( - "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", - ); - - let decrypt_args = build_decrypt_args( - enc1.as_path(), - enc2.as_path(), - enc3.as_path(), - mk_private_key1.path(), - mk_private_key2.path(), - mk_private_key3.path(), - &decrypt_output, - ); - let _ = decrypt_and_reconstruct(decrypt_args).await; - } - - #[tokio::test] - async fn encrypt_and_decrypt() { - let input_file = write_input_file(); - let output_dir = tempdir().unwrap(); - let network_file = write_network_file(); - let encrypt_args = - build_encrypt_args(input_file.path(), output_dir.path(), network_file.path()); - let _ = encrypt(&encrypt_args); - - let decrypt_output = output_dir.path().join("output"); - let enc1 = output_dir.path().join("helper1.enc"); - let enc2 = output_dir.path().join("helper2.enc"); - let enc3 = output_dir.path().join("helper3.enc"); - let mk_private_key1 = write_mk_private_key( - "53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff", - ); - let mk_private_key2 = write_mk_private_key( - "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", - ); - let mk_private_key3 = write_mk_private_key( - "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", - ); - - let decrypt_args = build_decrypt_args( - enc1.as_path(), - enc2.as_path(), - enc3.as_path(), - mk_private_key1.path(), - mk_private_key2.path(), - mk_private_key3.path(), - &decrypt_output, - ); - let _ = decrypt_and_reconstruct(decrypt_args).await; - - are_files_equal(input_file.path(), &decrypt_output); - } - - #[tokio::test] - async fn encrypt_and_execute_query() { - const EXPECTED: &[u128] = &[0, 2, 5]; - - let records = vec![ - TestRawDataRecord { - timestamp: 0, - user_id: 12345, - is_trigger_report: false, - breakdown_key: 2, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 4, - user_id: 68362, - is_trigger_report: false, - breakdown_key: 1, - trigger_value: 0, - }, - TestRawDataRecord { - timestamp: 10, - user_id: 12345, - is_trigger_report: true, - breakdown_key: 0, - trigger_value: 5, - }, - TestRawDataRecord { - timestamp: 12, - user_id: 68362, - is_trigger_report: true, - breakdown_key: 0, - trigger_value: 2, - }, - ]; - let query_size = QuerySize::try_from(records.len()).unwrap(); - let mut input_file = NamedTempFile::new().unwrap(); - - for event in records { - event.to_csv(input_file.as_file_mut()).unwrap(); - writeln!(input_file.as_file()).unwrap(); - } - input_file.flush().unwrap(); - - let output_dir = tempdir().unwrap(); - let network_file = write_network_file(); - encrypt(&build_encrypt_args( - input_file.path(), - output_dir.path(), - network_file.path(), - )) - .unwrap(); - - let files = [ - &output_dir.path().join("helper1.enc"), - &output_dir.path().join("helper2.enc"), - &output_dir.path().join("helper3.enc"), - ]; - - let world = TestWorld::default(); - - let mk_private_keys = [ - "53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff", - "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", - "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", - ]; - - #[allow(clippy::large_futures)] - let results = join3v( - EncryptedOprfReportStreams::from(files) - .streams - .into_iter() - .zip(world.contexts()) - .zip(mk_private_keys.into_iter()) - .map(|((input, ctx), mk_private_key)| { - let mk_private_key = hex::decode(mk_private_key) - .map(|bytes| IpaPrivateKey::from_bytes(&bytes).unwrap()) - .unwrap(); - let query_config = IpaQueryConfig { - max_breakdown_key: 3, - with_dp: 0, - epsilon: 1.0, - ..Default::default() - }; - - OprfIpaQuery::<_, BA16, _>::new( - query_config, - Arc::new(KeyRegistry::from_keys([PrivateKeyOnly(mk_private_key)])), - ) - .execute(ctx, query_size, input) - }), - ) - .await; - - assert_eq!( - results.reconstruct()[0..3] - .iter() - .map(U128Conversions::as_u128) - .collect::>(), - EXPECTED - ); - } -} diff --git a/ipa-core/src/cli/crypto/decrypt.rs b/ipa-core/src/cli/crypto/decrypt.rs new file mode 100644 index 000000000..63ec8e3f9 --- /dev/null +++ b/ipa-core/src/cli/crypto/decrypt.rs @@ -0,0 +1,278 @@ +use std::{ + fs::{File, OpenOptions}, + io::{BufRead, BufReader, Write}, + path::{Path, PathBuf}, +}; + +use clap::Parser; + +use crate::{ + config::{hpke_registry, HpkeServerConfig}, + error::BoxError, + ff::{ + boolean_array::{BA20, BA3, BA8}, + U128Conversions, + }, + hpke::{KeyRegistry, PrivateKeyOnly}, + report::{EncryptedOprfReport, EventType, OprfReport}, + test_fixture::Reconstruct, +}; + +#[derive(Debug, Parser)] +#[clap(name = "test_decrypt", about = "Test Decrypt")] +#[command(about)] +pub struct DecryptArgs { + /// Path to helper1 file to decrypt + #[arg(long)] + input_file1: PathBuf, + + /// Helper1 Private key for decrypting match keys + #[arg(long)] + mk_private_key1: PathBuf, + + /// Path to helper2 file to decrypt + #[arg(long)] + input_file2: PathBuf, + + /// Helper2 Private key for decrypting match keys + #[arg(long)] + mk_private_key2: PathBuf, + + /// Path to helper3 file to decrypt + #[arg(long)] + input_file3: PathBuf, + + /// Helper3 Private key for decrypting match keys + #[arg(long)] + mk_private_key3: PathBuf, + + /// The destination file for decrypted output. + #[arg(long, value_name = "FILE")] + output_file: PathBuf, +} + +impl DecryptArgs { + #[must_use] + pub fn new( + input_file1: &Path, + input_file2: &Path, + input_file3: &Path, + mk_private_key1: &Path, + mk_private_key2: &Path, + mk_private_key3: &Path, + output_file: &Path, + ) -> Self { + Self { + input_file1: input_file1.to_path_buf(), + mk_private_key1: mk_private_key1.to_path_buf(), + input_file2: input_file2.to_path_buf(), + mk_private_key2: mk_private_key2.to_path_buf(), + input_file3: input_file3.to_path_buf(), + mk_private_key3: mk_private_key3.to_path_buf(), + output_file: output_file.to_path_buf(), + } + } + + /// # Panics + // if input files or private_keys are not correctly formatted + /// # Errors + /// if it cannot open the files + pub async fn decrypt_and_reconstruct(self) -> Result<(), BoxError> { + let Self { + input_file1, + mk_private_key1, + input_file2, + mk_private_key2, + input_file3, + mk_private_key3, + output_file, + } = self; + let key_registry1 = build_hpke_registry(mk_private_key1).await?; + let key_registry2 = build_hpke_registry(mk_private_key2).await?; + let key_registry3 = build_hpke_registry(mk_private_key3).await?; + let decrypted_reports1 = DecryptedReports::new(&input_file1, key_registry1); + let decrypted_reports2 = DecryptedReports::new(&input_file2, key_registry2); + let decrypted_reports3 = DecryptedReports::new(&input_file3, key_registry3); + + let mut writer = Box::new( + OpenOptions::new() + .write(true) + .create_new(true) + .open(output_file)?, + ); + + for (dec_report1, (dec_report2, dec_report3)) in + decrypted_reports1.zip(decrypted_reports2.zip(decrypted_reports3)) + { + let timestamp = [ + dec_report1.timestamp, + dec_report2.timestamp, + dec_report3.timestamp, + ] + .reconstruct() + .as_u128(); + + let match_key = [ + dec_report1.match_key, + dec_report2.match_key, + dec_report3.match_key, + ] + .reconstruct() + .as_u128(); + + // these aren't reconstucted, so we explictly make sure + // they are consistent across all three files, then set + // it to the first one (without loss of generality) + assert_eq!(dec_report1.event_type, dec_report2.event_type); + assert_eq!(dec_report2.event_type, dec_report3.event_type); + let is_trigger_report = dec_report1.event_type == EventType::Trigger; + + let breakdown_key = [ + dec_report1.breakdown_key, + dec_report2.breakdown_key, + dec_report3.breakdown_key, + ] + .reconstruct() + .as_u128(); + + let trigger_value = [ + dec_report1.trigger_value, + dec_report2.trigger_value, + dec_report3.trigger_value, + ] + .reconstruct() + .as_u128(); + + writeln!( + writer, + "{},{},{},{},{}", + timestamp, + match_key, + u8::from(is_trigger_report), + breakdown_key, + trigger_value, + )?; + } + + Ok(()) + } +} + +struct DecryptedReports { + reader: BufReader, + key_registry: KeyRegistry, +} + +impl Iterator for DecryptedReports { + type Item = OprfReport; + + fn next(&mut self) -> Option { + let mut line = String::new(); + if self.reader.read_line(&mut line).unwrap() > 0 { + let encrypted_report_bytes = hex::decode(line.trim()).unwrap(); + let enc_report = + EncryptedOprfReport::from_bytes(encrypted_report_bytes.as_slice()).unwrap(); + let dec_report: OprfReport = + enc_report.decrypt(&self.key_registry).unwrap(); + Some(dec_report) + } else { + None + } + } +} + +impl DecryptedReports { + fn new(filename: &PathBuf, key_registry: KeyRegistry) -> Self { + let file = File::open(filename) + .unwrap_or_else(|e| panic!("unable to open file {filename:?}. {e}")); + let reader = BufReader::new(file); + Self { + reader, + key_registry, + } + } +} + +async fn build_hpke_registry( + private_key_file: PathBuf, +) -> Result, BoxError> { + let mk_encryption = Some(HpkeServerConfig::File { private_key_file }); + let key_registry = hpke_registry(mk_encryption.as_ref()).await?; + Ok(key_registry) +} + +#[cfg(test)] +mod tests { + + use tempfile::tempdir; + + use crate::cli::crypto::{decrypt::DecryptArgs, encrypt::EncryptArgs, sample_data}; + + #[tokio::test] + #[should_panic = "No such file or directory (os error 2)"] + async fn decrypt_no_enc_file() { + let input_file = sample_data::write_csv(sample_data::test_ipa_data().take(10)).unwrap(); + + let output_dir = tempdir().unwrap(); + let network_file = sample_data::test_keys().network_config(); + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + + let decrypt_output = output_dir.path().join("output"); + let enc1 = output_dir.path().join("DOES_NOT_EXIST.enc"); + let enc2 = output_dir.path().join("helper2.enc"); + let enc3 = output_dir.path().join("helper3.enc"); + + let [mk_private_key1, mk_private_key2, mk_private_key3] = + sample_data::test_keys().sk_files(); + + let decrypt_args = DecryptArgs::new( + enc1.as_path(), + enc2.as_path(), + enc3.as_path(), + mk_private_key1.path(), + mk_private_key2.path(), + mk_private_key3.path(), + &decrypt_output, + ); + decrypt_args.decrypt_and_reconstruct().await.unwrap(); + } + + #[tokio::test] + #[should_panic = "called `Result::unwrap()` on an `Err` value: Crypt(Other)"] + async fn decrypt_bad_private_key() { + let input_file = sample_data::write_csv(sample_data::test_ipa_data().take(10)).unwrap(); + + let network_file = sample_data::test_keys().network_config(); + let output_dir = tempdir().unwrap(); + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + + let decrypt_output = output_dir.path().join("output"); + let enc1 = output_dir.path().join("helper1.enc"); + let enc2 = output_dir.path().join("helper2.enc"); + let enc3 = output_dir.path().join("helper3.enc"); + + // corrupt the secret key for H1 + let mut keys = sample_data::test_keys().clone(); + let mut sk = keys.get_sk(0); + sk[0] = !sk[0]; + keys.set_sk(0, sk); + let [mk_private_key1, mk_private_key2, mk_private_key3] = keys.sk_files(); + + DecryptArgs::new( + enc1.as_path(), + enc2.as_path(), + enc3.as_path(), + mk_private_key1.path(), + mk_private_key2.path(), + mk_private_key3.path(), + &decrypt_output, + ) + .decrypt_and_reconstruct() + .await + .unwrap(); + } +} diff --git a/ipa-core/src/cli/crypto/encrypt.rs b/ipa-core/src/cli/crypto/encrypt.rs new file mode 100644 index 000000000..8c74f169f --- /dev/null +++ b/ipa-core/src/cli/crypto/encrypt.rs @@ -0,0 +1,269 @@ +use std::{ + fs::{read_to_string, OpenOptions}, + io::Write, + iter::zip, + path::{Path, PathBuf}, +}; + +use clap::Parser; +use rand::thread_rng; + +use crate::{ + cli::playbook::{BreakdownKey, InputSource, Timestamp, TriggerValue}, + config::{KeyRegistries, NetworkConfig}, + error::BoxError, + report::{OprfReport, DEFAULT_KEY_ID}, + secret_sharing::IntoShares, + test_fixture::ipa::TestRawDataRecord, +}; + +#[derive(Debug, Parser)] +#[clap(name = "test_encrypt", about = "Test Encrypt")] +#[command(about)] +pub struct EncryptArgs { + /// Path to file to secret share and encrypt + #[arg(long)] + input_file: PathBuf, + /// The destination dir for encrypted output. + /// In that dir, it will create helper1.enc, + /// helper2.enc, and helper3.enc + #[arg(long, value_name = "FILE")] + output_dir: PathBuf, + /// Path to helper network configuration file + #[arg(long)] + network: PathBuf, +} + +impl EncryptArgs { + #[must_use] + pub fn new(input_file: &Path, output_dir: &Path, network: &Path) -> Self { + Self { + input_file: input_file.to_path_buf(), + output_dir: output_dir.to_path_buf(), + network: network.to_path_buf(), + } + } + + /// # Panics + /// if input file or network file are not correctly formatted + /// # Errors + /// if it cannot open the files + pub fn encrypt(&self) -> Result<(), BoxError> { + let input = InputSource::from_file(&self.input_file); + + let mut rng = thread_rng(); + let mut key_registries = KeyRegistries::default(); + + let network = + NetworkConfig::from_toml_str(&read_to_string(&self.network).unwrap_or_else(|e| { + panic!("Failed to open network file: {:?}. {}", &self.network, e) + })) + .unwrap_or_else(|e| { + panic!( + "Failed to parse network file into toml: {:?}. {}", + &self.network, e + ) + }); + let Some(key_registries) = key_registries.init_from(&network) else { + panic!("could not load network file") + }; + + let shares: [Vec>; 3] = + input.iter::().share(); + + for (index, (shares, key_registry)) in zip(shares, key_registries).enumerate() { + let output_filename = format!("helper{}.enc", index + 1); + let mut writer = OpenOptions::new() + .write(true) + .create_new(true) + .open(self.output_dir.join(&output_filename)) + .unwrap_or_else(|e| panic!("unable write to {}. {}", &output_filename, e)); + + for share in shares { + let output = share + .encrypt(DEFAULT_KEY_ID, key_registry, &mut rng) + .unwrap(); + let hex_output = hex::encode(&output); + writeln!(writer, "{hex_output}")?; + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{io::Write, sync::Arc}; + + use hpke::Deserializable; + use tempfile::{tempdir, NamedTempFile}; + + use crate::{ + cli::{ + crypto::{encrypt::EncryptArgs, sample_data}, + CsvSerializer, + }, + ff::{boolean_array::BA16, U128Conversions}, + helpers::query::{IpaQueryConfig, QuerySize}, + hpke::{IpaPrivateKey, KeyRegistry, PrivateKeyOnly}, + query::OprfIpaQuery, + report::EncryptedOprfReportStreams, + test_fixture::{ipa::TestRawDataRecord, join3v, Reconstruct, TestWorld}, + }; + + #[tokio::test] + async fn encrypt_and_execute_query() { + const EXPECTED: &[u128] = &[0, 2, 5]; + + let records = vec![ + TestRawDataRecord { + timestamp: 0, + user_id: 12345, + is_trigger_report: false, + breakdown_key: 2, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 4, + user_id: 68362, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 10, + user_id: 12345, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 5, + }, + TestRawDataRecord { + timestamp: 12, + user_id: 68362, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 2, + }, + ]; + let query_size = QuerySize::try_from(records.len()).unwrap(); + let mut input_file = NamedTempFile::new().unwrap(); + + for event in records { + event.to_csv(input_file.as_file_mut()).unwrap(); + writeln!(input_file.as_file()).unwrap(); + } + input_file.flush().unwrap(); + + let output_dir = tempdir().unwrap(); + let network_file = sample_data::test_keys().network_config(); + + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + + let files = [ + &output_dir.path().join("helper1.enc"), + &output_dir.path().join("helper2.enc"), + &output_dir.path().join("helper3.enc"), + ]; + + let world = TestWorld::default(); + + let mk_private_keys = [ + "53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff", + "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", + "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", + ]; + + #[allow(clippy::large_futures)] + let results = join3v( + EncryptedOprfReportStreams::from(files) + .streams + .into_iter() + .zip(world.contexts()) + .zip(mk_private_keys.into_iter()) + .map(|((input, ctx), mk_private_key)| { + let mk_private_key = hex::decode(mk_private_key) + .map(|bytes| IpaPrivateKey::from_bytes(&bytes).unwrap()) + .unwrap(); + let query_config = IpaQueryConfig { + max_breakdown_key: 3, + with_dp: 0, + epsilon: 1.0, + ..Default::default() + }; + + OprfIpaQuery::<_, BA16, _>::new( + query_config, + Arc::new(KeyRegistry::from_keys([PrivateKeyOnly(mk_private_key)])), + ) + .execute(ctx, query_size, input) + }), + ) + .await; + + assert_eq!( + results.reconstruct()[0..3] + .iter() + .map(U128Conversions::as_u128) + .collect::>(), + EXPECTED + ); + } + + #[test] + #[should_panic = "Failed to open network file:"] + fn encrypt_no_network_file() { + let input_file = sample_data::write_csv(sample_data::test_ipa_data().take(10)).unwrap(); + + let output_dir = tempdir().unwrap(); + let network_dir = tempdir().unwrap(); + let network_file = network_dir.path().join("does_not_exist"); + EncryptArgs::new(input_file.path(), output_dir.path(), &network_file) + .encrypt() + .unwrap(); + } + + #[test] + #[should_panic = "TOML parse error at"] + fn encrypt_bad_network_file() { + let input_file = sample_data::write_csv(sample_data::test_ipa_data().take(10)).unwrap(); + let output_dir = tempdir().unwrap(); + let network_data = r" +this is not toml! +%^& weird characters +(\deadbeef>? +"; + let mut network_file = NamedTempFile::new().unwrap(); + writeln!(network_file.as_file_mut(), "{network_data}").unwrap(); + + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + } + + #[test] + #[should_panic = "invalid length 2, expected an array of length 3"] + fn encrypt_incomplete_network_file() { + let input_file = sample_data::write_csv(sample_data::test_ipa_data().take(10)).unwrap(); + + let output_dir = tempdir().unwrap(); + let network_data = r#" +[[peers]] +url = "helper1.test" +[peers.hpke] +public_key = "92a6fb666c37c008defd74abf3204ebea685742eab8347b08e2f7c759893947a" +[[peers]] +url = "helper2.test" +[peers.hpke] +public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e" +"#; + let mut network_file = NamedTempFile::new().unwrap(); + writeln!(network_file.as_file_mut(), "{network_data}").unwrap(); + + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + } +} diff --git a/ipa-core/src/cli/crypto/mod.rs b/ipa-core/src/cli/crypto/mod.rs new file mode 100644 index 000000000..3703038e2 --- /dev/null +++ b/ipa-core/src/cli/crypto/mod.rs @@ -0,0 +1,233 @@ +mod decrypt; +mod encrypt; + +pub use decrypt::DecryptArgs; +pub use encrypt::EncryptArgs; + +#[cfg(test)] +mod sample_data { + use std::{io, io::Write, sync::OnceLock}; + + use hpke::{Deserializable, Serializable}; + use rand::thread_rng; + use tempfile::NamedTempFile; + + use crate::{ + cli::CsvSerializer, + hpke::{IpaPrivateKey, IpaPublicKey}, + test_fixture::{ipa::TestRawDataRecord, EventGenerator, EventGeneratorConfig}, + }; + + /// Keys that are used in crypto tests + #[derive(Clone)] + pub(super) struct TestKeys { + inner: [(IpaPublicKey, IpaPrivateKey); 3], + } + + static TEST_KEYS: OnceLock = OnceLock::new(); + pub fn test_keys() -> &'static TestKeys { + TEST_KEYS.get_or_init(TestKeys::new) + } + + impl TestKeys { + pub fn new() -> Self { + Self { + inner: [ + ( + decode_key::<_, IpaPublicKey>( + "92a6fb666c37c008defd74abf3204ebea685742eab8347b08e2f7c759893947a", + ), + decode_key::<_, IpaPrivateKey>( + "53d58e022981f2edbf55fec1b45dbabd08a3442cb7b7c598839de5d7a5888bff", + ), + ), + ( + decode_key::<_, IpaPublicKey>( + "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e", + ), + decode_key::<_, IpaPrivateKey>( + "3a0a993a3cfc7e8d381addac586f37de50c2a14b1a6356d71e94ca2afaeb2569", + ), + ), + ( + decode_key::<_, IpaPublicKey>( + "b900be35da06106a83ed73c33f733e03e4ea5888b7ea4c912ab270b0b0f8381e", + ), + decode_key::<_, IpaPrivateKey>( + "1fb5c5274bf85fbe6c7935684ef05499f6cfb89ac21640c28330135cc0e8a0f7", + ), + ), + ], + } + } + + pub fn network_config(&self) -> NamedTempFile { + let mut file = NamedTempFile::new().unwrap(); + let [pk1, pk2, pk3] = self.inner.each_ref().map(|(pk, _)| pk); + let [pk1, pk2, pk3] = [ + hex::encode(pk1.to_bytes()), + hex::encode(pk2.to_bytes()), + hex::encode(pk3.to_bytes()), + ]; + let network_data = format!( + r#" + [[peers]] + url = "helper1.test" + [peers.hpke] + public_key = "{pk1}" + [[peers]] + url = "helper2.test" + [peers.hpke] + public_key = "{pk2}" + [[peers]] + url = "helper3.test" + [peers.hpke] + public_key = "{pk3}" + "# + ); + file.write(network_data.as_bytes()).unwrap(); + + file + } + + pub fn set_sk>(&mut self, idx: usize, data: I) { + self.inner[idx].1 = IpaPrivateKey::from_bytes(data.as_ref()).unwrap(); + } + + pub fn get_sk(&self, idx: usize) -> Vec { + self.inner[idx].1.to_bytes().to_vec() + } + + pub fn sk_files(&self) -> [NamedTempFile; 3] { + let files = [ + NamedTempFile::new().unwrap(), + NamedTempFile::new().unwrap(), + NamedTempFile::new().unwrap(), + ]; + + self.inner.each_ref().map(|(_, sk)| sk).map(|sk| { + let mut file = NamedTempFile::new().unwrap(); + file.write(hex::encode(sk.to_bytes()).as_bytes()).unwrap(); + file.flush().unwrap(); + + file + }) + } + } + + fn decode_key, T: Deserializable>(input: I) -> T { + let bytes = hex::decode(input).unwrap(); + T::from_bytes(&bytes).unwrap() + } + + pub fn test_ipa_data() -> impl Iterator { + let rng = thread_rng(); + let event_gen_args = EventGeneratorConfig::new(10, 5, 20, 1, 10, 604_800); + + EventGenerator::with_config(rng, event_gen_args) + } + + pub fn write_csv( + data: impl Iterator, + ) -> Result { + let mut file = NamedTempFile::new()?; + for event in data { + let () = event.to_csv(&mut file)?; + writeln!(file)?; + } + + file.flush()?; + + Ok(file) + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::{ + fs::File, + io::{BufRead, BufReader}, + path::Path, + }; + + use clap::Parser; + use tempfile::tempdir; + + use crate::cli::crypto::{decrypt::DecryptArgs, encrypt::EncryptArgs, sample_data}; + + fn are_files_equal(file1: &Path, file2: &Path) { + let file1 = + File::open(file1).unwrap_or_else(|e| panic!("unable to open {}: {e}", file1.display())); + let file2 = + File::open(file2).unwrap_or_else(|e| panic!("unable to open {}: {e}", file2.display())); + let reader1 = BufReader::new(file1).lines(); + let mut reader2 = BufReader::new(file2).lines(); + for line1 in reader1 { + let line2 = reader2.next().expect("Files have different lengths"); + assert_eq!(line1.unwrap(), line2.unwrap()); + } + assert!(reader2.next().is_none(), "Files have different lengths"); + } + + fn build_decrypt_args( + enc1: &Path, + enc2: &Path, + enc3: &Path, + mk_private_key1: &Path, + mk_private_key2: &Path, + mk_private_key3: &Path, + decrypt_output: &Path, + ) -> DecryptArgs { + DecryptArgs::try_parse_from([ + "test_decrypt", + "--input-file1", + enc1.to_str().unwrap(), + "--input-file2", + enc2.to_str().unwrap(), + "--input-file3", + enc3.to_str().unwrap(), + "--mk-private-key1", + mk_private_key1.to_str().unwrap(), + "--mk-private-key2", + mk_private_key2.to_str().unwrap(), + "--mk-private-key3", + mk_private_key3.to_str().unwrap(), + "--output-file", + decrypt_output.to_str().unwrap(), + ]) + .unwrap() + } + + #[tokio::test] + async fn encrypt_and_decrypt() { + let output_dir = tempdir().unwrap(); + let input = sample_data::test_ipa_data().take(10); + let input_file = sample_data::write_csv(input).unwrap(); + let network_file = sample_data::test_keys().network_config(); + EncryptArgs::new(input_file.path(), output_dir.path(), network_file.path()) + .encrypt() + .unwrap(); + + let decrypt_output = output_dir.path().join("output"); + let enc1 = output_dir.path().join("helper1.enc"); + let enc2 = output_dir.path().join("helper2.enc"); + let enc3 = output_dir.path().join("helper3.enc"); + let [mk_private_key1, mk_private_key2, mk_private_key3] = + sample_data::test_keys().sk_files(); + + DecryptArgs::new( + enc1.as_path(), + enc2.as_path(), + enc3.as_path(), + mk_private_key1.path(), + mk_private_key2.path(), + mk_private_key3.path(), + &decrypt_output, + ) + .decrypt_and_reconstruct() + .await + .unwrap(); + + are_files_equal(input_file.path(), &decrypt_output); + } +} From f329272f02b8baa03428252c876e95727cc2a26c Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 20 Sep 2024 09:24:49 -0700 Subject: [PATCH 059/191] Clippy and compact-gate fixes --- ipa-core/src/cli/crypto/encrypt.rs | 2 +- ipa-core/src/cli/crypto/mod.rs | 41 +++--------------------------- 2 files changed, 4 insertions(+), 39 deletions(-) diff --git a/ipa-core/src/cli/crypto/encrypt.rs b/ipa-core/src/cli/crypto/encrypt.rs index 8c74f169f..6f174f89a 100644 --- a/ipa-core/src/cli/crypto/encrypt.rs +++ b/ipa-core/src/cli/crypto/encrypt.rs @@ -92,7 +92,7 @@ impl EncryptArgs { } } -#[cfg(test)] +#[cfg(all(test, unit_test))] mod tests { use std::{io::Write, sync::Arc}; diff --git a/ipa-core/src/cli/crypto/mod.rs b/ipa-core/src/cli/crypto/mod.rs index 3703038e2..08322cd85 100644 --- a/ipa-core/src/cli/crypto/mod.rs +++ b/ipa-core/src/cli/crypto/mod.rs @@ -85,7 +85,7 @@ mod sample_data { public_key = "{pk3}" "# ); - file.write(network_data.as_bytes()).unwrap(); + file.write_all(network_data.as_bytes()).unwrap(); file } @@ -99,15 +99,10 @@ mod sample_data { } pub fn sk_files(&self) -> [NamedTempFile; 3] { - let files = [ - NamedTempFile::new().unwrap(), - NamedTempFile::new().unwrap(), - NamedTempFile::new().unwrap(), - ]; - self.inner.each_ref().map(|(_, sk)| sk).map(|sk| { let mut file = NamedTempFile::new().unwrap(); - file.write(hex::encode(sk.to_bytes()).as_bytes()).unwrap(); + file.write_all(hex::encode(sk.to_bytes()).as_bytes()) + .unwrap(); file.flush().unwrap(); file @@ -150,7 +145,6 @@ mod tests { path::Path, }; - use clap::Parser; use tempfile::tempdir; use crate::cli::crypto::{decrypt::DecryptArgs, encrypt::EncryptArgs, sample_data}; @@ -169,35 +163,6 @@ mod tests { assert!(reader2.next().is_none(), "Files have different lengths"); } - fn build_decrypt_args( - enc1: &Path, - enc2: &Path, - enc3: &Path, - mk_private_key1: &Path, - mk_private_key2: &Path, - mk_private_key3: &Path, - decrypt_output: &Path, - ) -> DecryptArgs { - DecryptArgs::try_parse_from([ - "test_decrypt", - "--input-file1", - enc1.to_str().unwrap(), - "--input-file2", - enc2.to_str().unwrap(), - "--input-file3", - enc3.to_str().unwrap(), - "--mk-private-key1", - mk_private_key1.to_str().unwrap(), - "--mk-private-key2", - mk_private_key2.to_str().unwrap(), - "--mk-private-key3", - mk_private_key3.to_str().unwrap(), - "--output-file", - decrypt_output.to_str().unwrap(), - ]) - .unwrap() - } - #[tokio::test] async fn encrypt_and_decrypt() { let output_dir = tempdir().unwrap(); From 4c393e3ec49649a33d0d819aaba55a2840762130 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 20 Sep 2024 09:47:42 -0700 Subject: [PATCH 060/191] Fix pre-commit script This fix makes pre-commit not to run checks behind environment variable flag if it is not set --- .pre-commit.stash4zjacb | 0 pre-commit | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 .pre-commit.stash4zjacb diff --git a/.pre-commit.stash4zjacb b/.pre-commit.stash4zjacb new file mode 100644 index 000000000..e69de29bb diff --git a/pre-commit b/pre-commit index 3edbe2781..babada0ef 100755 --- a/pre-commit +++ b/pre-commit @@ -88,7 +88,7 @@ check "Clippy checks" \ check "Tests" \ cargo test --features="cli test-fixture" -if [ -z "$EXEC_SLOW_TESTS" ] +if [ -n "$EXEC_SLOW_TESTS" ] then check "Benchmark compilation" \ From db732e96cc2444356ac3d5f9542171e160259643 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 20 Sep 2024 11:38:10 -0700 Subject: [PATCH 061/191] Feedback --- ipa-core/src/cli/crypto/mod.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ipa-core/src/cli/crypto/mod.rs b/ipa-core/src/cli/crypto/mod.rs index 08322cd85..0bcb1f629 100644 --- a/ipa-core/src/cli/crypto/mod.rs +++ b/ipa-core/src/cli/crypto/mod.rs @@ -21,7 +21,7 @@ mod sample_data { /// Keys that are used in crypto tests #[derive(Clone)] pub(super) struct TestKeys { - inner: [(IpaPublicKey, IpaPrivateKey); 3], + key_pairs: [(IpaPublicKey, IpaPrivateKey); 3], } static TEST_KEYS: OnceLock = OnceLock::new(); @@ -32,7 +32,7 @@ mod sample_data { impl TestKeys { pub fn new() -> Self { Self { - inner: [ + key_pairs: [ ( decode_key::<_, IpaPublicKey>( "92a6fb666c37c008defd74abf3204ebea685742eab8347b08e2f7c759893947a", @@ -63,7 +63,7 @@ mod sample_data { pub fn network_config(&self) -> NamedTempFile { let mut file = NamedTempFile::new().unwrap(); - let [pk1, pk2, pk3] = self.inner.each_ref().map(|(pk, _)| pk); + let [pk1, pk2, pk3] = self.key_pairs.each_ref().map(|(pk, _)| pk); let [pk1, pk2, pk3] = [ hex::encode(pk1.to_bytes()), hex::encode(pk2.to_bytes()), @@ -91,15 +91,15 @@ mod sample_data { } pub fn set_sk>(&mut self, idx: usize, data: I) { - self.inner[idx].1 = IpaPrivateKey::from_bytes(data.as_ref()).unwrap(); + self.key_pairs[idx].1 = IpaPrivateKey::from_bytes(data.as_ref()).unwrap(); } pub fn get_sk(&self, idx: usize) -> Vec { - self.inner[idx].1.to_bytes().to_vec() + self.key_pairs[idx].1.to_bytes().to_vec() } pub fn sk_files(&self) -> [NamedTempFile; 3] { - self.inner.each_ref().map(|(_, sk)| sk).map(|sk| { + self.key_pairs.each_ref().map(|(_, sk)| sk).map(|sk| { let mut file = NamedTempFile::new().unwrap(); file.write_all(hex::encode(sk.to_bytes()).as_bytes()) .unwrap(); From 2d7b95dd68566fa1147ac815e10b347f4c4dca8e Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Fri, 20 Sep 2024 13:53:43 -0700 Subject: [PATCH 062/191] Adding missing feature check --- ipa-core/benches/oneshot/ipa.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ipa-core/benches/oneshot/ipa.rs b/ipa-core/benches/oneshot/ipa.rs index 02f6e0304..3d19836ee 100644 --- a/ipa-core/benches/oneshot/ipa.rs +++ b/ipa-core/benches/oneshot/ipa.rs @@ -19,7 +19,11 @@ use ipa_step::StepNarrow; use rand::{random, rngs::StdRng, SeedableRng}; use tokio::runtime::Builder; -#[cfg(all(not(target_env = "msvc"), not(feature = "dhat-heap")))] +#[cfg(all( + not(target_env = "msvc"), + not(feature = "dhat-heap"), + not(target_os = "macos") +))] #[global_allocator] static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; From c03e154ebf56f1e8c3de46f9a649483e47393bca Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Fri, 20 Sep 2024 15:16:30 -0700 Subject: [PATCH 063/191] Hybrid event gen (#1291) * fix compilation bugs * add integration test for hybrid in the clear * add another cfg flag * remove test mod * add test to Cargo.toml * Update ipa-core/Cargo.toml --- ipa-core/Cargo.toml | 12 +++++ ipa-core/src/bin/in_the_clear.rs | 72 ++++++++++++++++++++++++++++++ ipa-core/src/cli/playbook/input.rs | 39 +++++++++++++++- ipa-core/tests/hybrid.rs | 48 ++++++++++++++++++++ 4 files changed, 169 insertions(+), 2 deletions(-) create mode 100644 ipa-core/src/bin/in_the_clear.rs create mode 100644 ipa-core/tests/hybrid.rs diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 3e816b86c..3646acacd 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -197,6 +197,11 @@ name = "crypto_util" required-features = ["cli", "test-fixture", "web-app"] bench = false +[[bin]] +name = "in_the_clear" +required-features = ["cli", "test-fixture", "web-app"] +bench = false + [[bench]] name = "criterion_arithmetic" path = "benches/ct/arithmetic_circuit.rs" @@ -250,3 +255,10 @@ required-features = [ "real-world-infra", "test-fixture", ] + +[[test]] +name = "hybrid" +required-features = [ + "test-fixture", + "cli", +] diff --git a/ipa-core/src/bin/in_the_clear.rs b/ipa-core/src/bin/in_the_clear.rs new file mode 100644 index 000000000..16b2235df --- /dev/null +++ b/ipa-core/src/bin/in_the_clear.rs @@ -0,0 +1,72 @@ +use std::{error::Error, fs::File, io::Write, num::NonZeroU32, path::PathBuf}; + +use clap::Parser; +use ipa_core::{ + cli::{playbook::InputSource, Verbosity}, + test_fixture::hybrid::{hybrid_in_the_clear, TestHybridRecord}, +}; + +#[derive(Debug, Parser)] +pub struct CommandInput { + #[arg( + long, + help = "Read the input from the provided file, instead of standard input" + )] + input_file: Option, +} + +impl From<&CommandInput> for InputSource { + fn from(source: &CommandInput) -> Self { + if let Some(ref file_name) = source.input_file { + InputSource::from_file(file_name) + } else { + InputSource::from_stdin() + } + } +} + +#[derive(Debug, Parser)] +#[clap(name = "in_the_clear", about = "In the Clear CLI")] +#[command(about)] +struct Args { + #[clap(flatten)] + logging: Verbosity, + + #[clap(flatten)] + input: CommandInput, + + /// The destination file for output. + #[arg(long, value_name = "OUTPUT_FILE")] + output_file: PathBuf, + + #[arg(long, default_value = "20")] + max_breakdown_key: NonZeroU32, +} + +fn main() -> Result<(), Box> { + let args = Args::parse(); + let _handle = args.logging.setup_logging(); + + let input = InputSource::from(&args.input); + + let input_rows = input.iter::().collect::>(); + let expected = hybrid_in_the_clear( + &input_rows, + usize::try_from(args.max_breakdown_key.get()).unwrap(), + ); + + let mut file = File::options() + .write(true) + .create_new(true) + .open(&args.output_file) + .map_err(|e| { + format!( + "Failed to create output file {}: {e}", + &args.output_file.display() + ) + })?; + + write!(file, "{}", serde_json::to_string_pretty(&expected)?)?; + + Ok(()) +} diff --git a/ipa-core/src/cli/playbook/input.rs b/ipa-core/src/cli/playbook/input.rs index b7d31484d..5015d5587 100644 --- a/ipa-core/src/cli/playbook/input.rs +++ b/ipa-core/src/cli/playbook/input.rs @@ -7,8 +7,9 @@ use std::{ }; use crate::{ - cli::playbook::generator::U128Generator, ff::U128Conversions, - test_fixture::ipa::TestRawDataRecord, + cli::playbook::generator::U128Generator, + ff::U128Conversions, + test_fixture::{hybrid::TestHybridRecord, ipa::TestRawDataRecord}, }; pub trait InputItem { @@ -56,6 +57,40 @@ impl InputItem for TestRawDataRecord { } } +impl InputItem for TestHybridRecord { + fn from_str(s: &str) -> Self { + if let [event_type, match_key, number] = s.splitn(3, ',').collect::>()[..] { + let match_key: u64 = match_key + .parse() + .unwrap_or_else(|e| panic!("Expected an u64, got {match_key}: {e}")); + + let number: u32 = number + .parse() + .unwrap_or_else(|e| panic!("Expected an u32, got {number}: {e}")); + + match event_type { + "i" => TestHybridRecord::TestImpression { + match_key, + breakdown_key: number, + }, + + "c" => TestHybridRecord::TestConversion { + match_key, + value: number, + }, + _ => panic!( + "{}", + format!( + "Invalid input. Rows should start with 'i' or 'c'. Did not expect {event_type}" + ) + ), + } + } else { + panic!("{s} is not a valid {}", type_name::()) + } + } +} + pub struct InputSource { inner: Box, sz: Option, diff --git a/ipa-core/tests/hybrid.rs b/ipa-core/tests/hybrid.rs new file mode 100644 index 000000000..06caabbce --- /dev/null +++ b/ipa-core/tests/hybrid.rs @@ -0,0 +1,48 @@ +// some pub functions in `common` to be compiled, and rust complains about dead code. +#[allow(dead_code)] +mod common; + +use std::process::{Command, Stdio}; + +use common::{tempdir::TempDir, CommandExt, UnwrapStatusExt, TEST_RC_BIN}; +use rand::thread_rng; +use rand_core::RngCore; + +pub const IN_THE_CLEAR_BIN: &str = env!("CARGO_BIN_EXE_in_the_clear"); + +// this currently only generates data and runs in the clear +// eventaully we'll want to add the MPC as well +#[test] +fn test_hybrid() { + const INPUT_SIZE: usize = 100; + const MAX_CONVERSION_VALUE: usize = 5; + const MAX_BREAKDOWN_KEY: usize = 20; + const MAX_CONVS_PER_IMP: usize = 10; + + let dir = TempDir::new_delete_on_drop(); + + // Gen inputs + let input_file = dir.path().join("ipa_inputs.txt"); + let output_file = dir.path().join("ipa_output.json"); + + let mut command = Command::new(TEST_RC_BIN); + command + .args(["--output-file".as_ref(), input_file.as_os_str()]) + .arg("gen-hybrid-inputs") + .args(["--count", &INPUT_SIZE.to_string()]) + .args(["--max-conversion-value", &MAX_CONVERSION_VALUE.to_string()]) + .args(["--max-breakdown-key", &MAX_BREAKDOWN_KEY.to_string()]) + .args(["--max-convs-per-imp", &MAX_CONVS_PER_IMP.to_string()]) + .args(["--seed", &thread_rng().next_u64().to_string()]) + .silent() + .stdin(Stdio::piped()); + command.status().unwrap_status(); + + let mut command = Command::new(IN_THE_CLEAR_BIN); + command + .args(["--input-file".as_ref(), input_file.as_os_str()]) + .args(["--output-file".as_ref(), output_file.as_os_str()]) + .silent() + .stdin(Stdio::piped()); + command.status().unwrap_status(); +} From 9147925d86e0adde24758d2ef9dadf917a85b581 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 20 Sep 2024 15:29:35 -0700 Subject: [PATCH 064/191] Delete .pre-commit.stash4zjacb --- .pre-commit.stash4zjacb | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .pre-commit.stash4zjacb diff --git a/.pre-commit.stash4zjacb b/.pre-commit.stash4zjacb deleted file mode 100644 index e69de29bb..000000000 From 6fc683b8fab2b14e5b2bd1b7eef2ab7b7a808f22 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Fri, 20 Sep 2024 16:43:53 -0700 Subject: [PATCH 065/191] Making this test less flaky, and good error reports --- ipa-core/src/test_fixture/hybrid_event_gen.rs | 112 +++++++----------- 1 file changed, 41 insertions(+), 71 deletions(-) diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs index 8d44d38b3..eb2ee616a 100644 --- a/ipa-core/src/test_fixture/hybrid_event_gen.rs +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -1,4 +1,4 @@ -use std::num::NonZeroU32; +use std::{iter::zip, num::NonZeroU32}; use rand::Rng; @@ -175,10 +175,25 @@ mod tests { #[test] fn default_config() { + const EXPECTED_HISTOGRAM_WITH_TOLERANCE: [(i32, f64); 12] = [ + (0, 0.0), + (647634, 0.01), + (137626, 0.01), + (20652, 0.02), + (3085, 0.05), + (463, 0.12), + (70, 0.5), + (10, 1.0), + (2, 1.0), + (0, 1.0), + (0, 1.0), + (0, 1.0), + ]; + const TEST_COUNT: usize = 1_000_000; let gen = EventGenerator::with_default_config(thread_rng()); let max_convs_per_imp = gen.config.max_convs_per_imp.get(); let mut match_key_to_event_count = HashMap::new(); - for event in gen.take(10_000) { + for event in gen.take(TEST_COUNT) { match event { TestHybridRecord::TestImpression { match_key, .. } => { match_key_to_event_count @@ -200,44 +215,25 @@ mod tests { histogram[count] += 1; } - assert!( - (6470 - histogram[1]).abs() < 200, - "expected {:?} unmatched events, got {:?}", - 647, - histogram[1] - ); - - assert!( - (1370 - histogram[2]).abs() < 100, - "expected {:?} unmatched events, got {:?}", - 137, - histogram[2] - ); - - assert!( - (200 - histogram[3]).abs() < 50, - "expected {:?} unmatched events, got {:?}", - 20, - histogram[3] - ); - - assert!( - (30 - histogram[4]).abs() < 40, - "expected {:?} unmatched events, got {:?}", - 3, - histogram[4] - ); - - assert!( - (0 - histogram[11]).abs() < 10, - "expected {:?} unmatched events, got {:?}", - 0, - histogram[11] - ); + for (actual, (expected, tolerance)) in + zip(histogram, EXPECTED_HISTOGRAM_WITH_TOLERANCE.iter()) + { + let max_tolerance = (*expected as f64) * tolerance + 10.0; + assert!( + (expected - actual).abs() as f64 <= max_tolerance, + "expected {:?} unmatched events, got {:?}", + expected, + actual, + ); + } } #[test] fn lots_of_repeat_conversions() { + const EXPECTED_HISTOGRAM: [i32; 12] = [ + 0, 299296, 25640, 20542, 16421, 13133, 10503, 8417, 6730, 5391, 4289, 17206, + ]; + const TEST_COUNT: usize = 1_000_000; const MAX_CONVS_PER_IMP: u32 = 10; const MAX_BREAKDOWN_KEY: u32 = 20; const MAX_VALUE: u32 = 3; @@ -252,7 +248,7 @@ mod tests { ); let max_convs_per_imp = gen.config.max_convs_per_imp.get(); let mut match_key_to_event_count = HashMap::new(); - for event in gen.take(100_000) { + for event in gen.take(TEST_COUNT) { match event { TestHybridRecord::TestImpression { match_key, @@ -279,40 +275,14 @@ mod tests { histogram[count] += 1; } - assert!( - (30_032 - histogram[1]).abs() < 800, - "expected {:?} unmatched events, got {:?}", - 30_032, - histogram[1] - ); - - assert!( - (2_572 - histogram[2]).abs() < 300, - "expected {:?} unmatched events, got {:?}", - 2_572, - histogram[2] - ); - - assert!( - (2_048 - histogram[3]).abs() < 200, - "expected {:?} unmatched events, got {:?}", - 2_048, - histogram[3] - ); - - assert!( - (1_650 - histogram[4]).abs() < 100, - "expected {:?} unmatched events, got {:?}", - 1_650, - histogram[4] - ); - - assert!( - (1_718 - histogram[11]).abs() < 100, - "expected {:?} unmatched events, got {:?}", - 1_718, - histogram[11] - ); + for (expected, actual) in zip(EXPECTED_HISTOGRAM.iter(), histogram) { + assert!( + (expected - actual).abs() <= expected / 20 + 10, + "expected {:?} unmatched events, got {:?}", + expected, + actual, + ); + } } #[test] From dc9fad53623488fb9c71769d9f0fd62bffde2b6d Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Fri, 20 Sep 2024 17:03:29 -0700 Subject: [PATCH 066/191] Improving it --- ipa-core/src/test_fixture/hybrid_event_gen.rs | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs index eb2ee616a..c332cde2b 100644 --- a/ipa-core/src/test_fixture/hybrid_event_gen.rs +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -175,6 +175,11 @@ mod tests { #[test] fn default_config() { + // Since there is randomness, the actual number will be a bit different + // from the expected value. + // The "tolerance" is used to compute the allowable range of values. + // It is multiplied by the expected value. So a tolerance of 0.05 means + // we will accept a value within 5% of the expected value const EXPECTED_HISTOGRAM_WITH_TOLERANCE: [(i32, f64); 12] = [ (0, 0.0), (647634, 0.01), @@ -218,12 +223,16 @@ mod tests { for (actual, (expected, tolerance)) in zip(histogram, EXPECTED_HISTOGRAM_WITH_TOLERANCE.iter()) { + // Adding a constant value of 10 is a way of dealing with the high variability small values + // which will vary a lot more (as a percent). Because 10 is an increasingly large percentage of + // A smaller and smaller expected value let max_tolerance = (*expected as f64) * tolerance + 10.0; assert!( (expected - actual).abs() as f64 <= max_tolerance, - "expected {:?} unmatched events, got {:?}", - expected, + "{:?} is outside of the expected range: ({:?}..{:?})", actual, + (*expected as f64) - max_tolerance, + (*expected as f64) + max_tolerance, ); } } @@ -276,11 +285,13 @@ mod tests { } for (expected, actual) in zip(EXPECTED_HISTOGRAM.iter(), histogram) { + let max_tolerance = (*expected as f64) * 0.05 + 10.0; assert!( - (expected - actual).abs() <= expected / 20 + 10, - "expected {:?} unmatched events, got {:?}", - expected, + (expected - actual).abs() as f64 <= max_tolerance, + "{:?} is outside of the expected range: ({:?}..{:?})", actual, + (*expected as f64) - max_tolerance, + (*expected as f64) + max_tolerance, ); } } From 797c39f415c08a5ec49cf7b86c0ad9174b1c35b9 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Fri, 20 Sep 2024 17:13:03 -0700 Subject: [PATCH 067/191] Fix import --- ipa-core/src/bin/in_the_clear.rs | 65 +++++++++++++++++++ ipa-core/src/test_fixture/hybrid_event_gen.rs | 13 ++-- 2 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 ipa-core/src/bin/in_the_clear.rs diff --git a/ipa-core/src/bin/in_the_clear.rs b/ipa-core/src/bin/in_the_clear.rs new file mode 100644 index 000000000..e48fa77a4 --- /dev/null +++ b/ipa-core/src/bin/in_the_clear.rs @@ -0,0 +1,65 @@ +use std::{ + error::Error, + path::{Path, PathBuf}, +}; + +use clap::Parser; +use ipa_core::{ + cli::playbook::InputSource, + test_fixture::hybrid::{hybrid_in_the_clear, TestHybridRecord}, +}; + +#[derive(Debug, Parser)] +pub struct CommandInput { + #[arg( + long, + help = "Read the input from the provided file, instead of standard input" + )] + input_file: Option, +} + +impl From<&CommandInput> for InputSource { + fn from(source: &CommandInput) -> Self { + if let Some(ref file_name) = source.input_file { + InputSource::from_file(file_name) + } else { + InputSource::from_stdin() + } + } +} + +#[derive(Debug, Parser)] +#[clap(name = "rc", about = "Report Collector CLI")] +#[command(about)] +struct Args { + #[clap(flatten)] + input: CommandInput, + + /// The destination file for output. + #[arg(long, value_name = "OUTPUT_FILE")] + output_file: PathBuf, +} + +fn main() -> Result<(), Box> { + let args = Args::parse(); + + let input = InputSource::from(&args.input); + + let input_rows = input.iter::().collect::>(); + let expected = hybrid_in_the_clear(&input_rows, 10); + + let mut file = File::options() + .write(true) + .create_new(true) + .open(args.output_file) + .map_err(|e| { + format!( + "Failed to create output file {}: {e}", + args.output_file.display() + ) + })?; + + write!(file, "{}", serde_json::to_string_pretty(&expected)?)?; + + Ok(()) +} diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs index c332cde2b..7eddb0ba4 100644 --- a/ipa-core/src/test_fixture/hybrid_event_gen.rs +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -1,4 +1,4 @@ -use std::{iter::zip, num::NonZeroU32}; +use std::num::NonZeroU32; use rand::Rng; @@ -158,7 +158,10 @@ impl Iterator for EventGenerator { #[cfg(all(test, unit_test))] mod tests { - use std::collections::{HashMap, HashSet}; + use std::{ + collections::{HashMap, HashSet}, + iter::zip, + }; use rand::thread_rng; @@ -181,9 +184,9 @@ mod tests { // It is multiplied by the expected value. So a tolerance of 0.05 means // we will accept a value within 5% of the expected value const EXPECTED_HISTOGRAM_WITH_TOLERANCE: [(i32, f64); 12] = [ - (0, 0.0), - (647634, 0.01), - (137626, 0.01), + (0, 0.0), + (647634, 0.01), + (137626, 0.01), (20652, 0.02), (3085, 0.05), (463, 0.12), From 43104a17808ff5a5439144df09e55e9184720d3c Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Fri, 20 Sep 2024 17:15:14 -0700 Subject: [PATCH 068/191] whoops --- ipa-core/src/bin/in_the_clear.rs | 65 -------------------------------- 1 file changed, 65 deletions(-) delete mode 100644 ipa-core/src/bin/in_the_clear.rs diff --git a/ipa-core/src/bin/in_the_clear.rs b/ipa-core/src/bin/in_the_clear.rs deleted file mode 100644 index e48fa77a4..000000000 --- a/ipa-core/src/bin/in_the_clear.rs +++ /dev/null @@ -1,65 +0,0 @@ -use std::{ - error::Error, - path::{Path, PathBuf}, -}; - -use clap::Parser; -use ipa_core::{ - cli::playbook::InputSource, - test_fixture::hybrid::{hybrid_in_the_clear, TestHybridRecord}, -}; - -#[derive(Debug, Parser)] -pub struct CommandInput { - #[arg( - long, - help = "Read the input from the provided file, instead of standard input" - )] - input_file: Option, -} - -impl From<&CommandInput> for InputSource { - fn from(source: &CommandInput) -> Self { - if let Some(ref file_name) = source.input_file { - InputSource::from_file(file_name) - } else { - InputSource::from_stdin() - } - } -} - -#[derive(Debug, Parser)] -#[clap(name = "rc", about = "Report Collector CLI")] -#[command(about)] -struct Args { - #[clap(flatten)] - input: CommandInput, - - /// The destination file for output. - #[arg(long, value_name = "OUTPUT_FILE")] - output_file: PathBuf, -} - -fn main() -> Result<(), Box> { - let args = Args::parse(); - - let input = InputSource::from(&args.input); - - let input_rows = input.iter::().collect::>(); - let expected = hybrid_in_the_clear(&input_rows, 10); - - let mut file = File::options() - .write(true) - .create_new(true) - .open(args.output_file) - .map_err(|e| { - format!( - "Failed to create output file {}: {e}", - args.output_file.display() - ) - })?; - - write!(file, "{}", serde_json::to_string_pretty(&expected)?)?; - - Ok(()) -} From 2dfe00c9a37fff0e74c299bbd263e81f82754a40 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Fri, 20 Sep 2024 18:12:27 -0700 Subject: [PATCH 069/191] add feature flag for relaxed dp. integration tests run with only non-relaxed-dp (#1297) --- .github/workflows/check.yml | 15 +++++++-- ipa-core/Cargo.toml | 13 ++++++++ ipa-core/src/query/runner/oprf_ipa.rs | 4 +-- ipa-core/tests/compact_gate.rs | 10 ++++++ ipa-core/tests/helper_networks.rs | 1 + ipa-core/tests/ipa_with_relaxed_dp.rs | 48 +++++++++++++++++++++++++++ pre-commit | 16 +++++++-- scripts/coverage-ci | 5 ++- 8 files changed, 103 insertions(+), 9 deletions(-) create mode 100644 ipa-core/tests/ipa_with_relaxed_dp.rs diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index bcec6b0f8..4dfdbdf2d 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -68,7 +68,7 @@ jobs: run: cargo build --tests - name: Run tests - run: cargo test --features "cli test-fixture" + run: cargo test --features "cli test-fixture relaxed-dp" - name: Run tests with multithreading feature enabled run: cargo test --features "multi-threading" @@ -172,8 +172,17 @@ jobs: target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} - - name: End-to-end tests - run: cargo test --release --test "*" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + - name: Integration Tests - Compact Gate + run: cargo test --release --test "compact_gate" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + + - name: Integration Tests - Helper Networks + run: cargo test --release --test "helper_networks" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + + - name: Integration Tests - Hybrid + run: cargo test --release --test "hybrid" --features "cli test-fixture" + + - name: Integration Tests - IPA with Relaxed DP + run: cargo test --release --test "ipa_with_relaxed_dp" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate relaxed-dp" # sanitizers currently require nightly https://github.com/rust-lang/rust/issues/39699 sanitize: diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 3646acacd..4a495d50e 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -74,6 +74,8 @@ reveal-aggregation = [] aggregate-circuit = [] # IPA protocol based on OPRF ipa-prf = [] +# relaxed DP, off by default +relaxed-dp = [] [dependencies] ipa-step = { version = "*", path = "../ipa-step" } @@ -256,6 +258,17 @@ required-features = [ "test-fixture", ] +[[test]] +name = "ipa_with_relaxed_dp" +required-features = [ + "cli", + "compact-gate", + "web-app", + "real-world-infra", + "test-fixture", + "relaxed-dp", +] + [[test]] name = "hybrid" required-features = [ diff --git a/ipa-core/src/query/runner/oprf_ipa.rs b/ipa-core/src/query/runner/oprf_ipa.rs index b2ec0cffc..fa8b787d8 100644 --- a/ipa-core/src/query/runner/oprf_ipa.rs +++ b/ipa-core/src/query/runner/oprf_ipa.rs @@ -142,9 +142,9 @@ where }, }; - #[cfg(any(test, feature = "cli", feature = "test-fixture"))] + #[cfg(feature = "relaxed-dp")] let padding_params = PaddingParameters::relaxed(); - #[cfg(not(any(test, feature = "cli", feature = "test-fixture")))] + #[cfg(not(feature = "relaxed-dp"))] let padding_params = PaddingParameters::default(); match config.per_user_credit_cap { 8 => oprf_ipa::<_, BA8, BA3, HV, BA20, 3, 256>(ctx, input, aws, dp_params, padding_params).await, diff --git a/ipa-core/tests/compact_gate.rs b/ipa-core/tests/compact_gate.rs index 05d635f19..5cc83beaf 100644 --- a/ipa-core/tests/compact_gate.rs +++ b/ipa-core/tests/compact_gate.rs @@ -37,11 +37,21 @@ fn compact_gate_cap_8_no_window_semi_honest_plaintext_input() { } #[test] +/// This test is turned off because of [`issue`]. +/// +/// This test will hang without `relaxed-dp` feature turned out until it is fixed +/// [`issue`]: https://github.com/private-attribution/ipa/issues/1298 +#[ignore] fn compact_gate_cap_8_no_window_malicious_encrypted_input() { test_compact_gate(IpaSecurityModel::Malicious, 8, 0, true); } #[test] +/// This test is turned off because of [`issue`]. +/// +/// This test will hang without `relaxed-dp` feature turned out until it is fixed +/// [`issue`]: https://github.com/private-attribution/ipa/issues/1298 +#[ignore] fn compact_gate_cap_8_no_window_malicious_plaintext_input() { test_compact_gate(IpaSecurityModel::Malicious, 8, 0, false); } diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index ce0b469c9..7775ffba4 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -56,6 +56,7 @@ fn https_semi_honest_ipa() { #[test] #[cfg(all(test, web_test))] +#[ignore] fn https_malicious_ipa() { test_ipa(IpaSecurityModel::Malicious, true, true); } diff --git a/ipa-core/tests/ipa_with_relaxed_dp.rs b/ipa-core/tests/ipa_with_relaxed_dp.rs new file mode 100644 index 000000000..84c4c2a7b --- /dev/null +++ b/ipa-core/tests/ipa_with_relaxed_dp.rs @@ -0,0 +1,48 @@ +#[allow(dead_code)] +mod common; + +use std::num::NonZeroU32; + +use common::{test_ipa, test_ipa_with_config}; +use ipa_core::{helpers::query::IpaQueryConfig, test_fixture::ipa::IpaSecurityModel}; + +fn build_config() -> IpaQueryConfig { + IpaQueryConfig { + per_user_credit_cap: 8, + attribution_window_seconds: NonZeroU32::new(0), + with_dp: 0, + ..Default::default() + } +} + +#[test] +fn relaxed_dp_semi_honest() { + let encrypted_input = false; + let config = build_config(); + + test_ipa_with_config( + IpaSecurityModel::SemiHonest, + encrypted_input, + config, + encrypted_input, + ); +} + +#[test] +fn relaxed_dp_malicious() { + let encrypted_input = false; + let config = build_config(); + + test_ipa_with_config( + IpaSecurityModel::Malicious, + encrypted_input, + config, + encrypted_input, + ); +} + +#[test] +#[cfg(all(test, web_test))] +fn relaxed_dp_https_malicious_ipa() { + test_ipa(IpaSecurityModel::Malicious, true, true); +} diff --git a/pre-commit b/pre-commit index babada0ef..9b1c5cb37 100755 --- a/pre-commit +++ b/pre-commit @@ -86,7 +86,7 @@ check "Clippy checks" \ cargo clippy --features="cli test-fixture" --tests -- -D warnings check "Tests" \ - cargo test --features="cli test-fixture" + cargo test --features="cli test-fixture relaxed-dp" if [ -n "$EXEC_SLOW_TESTS" ] then @@ -117,6 +117,16 @@ then check "Arithmetic circuit benchmark" \ cargo bench --bench oneshot_arithmetic --no-default-features --features "enable-benches compact-gate" - check "Slow tests" \ - cargo test --release --test "*" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + check "Slow tests: Compact Gate" \ + cargo test --release --test "compact_gate" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + + check "Slow tests: Helper Networks" \ + cargo test --release --test "helper_networks" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" + + check "Slow tests: Hybrid tests" \ + cargo test --release --test "hybrid" --features "cli test-fixture" + + + check "Slow tests: IPA with Relaxed DP" \ + cargo test --release --test "ipa_with_relaxed_dp" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate relaxed-dp" fi diff --git a/scripts/coverage-ci b/scripts/coverage-ci index 6a5836c55..2d448daa7 100755 --- a/scripts/coverage-ci +++ b/scripts/coverage-ci @@ -9,13 +9,16 @@ cargo llvm-cov clean --workspace cargo build --all-targets # Need to be kept in sync manually with tests we run inside check.yml. -cargo test --features "cli test-fixture" +cargo test --features "cli test-fixture relaxed-dp" # descriptive-gate does not require a feature flag. for gate in "compact-gate" ""; do cargo test --no-default-features --features "cli web-app real-world-infra test-fixture $gate" done +# integration tests run without relaxed dp, except for these +cargo test --release --test "ipa_with_relaxed_dp" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate relaxed-dp" + cargo test --bench oneshot_ipa --no-default-features --features "enable-benches compact-gate" -- -n 62 -c 16 cargo test --bench criterion_arithmetic --no-default-features --features "enable-benches compact-gate" From bf04fc0b6fb560669adcb10c426bdc72cce52ad1 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Fri, 20 Sep 2024 18:22:48 -0700 Subject: [PATCH 070/191] Moving to reveal based aggregation --- ipa-core/src/ff/boolean_array.rs | 6 +++- .../src/protocol/ipa_prf/prf_sharding/mod.rs | 34 +++++++------------ ipa-core/src/protocol/ipa_prf/shuffle/mod.rs | 5 +-- 3 files changed, 21 insertions(+), 24 deletions(-) diff --git a/ipa-core/src/ff/boolean_array.rs b/ipa-core/src/ff/boolean_array.rs index 3df41f269..43bfee4a2 100644 --- a/ipa-core/src/ff/boolean_array.rs +++ b/ipa-core/src/ff/boolean_array.rs @@ -5,7 +5,7 @@ use bitvec::{ slice::Iter, }; use generic_array::GenericArray; -use typenum::{U14, U18, U2, U32, U8}; +use typenum::{U12, U14, U18, U2, U32, U8}; use crate::{ error::{Error, LengthError}, @@ -862,6 +862,9 @@ macro_rules! boolean_array_impl_large { //impl store for U8 store_impl!(U8, 64); +//impl store for U12 +store_impl!(U12, 96); + //impl store for U14 store_impl!(U14, 112); @@ -890,6 +893,7 @@ boolean_array_impl_small!(boolean_array_16, BA16, 16, infallible); boolean_array_impl_small!(boolean_array_20, BA20, 20, fallible); boolean_array_impl_small!(boolean_array_32, BA32, 32, infallible); boolean_array_impl_small!(boolean_array_64, BA64, 64, infallible); +boolean_array_impl_small!(boolean_array_96, BA96, 96, infallible); boolean_array_impl_small!(boolean_array_112, BA112, 112, infallible); boolean_array_impl_large!(boolean_array_144, BA144, 144, infallible, U18); boolean_array_impl_large!(boolean_array_256, BA256, 256, infallible, U32); diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 16e9fea14..111cf7552 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -537,27 +537,19 @@ where let ctx = sh_ctx.narrow(&Step::Aggregate); - // New aggregation is still experimental, we need proofs that it is private, - // hence it is only enabled behind a feature flag. - if cfg!(feature = "reveal-aggregation") { - // If there was any error in attribution we stop the execution with an error - tracing::warn!("Using the experimental aggregation based on revealing breakdown keys"); - let validator = ctx.dzkp_validator( - MaliciousProtocolSteps { - protocol: &AggregationStep::AggregateChunk(0), - validate: &AggregationStep::AggregateChunkValidate(0), - }, - aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()), - ); - let user_contributions = flattened_user_results.try_collect::>().await?; - let result = - breakdown_reveal_aggregation::<_, _, _, HV, B>(validator.context(), user_contributions) - .await; - validator.validate().await?; - result - } else { - aggregate_contributions::<_, _, _, _, HV, B>(ctx, flattened_user_results, num_outputs).await - } + let validator = ctx.dzkp_validator( + MaliciousProtocolSteps { + protocol: &AggregationStep::AggregateChunk(0), + validate: &AggregationStep::AggregateChunkValidate(0), + }, + aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()), + ); + let user_contributions = flattened_user_results.try_collect::>().await?; + let result = + breakdown_reveal_aggregation::<_, _, _, HV, B>(validator.context(), user_contributions) + .await; + validator.validate().await?; + result } #[tracing::instrument(name = "attribute_cap", skip_all, fields(unique_match_keys = input.len()))] diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index 4119878d4..2154374cf 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -11,7 +11,7 @@ use crate::{ error::Error, ff::{ boolean::Boolean, - boolean_array::{BooleanArray, BA112, BA144, BA64}, + boolean_array::{BooleanArray, BA112, BA144, BA64, BA96}, ArrayAccess, }, protocol::{ @@ -140,7 +140,8 @@ where .map(|item| attribution_outputs_to_shuffle_input::(&item)) .collect::>(); - let (shuffled, _) = shuffle_protocol(ctx, shuffle_input).await?; + //let (shuffled, _) = shuffle_protocol(ctx, shuffle_input).await?; + let shuffled = malicious_shuffle::<_, R, BA96, _>(ctx, shuffle_input).await?; Ok(shuffled .into_iter() From 89dfa930ea406f0eaff34afb4b5c2837fe94e6ee Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Fri, 20 Sep 2024 20:41:17 -0700 Subject: [PATCH 071/191] fixed compact gate... finally --- .../ipa_prf/aggregation/breakdown_reveal.rs | 1 + .../protocol/ipa_prf/aggregation/bucket.rs | 287 ------------------ .../src/protocol/ipa_prf/aggregation/mod.rs | 203 +------------ .../src/protocol/ipa_prf/aggregation/step.rs | 10 +- .../src/protocol/ipa_prf/prf_sharding/mod.rs | 14 +- .../src/protocol/ipa_prf/prf_sharding/step.rs | 2 + 6 files changed, 21 insertions(+), 496 deletions(-) delete mode 100644 ipa-core/src/protocol/ipa_prf/aggregation/bucket.rs diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index b08a3ece9..f657deac9 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -75,6 +75,7 @@ where let attributions = shuffle_attributions(&ctx, attributed_values_padded).await?; let grouped_tvs = reveal_breakdowns(&ctx, attributions).await?; let num_rows = grouped_tvs.max_len; + let ctx = ctx.narrow(&AggregationStep::AggregateAggregatePlease); aggregate_values::<_, HV, B>(ctx, grouped_tvs.into_stream(), num_rows).await } diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/bucket.rs b/ipa-core/src/protocol/ipa_prf/aggregation/bucket.rs deleted file mode 100644 index dea77c2f5..000000000 --- a/ipa-core/src/protocol/ipa_prf/aggregation/bucket.rs +++ /dev/null @@ -1,287 +0,0 @@ -use embed_doc_image::embed_doc_image; - -use crate::{ - error::Error, - ff::boolean::Boolean, - helpers::repeat_n, - protocol::{ - basics::SecureMul, boolean::and::bool_and_8_bit, context::Context, - ipa_prf::aggregation::step::BucketStep, RecordId, - }, - secret_sharing::{replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd}, -}; - -const MAX_BREAKDOWNS: usize = 512; // constrained by the compact step ability to generate dynamic steps - -#[derive(thiserror::Error, Debug)] -pub enum MoveToBucketError { - #[error("Bad value for the breakdown key: {0}")] - InvalidBreakdownKey(String), -} - -impl From for Error { - fn from(error: MoveToBucketError) -> Self { - match error { - e @ MoveToBucketError::InvalidBreakdownKey(_) => { - Error::InvalidQueryParameter(Box::new(e)) - } - } - } -} - -#[embed_doc_image("tree-aggregation", "images/tree_aggregation.png")] -/// This function moves a single value to a correct bucket using tree aggregation approach -/// -/// Here is how it works -/// The combined value, [`value`] forms the root of a binary tree as follows: -/// ![Tree propagation][tree-aggregation] -/// -/// This value is propagated through the tree, with each subsequent iteration doubling the number of multiplications. -/// In the first round, r=BK-1, multiply the most significant bit ,[`bd_key`]_r by the value to get [`bd_key`]_r.[`value`]. From that, -/// produce [`row_contribution`]_r,0 =[`value`]-[`bd_key`]_r.[`value`] and [`row_contribution`]_r,1=[`bd_key`]_r.[`value`]. -/// This takes the most significant bit of `bd_key` and places value in one of the two child nodes of the binary tree. -/// At each successive round, the next most significant bit is propagated from the leaf nodes of the tree into further leaf nodes: -/// [`row_contribution`]_r+1,q,0 =[`row_contribution`]_r,q - [`bd_key`]_r+1.[`row_contribution`]_r,q and [`row_contribution`]_r+1,q,1 =[`bd_key`]_r+1.[`row_contribution`]_r,q. -/// The work of each iteration therefore doubles relative to the one preceding. -/// -/// In case a malicious entity sends a out of range breakdown key (i.e. greater than the max count) to this function, we need to do some -/// extra processing to ensure contribution doesn't end up in a wrong bucket. However, this requires extra multiplications. -/// This would potentially not be needed in IPA (as the breakdown key is provided by the report collector, so a bad value only spoils their own result) but useful for PAM. -/// This can be by passing `robust` as true. -/// -/// ## Errors -/// If `breakdown_count` does not fit into `BK` bits or greater than or equal to $2^9$ -#[allow(dead_code)] -pub async fn move_single_value_to_bucket( - ctx: C, - record_id: RecordId, - bd_key: BitDecomposed>, - value: BitDecomposed>, - breakdown_count: usize, - robust: bool, -) -> Result>>, Error> -where - C: Context, - Boolean: FieldSimd, - AdditiveShare: SecureMul, -{ - let mut step: usize = 1 << bd_key.len(); - - if breakdown_count > step { - Err(MoveToBucketError::InvalidBreakdownKey(format!( - "Asking for more buckets ({breakdown_count}) than bits in the breakdown key ({}) allow", - bd_key.len() - )))?; - } - - if breakdown_count > MAX_BREAKDOWNS { - Err(MoveToBucketError::InvalidBreakdownKey( - "Our step implementation (BucketStep) cannot go past {MAX_BREAKDOWNS} breakdown keys" - .to_string(), - ))?; - } - - let mut row_contribution = vec![value; breakdown_count]; - - // To move a value to one of 2^bd_key_bits buckets requires 2^bd_key_bits - 1 multiplications - // They happen in a tree like fashion: - // 1 multiplication for the first bit - // 2 for the second bit - // 4 for the 3rd bit - // And so on. Simply ordering them sequentially is a functional way - // of enumerating them without creating more step transitions than necessary - let mut multiplication_channel = 0; - - for bit_of_bdkey in bd_key.iter().rev() { - let span = step >> 1; - if !robust && span > breakdown_count { - step = span; - continue; - } - - let contributions = ctx - .parallel_join((0..breakdown_count).step_by(step).enumerate().filter_map( - |(i, tree_index)| { - let bucket_c = ctx.narrow(&BucketStep::from(multiplication_channel + i)); - - let index_contribution = &row_contribution[tree_index]; - - (robust || tree_index + span < breakdown_count).then(|| { - bool_and_8_bit( - bucket_c, - record_id, - index_contribution, - repeat_n(bit_of_bdkey, index_contribution.len()), - ) - }) - }, - )) - .await?; - multiplication_channel += contributions.len(); - - for (index, bdbit_contribution) in contributions.into_iter().enumerate() { - let left_index = index * step; - let right_index = left_index + span; - - // bdbit_contribution is either zero or equal to row_contribution. So it - // is okay to do a carryless "subtraction" here. - for (r, b) in row_contribution[left_index] - .iter_mut() - .zip(bdbit_contribution.iter()) - { - *r -= b; - } - if right_index < breakdown_count { - for (r, b) in row_contribution[right_index] - .iter_mut() - .zip(bdbit_contribution) - { - *r = b; - } - } - } - step = span; - } - Ok(row_contribution) -} - -#[cfg(all(test, unit_test))] -pub mod tests { - use rand::thread_rng; - - use super::move_single_value_to_bucket; - use crate::{ - ff::{boolean::Boolean, boolean_array::BA8, Gf8Bit, Gf9Bit, U128Conversions}, - protocol::{context::Context, RecordId}, - rand::Rng, - secret_sharing::{BitDecomposed, SharedValue}, - test_executor::run, - test_fixture::{Reconstruct, Runner, TestWorld}, - }; - - const MAX_BREAKDOWN_COUNT: usize = 256; - const VALUE: u32 = 10; - - async fn move_to_bucket(count: usize, breakdown_key: usize, robust: bool) -> Vec { - let breakdown_key_bits = BitDecomposed::decompose(Gf8Bit::BITS, |i| { - Boolean::from((breakdown_key >> i) & 1 == 1) - }); - let value = - BitDecomposed::decompose(Gf8Bit::BITS, |i| Boolean::from((VALUE >> i) & 1 == 1)); - - TestWorld::default() - .semi_honest( - (breakdown_key_bits, value), - |ctx, (breakdown_key_share, value_share)| async move { - move_single_value_to_bucket::<_, 1>( - ctx.set_total_records(1), - RecordId::from(0), - breakdown_key_share, - value_share, - count, - robust, - ) - .await - .unwrap() - }, - ) - .await - .reconstruct() - .into_iter() - .map(|val| val.into_iter().collect()) - .collect() - } - - #[test] - fn semi_honest_move_in_range() { - run(|| async move { - let mut rng = thread_rng(); - let count = rng.gen_range(1..MAX_BREAKDOWN_COUNT); - let breakdown_key = rng.gen_range(0..count); - let mut expected = vec![BA8::ZERO; count]; - expected[breakdown_key] = BA8::truncate_from(VALUE); - - let result = move_to_bucket(count, breakdown_key, false).await; - assert_eq!(result, expected, "expected value at index {breakdown_key}"); - }); - } - - #[test] - fn semi_honest_move_in_range_robust() { - run(|| async move { - let mut rng = thread_rng(); - let count = rng.gen_range(1..MAX_BREAKDOWN_COUNT); - let breakdown_key = rng.gen_range(0..count); - let mut expected = vec![BA8::ZERO; count]; - expected[breakdown_key] = BA8::truncate_from(VALUE); - - let result = move_to_bucket(count, breakdown_key, true).await; - assert_eq!(result, expected, "expected value at index {breakdown_key}"); - }); - } - - #[test] - fn semi_honest_move_out_of_range() { - run(move || async move { - let mut rng: rand::rngs::ThreadRng = thread_rng(); - let count = rng.gen_range(2..MAX_BREAKDOWN_COUNT - 1); - let breakdown_key = rng.gen_range(count..MAX_BREAKDOWN_COUNT); - - let result = move_to_bucket(count, breakdown_key, false).await; - assert_eq!(result.len(), count); - assert_eq!( - result.into_iter().fold(0, |acc, v| acc + v.as_u128()), - u128::from(VALUE) - ); - }); - } - - #[test] - fn semi_honest_move_out_of_range_robust() { - run(move || async move { - let mut rng: rand::rngs::ThreadRng = thread_rng(); - let count = rng.gen_range(2..MAX_BREAKDOWN_COUNT - 1); - let breakdown_key = rng.gen_range(count..MAX_BREAKDOWN_COUNT); - - let result = move_to_bucket(count, breakdown_key, true).await; - assert_eq!(result.len(), count); - assert!(result.into_iter().all(|x| x == BA8::ZERO)); - }); - } - - #[test] - #[should_panic(expected = "Asking for more buckets")] - fn move_out_of_range_too_many_buckets_type() { - run(move || async move { - _ = move_to_bucket(MAX_BREAKDOWN_COUNT + 1, 0, false).await; - }); - } - - #[test] - #[should_panic(expected = "Asking for more buckets")] - fn move_out_of_range_too_many_buckets_steps() { - run(move || async move { - let breakdown_key_bits = BitDecomposed::decompose(Gf9Bit::BITS, |_| Boolean::FALSE); - let value = - BitDecomposed::decompose(Gf8Bit::BITS, |i| Boolean::from((VALUE >> i) & 1 == 1)); - - _ = TestWorld::default() - .semi_honest( - (breakdown_key_bits, value), - |ctx, (breakdown_key_share, value_share)| async move { - move_single_value_to_bucket::<_, 1>( - ctx.set_total_records(1), - RecordId::from(0), - breakdown_key_share, - value_share, - 513, - false, - ) - .await - .unwrap() - }, - ) - .await; - }); - } -} diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index f7bc026cc..7c6bab49a 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -1,30 +1,23 @@ -use std::{any::type_name, convert::Infallible, iter, pin::Pin}; +use std::{any::type_name, iter, pin::Pin}; -use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use tracing::Instrument; -use typenum::Const; use crate::{ - error::{Error, LengthError, UnwrapInfallible}, + error::{Error, LengthError}, ff::{boolean::Boolean, boolean_array::BooleanArray, U128Conversions}, helpers::{ - stream::{ - div_round_up, process_stream_by_chunks, ChunkBuffer, FixedLength, TryFlattenItersExt, - }, + stream::{ChunkBuffer, FixedLength}, TotalRecords, }, protocol::{ - basics::{BooleanArrayMul, BooleanProtocols, SecureMul}, + basics::BooleanProtocols, boolean::{step::ThirtyTwoBitStep, NBitStep}, - context::{ - dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, - Context, DZKPContext, MaliciousProtocolSteps, UpgradableContext, - }, + context::{dzkp_validator::TARGET_PROOF_SIZE, Context}, ipa_prf::{ - aggregation::step::{AggregateChunkStep, AggregateValuesStep, AggregationStep as Step}, + aggregation::step::{AggregateChunkStep, AggregateValuesStep}, boolean_ops::addition_sequential::{integer_add, integer_sat_add}, - prf_sharding::{AttributionOutputs, SecretSharedAttributionOutputs}, - BreakdownKey, AGG_CHUNK, + prf_sharding::AttributionOutputs, }, RecordId, }, @@ -35,7 +28,6 @@ use crate::{ }; pub(crate) mod breakdown_reveal; -mod bucket; pub(crate) mod step; type AttributionOutputsChunk = AttributionOutputs< @@ -91,185 +83,6 @@ where } } -// Aggregation -// -// The input to aggregation is a stream of tuples of (attributed breakdown key, attributed trigger -// value) for each record. -// -// The first stage of aggregation decodes the breakdown key to produce a vector of trigger value -// to be added to each output bucket. At most one element of this vector can be non-zero, -// corresponding to the breakdown key value. This stage is implemented by the -// `move_single_value_to_bucket` function. -// -// The second stage of aggregation sums these vectors across all records, to produce the final -// output histogram. -// -// The first stage of aggregation is vectorized over records, meaning that a chunk of N -// records is collected, and the `move_single_value_to_bucket` function is called to -// decode the breakdown keys for all of those records simultaneously. -// -// The second stage of aggregation is vectorized over histogram buckets, meaning that -// the values in all `B` output buckets are added simultaneously. -// -// An intermediate transpose occurs between the two stages of aggregation, to convert from the -// record-vectorized representation to the bucket-vectorized representation. -// -// The input to this transpose is `&[BitDecomposed>]`, indexed -// by buckets, bits of trigger value, and contribution rows. -// -// The output is `&[BitDecomposed>]`, indexed by -// contribution rows, bits of trigger value, and buckets. -#[tracing::instrument(name = "aggregate", skip_all, fields(streams = contributions_stream_len))] -pub async fn aggregate_contributions<'ctx, C, St, BK, TV, HV, const B: usize>( - ctx: C, - contributions_stream: St, - mut contributions_stream_len: usize, -) -> Result>, Error> -where - C: UpgradableContext + 'ctx, - St: Stream, Error>> + Send, - BK: BreakdownKey, - TV: BooleanArray + U128Conversions, - HV: BooleanArray + U128Conversions, - Boolean: FieldSimd, - Replicated: BooleanProtocols<::Context, B>, - Replicated: SecureMul<::Context>, - Replicated: BooleanArrayMul<::Context>, - Replicated: BooleanArrayMul<::Context>, - BitDecomposed>: - for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, - BitDecomposed>: - for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, - Vec>>: for<'a> TransposeFrom< - &'a [BitDecomposed>], - Error = Infallible, - >, - Vec>: - for<'a> TransposeFrom<&'a BitDecomposed>, Error = LengthError>, -{ - assert!(contributions_stream_len != 0); - - let move_to_bucket_chunk_size = TARGET_PROOF_SIZE / B / usize::try_from(TV::BITS).unwrap(); - let move_to_bucket_records = - TotalRecords::specified(div_round_up(contributions_stream_len, Const::))?; - let validator = ctx - .set_total_records(move_to_bucket_records) - .dzkp_validator( - MaliciousProtocolSteps { - protocol: &Step::MoveToBucket, - validate: &Step::MoveToBucketValidate, - }, - move_to_bucket_chunk_size, - ); - let bucket_ctx = validator.context(); - // move each value to the correct bucket - let row_contribution_chunk_stream = process_stream_by_chunks( - contributions_stream, - AttributionOutputs { - attributed_breakdown_key_bits: vec![], - capped_attributed_trigger_value: vec![], - }, - move |idx, chunk: AttributionOutputsChunk| { - let record_id = RecordId::from(idx); - let validate_ctx = bucket_ctx.clone(); - let ctx = bucket_ctx - .clone() - .set_total_records(TotalRecords::Indeterminate); - async move { - let result = bucket::move_single_value_to_bucket::<_, AGG_CHUNK>( - ctx.clone(), - record_id, - chunk.attributed_breakdown_key_bits, - chunk.capped_attributed_trigger_value, - B, - false, - ) - .instrument(tracing::debug_span!("move_to_bucket", chunk = idx)) - .await; - - validate_ctx.validate_record(record_id).await?; - - result - } - }, - ); - - let mut aggregation_input = row_contribution_chunk_stream - // Rather than transpose out of record-vectorized form and then transpose again back - // into bucket-vectorized form, we use a special transpose (the "aggregation - // intermediate transpose") that combines the two steps. - // - // Since the bucket-vectorized representation is separable by records, we do the - // transpose within the `Chunk` wrapper using `Chunk::map`, and then invoke - // `Chunk::into_iter` via `try_flatten_iters` to produce an unchunked stream of - // records, vectorized by buckets. - .then(|fut| { - fut.map(|res| { - res.map(|chunk| { - chunk.map(|data| Vec::transposed_from(data.as_slice()).unwrap_infallible()) - }) - }) - }) - .try_flatten_iters() - .boxed(); - - let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()); - let chunks = iter::from_fn(|| { - if contributions_stream_len >= agg_proof_chunk { - contributions_stream_len -= agg_proof_chunk; - Some(agg_proof_chunk) - } else if contributions_stream_len > 0 { - let chunk = contributions_stream_len; - contributions_stream_len = 0; - Some(chunk) - } else { - None - } - }); - let mut intermediate_results = Vec::new(); - let mut chunk_counter = 0; - - for chunk in chunks { - chunk_counter += 1; - let stream = aggregation_input.by_ref().take(chunk); - let validator = ctx.clone().dzkp_validator( - MaliciousProtocolSteps { - protocol: &Step::aggregate_chunk(chunk_counter), - validate: &Step::aggregate_chunk_validate(chunk_counter), - }, - agg_proof_chunk, - ); - let result = - aggregate_values::<_, HV, B>(validator.context(), stream.boxed(), chunk).await?; - validator.validate().await?; - intermediate_results.push(Ok(result)); - } - - if intermediate_results.len() > 1 { - let stream_len = intermediate_results.len(); - let validator = ctx.dzkp_validator( - MaliciousProtocolSteps { - protocol: &Step::AggregateChunk(chunk_counter), - validate: &Step::AggregateChunkValidate(chunk_counter), - }, - agg_proof_chunk, - ); - let aggregated_result = aggregate_values::<_, HV, B>( - validator.context(), - stream::iter(intermediate_results).boxed(), - stream_len, - ) - .await?; - validator.validate().await?; - Ok(aggregated_result) - } else { - intermediate_results - .into_iter() - .next() - .expect("aggregation input must not be empty") - } -} - /// A vector of histogram contributions for each output bucket. /// /// Aggregation is vectorized over histogram buckets, so bit 0 for every histogram bucket is stored diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs index 65fff707d..8a5da06d8 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs @@ -10,14 +10,12 @@ pub(crate) enum AggregationStep { #[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)] Shuffle, RevealStep, - #[step(child = BucketStep)] - MoveToBucket, - #[step(child = crate::protocol::context::step::DzkpBatchStep)] - MoveToBucketValidate, - #[step(count = 32, child = AggregateChunkStep)] - AggregateChunk(usize), + #[step(child = AggregateChunkStep)] + AggregateAggregatePlease, #[step(count = 32, child = crate::protocol::context::step::DzkpSingleBatchStep)] AggregateChunkValidate(usize), + #[step(child = crate::protocol::context::step::DzkpBatchStep)] + AggregateValidate, } /// the number of steps must be kept in sync with `MAX_BREAKDOWNS` defined diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 111cf7552..00a90a642 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -12,7 +12,7 @@ use futures::{ FutureExt, Stream, StreamExt, TryStreamExt, }; -use super::aggregation::{aggregate_contributions, breakdown_reveal::breakdown_reveal_aggregation}; +use super::aggregation::breakdown_reveal::breakdown_reveal_aggregation; use crate::{ error::{Error, LengthError}, ff::{ @@ -33,7 +33,7 @@ use crate::{ Context, DZKPContext, DZKPUpgraded, MaliciousProtocolSteps, UpgradableContext, }, ipa_prf::{ - aggregation::{aggregate_values_proof_chunk, step::AggregationStep}, + aggregation::aggregate_values_proof_chunk, boolean_ops::{ addition_sequential::integer_add, comparison_and_subtraction_sequential::{compare_gt, integer_sub}, @@ -505,7 +505,6 @@ where * multiplications_per_record::(attribution_window_seconds)); // Tricky hacks to work around the limitations of our current infrastructure - let num_outputs = input_rows.len() - histogram[0]; let mut dzkp_validator = sh_ctx.clone().dzkp_validator( MaliciousProtocolSteps { protocol: &Step::Attribute, @@ -535,12 +534,11 @@ where attribution_window_seconds, ); - let ctx = sh_ctx.narrow(&Step::Aggregate); - - let validator = ctx.dzkp_validator( + // TODO: move this to the place it's actually used + let validator = sh_ctx.dzkp_validator( MaliciousProtocolSteps { - protocol: &AggregationStep::AggregateChunk(0), - validate: &AggregationStep::AggregateChunkValidate(0), + protocol: &Step::Aggregate, + validate: &Step::AggregateValidate, }, aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()), ); diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs index 710b0a7e3..03255d342 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs @@ -12,6 +12,8 @@ pub(crate) enum AttributionStep { AttributeValidate, #[step(child = crate::protocol::ipa_prf::aggregation::step::AggregationStep)] Aggregate, + #[step(child = crate::protocol::context::step::DzkpSingleBatchStep)] + AggregateValidate, } #[derive(CompactStep)] From 78e980c923064a396beb11ae0661f10953e377b5 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Fri, 20 Sep 2024 21:09:56 -0700 Subject: [PATCH 072/191] adding tracing --- ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs | 3 ++- ipa-core/src/protocol/ipa_prf/aggregation/mod.rs | 1 + ipa-core/src/protocol/ipa_prf/aggregation/step.rs | 4 +--- ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs | 1 + ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs | 1 - ipa-core/src/protocol/ipa_prf/shuffle/mod.rs | 1 - 6 files changed, 5 insertions(+), 6 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index f657deac9..cd9ea3eb0 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -48,6 +48,7 @@ use crate::{ /// 2. Reveal breakdown keys. This is the key difference to the previous /// aggregation (see [`reveal_breakdowns`]). /// 3. Add all values for each breakdown. +#[tracing::instrument(name = "breakdown_reveal_aggregation", skip_all, fields(total = attributed_values.len()))] pub async fn breakdown_reveal_aggregation( ctx: C, attributed_values: Vec>, @@ -75,7 +76,7 @@ where let attributions = shuffle_attributions(&ctx, attributed_values_padded).await?; let grouped_tvs = reveal_breakdowns(&ctx, attributions).await?; let num_rows = grouped_tvs.max_len; - let ctx = ctx.narrow(&AggregationStep::AggregateAggregatePlease); + let ctx = ctx.narrow(&AggregationStep::SumContributions); aggregate_values::<_, HV, B>(ctx, grouped_tvs.into_stream(), num_rows).await } diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index 7c6bab49a..41f91e243 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -116,6 +116,7 @@ pub fn aggregate_values_proof_chunk(input_width: usize, input_item_bits: usize) /// /// It might be possible to save some cost by using naive wrapping arithmetic. Another /// possibility would be to combine all carries into a single "overflow detected" bit. +#[tracing::instrument(name = "aggregate_values", skip_all, fields(num_rows = num_rows))] pub async fn aggregate_values<'ctx, 'fut, C, OV, const B: usize>( ctx: C, mut aggregated_stream: Pin> + Send + 'fut>>, diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs index 8a5da06d8..fcb619a45 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs @@ -11,9 +11,7 @@ pub(crate) enum AggregationStep { Shuffle, RevealStep, #[step(child = AggregateChunkStep)] - AggregateAggregatePlease, - #[step(count = 32, child = crate::protocol::context::step::DzkpSingleBatchStep)] - AggregateChunkValidate(usize), + SumContributions, #[step(child = crate::protocol::context::step::DzkpBatchStep)] AggregateValidate, } diff --git a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs index ce2d1ceda..d4680f0ac 100644 --- a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs @@ -274,6 +274,7 @@ where /// # Errors /// Will propagate errors from `apply_dp_padding_pass` +#[tracing::instrument(name = "apply_dp_padding", skip_all)] pub async fn apply_dp_padding( ctx: C, mut input: Vec, diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 00a90a642..a745d5ea7 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -534,7 +534,6 @@ where attribution_window_seconds, ); - // TODO: move this to the place it's actually used let validator = sh_ctx.dzkp_validator( MaliciousProtocolSteps { protocol: &Step::Aggregate, diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index 2154374cf..582445190 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -140,7 +140,6 @@ where .map(|item| attribution_outputs_to_shuffle_input::(&item)) .collect::>(); - //let (shuffled, _) = shuffle_protocol(ctx, shuffle_input).await?; let shuffled = malicious_shuffle::<_, R, BA96, _>(ctx, shuffle_input).await?; Ok(shuffled From d08f94740ac9ec7de2d461addba2caeacf5a5d78 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Fri, 20 Sep 2024 21:17:50 -0700 Subject: [PATCH 073/191] Reducing step count limit --- ipa-core/src/protocol/ipa_prf/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 90e00cb64..0462e31d6 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -744,7 +744,7 @@ mod compact_gate_tests { fn step_count_limit() { // This is an arbitrary limit intended to catch changes that unintentionally // blow up the step count. It can be increased, within reason. - const STEP_COUNT_LIMIT: u32 = 200_000; + const STEP_COUNT_LIMIT: u32 = 20_000; assert!( ProtocolStep::STEP_COUNT < STEP_COUNT_LIMIT, "Step count of {actual} exceeds limit of {STEP_COUNT_LIMIT}.", From 7b9f9da6835488a845c15ed6a2ae7997e564bd81 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Sat, 21 Sep 2024 04:20:12 -0700 Subject: [PATCH 074/191] Clippy --- ipa-core/src/test_fixture/hybrid_event_gen.rs | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs index 7eddb0ba4..bc1a463dd 100644 --- a/ipa-core/src/test_fixture/hybrid_event_gen.rs +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -185,10 +185,10 @@ mod tests { // we will accept a value within 5% of the expected value const EXPECTED_HISTOGRAM_WITH_TOLERANCE: [(i32, f64); 12] = [ (0, 0.0), - (647634, 0.01), - (137626, 0.01), - (20652, 0.02), - (3085, 0.05), + (647_634, 0.01), + (137_626, 0.01), + (20_652, 0.02), + (3_085, 0.05), (463, 0.12), (70, 0.5), (10, 1.0), @@ -229,13 +229,13 @@ mod tests { // Adding a constant value of 10 is a way of dealing with the high variability small values // which will vary a lot more (as a percent). Because 10 is an increasingly large percentage of // A smaller and smaller expected value - let max_tolerance = (*expected as f64) * tolerance + 10.0; + let max_tolerance = f64::from(*expected) * tolerance + 10.0; assert!( - (expected - actual).abs() as f64 <= max_tolerance, + f64::from((expected - actual).abs()) <= max_tolerance, "{:?} is outside of the expected range: ({:?}..{:?})", actual, - (*expected as f64) - max_tolerance, - (*expected as f64) + max_tolerance, + f64::from(*expected) - max_tolerance, + f64::from(*expected) + max_tolerance, ); } } @@ -243,7 +243,7 @@ mod tests { #[test] fn lots_of_repeat_conversions() { const EXPECTED_HISTOGRAM: [i32; 12] = [ - 0, 299296, 25640, 20542, 16421, 13133, 10503, 8417, 6730, 5391, 4289, 17206, + 0, 299_296, 25_640, 20_542, 16_421, 13_133, 10_503, 8_417, 6_730, 5_391, 4_289, 17_206, ]; const TEST_COUNT: usize = 1_000_000; const MAX_CONVS_PER_IMP: u32 = 10; @@ -288,13 +288,13 @@ mod tests { } for (expected, actual) in zip(EXPECTED_HISTOGRAM.iter(), histogram) { - let max_tolerance = (*expected as f64) * 0.05 + 10.0; + let max_tolerance = f64::from(*expected) * 0.05 + 10.0; assert!( - (expected - actual).abs() as f64 <= max_tolerance, + f64::from((expected - actual).abs()) <= max_tolerance, "{:?} is outside of the expected range: ({:?}..{:?})", actual, - (*expected as f64) - max_tolerance, - (*expected as f64) + max_tolerance, + f64::from(*expected) - max_tolerance, + f64::from(*expected) + max_tolerance, ); } } From da97c99a8290bd7a8e60ea1d046190b67ca39642 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sat, 21 Sep 2024 22:56:17 -0700 Subject: [PATCH 075/191] Add a handler to kill a query In our testing with external parties, it became quite evident that bouncing helpers if query fails is far from having good experience for people who run the test. We need to have an ability to kill queries that failed on one or more helper and restart the runs. This change does exactly that - adding a `/query_id/kill` handler that will respond with immediately aborting the task that runs an IPA query. --- ipa-core/src/app.rs | 4 + ipa-core/src/helpers/transport/handler.rs | 11 +- .../helpers/transport/in_memory/transport.rs | 3 +- ipa-core/src/helpers/transport/routing.rs | 1 + .../helpers/transport/stream/collection.rs | 20 +++ ipa-core/src/net/http_serde.rs | 74 ++++++++++ .../src/net/server/handlers/query/kill.rs | 104 ++++++++++++++ ipa-core/src/net/server/handlers/query/mod.rs | 2 + ipa-core/src/net/transport.rs | 43 +++++- ipa-core/src/query/mod.rs | 2 +- ipa-core/src/query/processor.rs | 127 ++++++++++++++++++ 11 files changed, 384 insertions(+), 7 deletions(-) create mode 100644 ipa-core/src/net/server/handlers/query/kill.rs diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index db603501c..da56e67e3 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -203,6 +203,10 @@ impl RequestHandler for Inner { let query_id = ext_query_id(&req)?; HelperResponse::from(qp.complete(query_id).await?) } + RouteId::KillQuery => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.kill(query_id)?) + } }) } } diff --git a/ipa-core/src/helpers/transport/handler.rs b/ipa-core/src/helpers/transport/handler.rs index 42981d097..525edb67e 100644 --- a/ipa-core/src/helpers/transport/handler.rs +++ b/ipa-core/src/helpers/transport/handler.rs @@ -12,7 +12,7 @@ use crate::{ }, query::{ NewQueryError, PrepareQueryError, ProtocolResult, QueryCompletionError, QueryInputError, - QueryStatus, QueryStatusError, + QueryKillStatus, QueryKilled, QueryStatus, QueryStatusError, }, sync::{Arc, Mutex, Weak}, }; @@ -135,6 +135,13 @@ impl From for HelperResponse { } } +impl From for HelperResponse { + fn from(value: QueryKilled) -> Self { + let v = serde_json::to_vec(&json!({"query_id": value.0, "status": "killed"})).unwrap(); + Self { body: v } + } +} + impl> From for HelperResponse { fn from(value: R) -> Self { let v = value.as_ref().to_bytes(); @@ -156,6 +163,8 @@ pub enum Error { #[error(transparent)] QueryStatus(#[from] QueryStatusError), #[error(transparent)] + QueryKill(#[from] QueryKillStatus), + #[error(transparent)] DeserializationFailure(#[from] serde_json::Error), #[error("MalformedRequest: {0}")] BadRequest(BoxError), diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 3c1a9e926..cd7324e89 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -119,7 +119,8 @@ impl InMemoryTransport { | RouteId::PrepareQuery | RouteId::QueryInput | RouteId::QueryStatus - | RouteId::CompleteQuery => { + | RouteId::CompleteQuery + | RouteId::KillQuery => { handler .as_ref() .expect("Handler is set") diff --git a/ipa-core/src/helpers/transport/routing.rs b/ipa-core/src/helpers/transport/routing.rs index 4d8f44796..3d9c2bb5f 100644 --- a/ipa-core/src/helpers/transport/routing.rs +++ b/ipa-core/src/helpers/transport/routing.rs @@ -16,6 +16,7 @@ pub enum RouteId { QueryInput, QueryStatus, CompleteQuery, + KillQuery, } /// The header/metadata of the incoming request. diff --git a/ipa-core/src/helpers/transport/stream/collection.rs b/ipa-core/src/helpers/transport/stream/collection.rs index 09e4f5e63..f19fd7ce5 100644 --- a/ipa-core/src/helpers/transport/stream/collection.rs +++ b/ipa-core/src/helpers/transport/stream/collection.rs @@ -114,6 +114,26 @@ impl StreamCollection { let mut streams = self.inner.lock().unwrap(); streams.clear(); } + + /// Returns the number of streams inside this collection. + /// + /// ## Panics + /// if mutex is poisoned. + #[cfg(test)] + #[must_use] + pub fn len(&self) -> usize { + self.inner.lock().unwrap().len() + } + + /// Returns `true` if this collection is empty. + /// + /// ## Panics + /// if mutex is poisoned. + #[must_use] + #[cfg(test)] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } /// Describes the lifecycle of records stream inside [`StreamCollection`] diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 927ae4a4d..6dabaf359 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -533,4 +533,78 @@ pub mod query { pub const AXUM_PATH: &str = "/:query_id/complete"; } + + pub mod kill { + use serde::{Deserialize, Serialize}; + + use crate::{ + helpers::{routing::RouteId, HelperResponse, NoStep, RouteParams}, + protocol::QueryId, + }; + + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] + pub struct Request { + pub query_id: QueryId, + } + + impl RouteParams for Request { + type Params = String; + + fn resource_identifier(&self) -> RouteId { + RouteId::KillQuery + } + + fn query_id(&self) -> QueryId { + self.query_id + } + + fn gate(&self) -> NoStep { + NoStep + } + + fn extra(&self) -> Self::Params { + String::new() + } + } + + impl Request { + /// see above for the reason why this needs to be behind a feature flag + #[cfg(all(test, not(feature = "shuttle")))] + pub fn new(query_id: QueryId) -> Self { + Self { query_id } + } + + #[cfg(all(test, not(feature = "shuttle")))] + pub fn try_into_http_request( + self, + scheme: axum::http::uri::Scheme, + authority: axum::http::uri::Authority, + ) -> crate::net::http_serde::OutgoingRequest { + let uri = axum::http::uri::Uri::builder() + .scheme(scheme) + .authority(authority) + .path_and_query(format!( + "{}/{}/kill", + crate::net::http_serde::query::BASE_AXUM_PATH, + self.query_id.as_ref() + )) + .build()?; + Ok(hyper::Request::get(uri).body(axum::body::Body::empty())?) + } + } + + #[derive(Clone, Debug, Serialize, Deserialize)] + pub struct ResponseBody { + pub query_id: QueryId, + pub status: String, + } + + impl From for ResponseBody { + fn from(value: HelperResponse) -> Self { + serde_json::from_slice(value.into_body().as_slice()).unwrap() + } + } + + pub const AXUM_PATH: &str = "/:query_id/kill"; + } } diff --git a/ipa-core/src/net/server/handlers/query/kill.rs b/ipa-core/src/net/server/handlers/query/kill.rs new file mode 100644 index 000000000..d6845a95d --- /dev/null +++ b/ipa-core/src/net/server/handlers/query/kill.rs @@ -0,0 +1,104 @@ +use axum::{extract::Path, routing::get, Extension, Json, Router}; +use hyper::StatusCode; + +use crate::{ + helpers::{ApiError, BodyStream, Transport}, + net::{ + http_serde::query::{kill, kill::Request}, + server::Error, + Error::QueryIdNotFound, + HttpTransport, + }, + protocol::QueryId, + query::QueryKillStatus, + sync::Arc, +}; + +async fn handler( + transport: Extension>, + Path(query_id): Path, +) -> Result, Error> { + let req = Request { query_id }; + let transport = Transport::clone_ref(&*transport); + match transport.dispatch(req, BodyStream::empty()).await { + Ok(state) => Ok(Json(kill::ResponseBody::from(state))), + Err(ApiError::QueryKill(QueryKillStatus::NoSuchQuery(query_id))) => Err( + Error::application(StatusCode::NOT_FOUND, QueryIdNotFound(query_id)), + ), + Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), + } +} + +pub fn router(transport: Arc) -> Router { + Router::new() + .route(kill::AXUM_PATH, get(handler)) + .layer(Extension(transport)) +} + +#[cfg(all(test, unit_test))] +mod tests { + use axum::{ + body::Body, + http::uri::{Authority, Scheme}, + }; + use hyper::StatusCode; + + use crate::{ + helpers::{ + make_owned_handler, + routing::{Addr, RouteId}, + BodyStream, HelperIdentity, HelperResponse, + }, + net::{ + http_serde, + server::handlers::query::test_helpers::{assert_fails_with, assert_success_with}, + }, + protocol::QueryId, + query::QueryKilled, + }; + + #[tokio::test] + async fn calls_kill() { + let expected_query_id = QueryId; + + let handler = make_owned_handler( + move |addr: Addr, _data: BodyStream| async move { + let RouteId::KillQuery = addr.route else { + panic!("unexpected call: {addr:?}"); + }; + assert_eq!(addr.query_id, Some(expected_query_id)); + Ok(HelperResponse::from(QueryKilled(expected_query_id))) + }, + ); + + let req = http_serde::query::kill::Request::new(QueryId); + let req = req + .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) + .unwrap(); + assert_success_with(req, handler).await; + } + + struct OverrideReq { + query_id: String, + } + + impl From for hyper::Request { + fn from(val: OverrideReq) -> Self { + let uri = format!( + "http://localhost{}/{}/kill", + http_serde::query::BASE_AXUM_PATH, + val.query_id + ); + hyper::Request::get(uri).body(Body::empty()).unwrap() + } + } + + #[tokio::test] + async fn malformed_query_id() { + let req = OverrideReq { + query_id: "not-a-query-id".into(), + }; + + assert_fails_with(req.into(), StatusCode::BAD_REQUEST).await; + } +} diff --git a/ipa-core/src/net/server/handlers/query/mod.rs b/ipa-core/src/net/server/handlers/query/mod.rs index 49f18e0a8..92890bd57 100644 --- a/ipa-core/src/net/server/handlers/query/mod.rs +++ b/ipa-core/src/net/server/handlers/query/mod.rs @@ -1,5 +1,6 @@ mod create; mod input; +mod kill; mod prepare; mod results; mod status; @@ -31,6 +32,7 @@ pub fn query_router(transport: Arc) -> Router { .merge(create::router(Arc::clone(&transport))) .merge(input::router(Arc::clone(&transport))) .merge(status::router(Arc::clone(&transport))) + .merge(kill::router(Arc::clone(&transport))) .merge(results::router(transport)) } diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 81d4bdcce..1d657d9cd 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -135,7 +135,7 @@ impl HttpTransport { .expect("A Handler should be set by now") .handle(Addr::from_route(None, req), body); - if let RouteId::CompleteQuery = route_id { + if let RouteId::CompleteQuery | RouteId::KillQuery = route_id { ClearOnDrop { transport: Arc::clone(&self), inner: r, @@ -210,7 +210,8 @@ impl Transport for Arc { evt @ (RouteId::QueryInput | RouteId::ReceiveQuery | RouteId::QueryStatus - | RouteId::CompleteQuery) => { + | RouteId::CompleteQuery + | RouteId::KillQuery) => { unimplemented!( "attempting to send client-specific request {evt:?} to another helper" ) @@ -272,7 +273,10 @@ mod tests { use bytes::Bytes; use futures::stream::{poll_immediate, StreamExt}; - use futures_util::future::{join_all, try_join_all}; + use futures_util::{ + future::{join_all, try_join_all}, + stream, + }; use generic_array::GenericArray; use once_cell::sync::Lazy; use tokio::sync::mpsc::channel; @@ -283,18 +287,49 @@ mod tests { use crate::{ config::{NetworkConfig, ServerConfig}, ff::{FieldType, Fp31, Serializable}, - helpers::query::{QueryInput, QueryType::TestMultiply}, + helpers::{ + make_owned_handler, + query::{QueryInput, QueryType::TestMultiply}, + HandlerBox, + }, net::{ client::ClientIdentity, test::{get_test_identity, TestConfig, TestConfigBuilder, TestServer}, }, secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, + test_executor::run, test_fixture::Reconstruct, AppConfig, AppSetup, HelperApp, }; static STEP: Lazy = Lazy::new(|| Gate::from("http-transport")); + #[tokio::test] + async fn clean_on_kill() { + let noop_handler = make_owned_handler(|addr, stream| async move { + { + Ok(HelperResponse::ok()) + } + }); + let TestServer { mut transport, .. } = TestServer::builder() + .with_request_handler(Arc::clone(&noop_handler)) + .build() + .await; + + transport.record_streams.add_stream( + (QueryId, HelperIdentity::ONE, Gate::default()), + BodyStream::empty(), + ); + assert_eq!(1, transport.record_streams.len()); + + Transport::clone_ref(&transport) + .dispatch((RouteId::KillQuery, QueryId), BodyStream::empty()) + .await + .unwrap(); + + assert!(transport.record_streams.is_empty()); + } + #[tokio::test] async fn receive_stream() { let (tx, rx) = channel::>>(1); diff --git a/ipa-core/src/query/mod.rs b/ipa-core/src/query/mod.rs index aaa437b7a..6e6650862 100644 --- a/ipa-core/src/query/mod.rs +++ b/ipa-core/src/query/mod.rs @@ -8,7 +8,7 @@ use completion::Handle as CompletionHandle; pub use executor::Result as ProtocolResult; pub use processor::{ NewQueryError, PrepareQueryError, Processor as QueryProcessor, QueryCompletionError, - QueryInputError, QueryStatusError, + QueryInputError, QueryKillStatus, QueryKilled, QueryStatusError, }; pub use runner::OprfIpaQuery; pub use state::QueryStatus; diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index c399bb019..9734eab2d 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -5,6 +5,7 @@ use std::{ }; use futures::{future::try_join, stream}; +use serde::{Deserialize, Serialize}; use crate::{ error::Error as ProtocolError, @@ -328,6 +329,36 @@ impl Processor { Ok(handle.await?) } + + /// Terminates a query with the given id. If query is running, then it + /// is unregistered and its task is terminated. + /// + /// ## Errors + /// if query is not registered on this helper. + /// + /// ## Panics + /// If failed to obtain exclusive access to the query collection. + pub fn kill(&self, query_id: QueryId) -> Result { + let mut queries = self.queries.inner.lock().unwrap(); + let Some(state) = queries.remove(&query_id) else { + return Err(QueryKillStatus::NoSuchQuery(query_id)); + }; + + if let QueryState::Running(handle) = state { + handle.join_handle.abort(); + } + + Ok(QueryKilled(query_id)) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct QueryKilled(pub QueryId); + +#[derive(thiserror::Error, Debug)] +pub enum QueryKillStatus { + #[error("failed to kill a query: {0} does not exist.")] + NoSuchQuery(QueryId), } #[cfg(all(test, unit_test))] @@ -549,6 +580,102 @@ mod tests { } } + mod kill { + use std::sync::Arc; + + use crate::{ + ff::FieldType, + helpers::{ + query::{ + QueryConfig, + QueryType::{TestAddInPrimeField, TestMultiply}, + }, + HandlerBox, HelperIdentity, InMemoryMpcNetwork, Transport, + }, + protocol::QueryId, + query::{ + processor::{tests::respond_ok, Processor}, + state::{QueryState, RunningQuery}, + QueryKillStatus, + }, + test_executor::run, + }; + + #[test] + fn non_existent_query() { + let processor = Processor::default(); + assert!(matches!( + processor.kill(QueryId), + Err(QueryKillStatus::NoSuchQuery(QueryId)) + )); + } + + #[test] + fn existing_query() { + run(|| async move { + let h2 = respond_ok(); + let h3 = respond_ok(); + let network = InMemoryMpcNetwork::new([ + None, + Some(HandlerBox::owning_ref(&h2)), + Some(HandlerBox::owning_ref(&h3)), + ]); + let identities = HelperIdentity::make_three(); + let processor = Processor::default(); + let transport = network.transport(identities[0]); + processor + .new_query( + Transport::clone_ref(&transport), + QueryConfig::new(TestMultiply, FieldType::Fp31, 1).unwrap(), + ) + .await + .unwrap(); + + processor.kill(QueryId).unwrap(); + + // start query again - it should work because the query was killed + processor + .new_query( + transport, + QueryConfig::new(TestAddInPrimeField, FieldType::Fp32BitPrime, 1).unwrap(), + ) + .await + .unwrap(); + }); + } + + #[test] + fn aborts_protocol_task() { + run(|| async move { + let processor = Processor::default(); + let (_tx, rx) = tokio::sync::oneshot::channel(); + let counter = Arc::new(1); + let task = tokio::spawn({ + let counter = Arc::clone(&counter); + async move { + loop { + tokio::task::yield_now().await; + let _ = *counter.as_ref(); + } + } + }); + processor.queries.inner.lock().unwrap().insert( + QueryId, + QueryState::Running(RunningQuery { + result: rx, + join_handle: task, + }), + ); + + assert_eq!(2, Arc::strong_count(&counter)); + processor.kill(QueryId).unwrap(); + while Arc::strong_count(&counter) > 1 { + tokio::task::yield_now().await; + } + }); + } + } + mod e2e { use std::time::Duration; From 7f2a82ea56290885a3a1bab44dfd6b1027b142b3 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sat, 21 Sep 2024 23:32:43 -0700 Subject: [PATCH 076/191] Fix unused warnings --- ipa-core/src/net/http_serde.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 6dabaf359..250dc6792 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -569,12 +569,12 @@ pub mod query { impl Request { /// see above for the reason why this needs to be behind a feature flag - #[cfg(all(test, not(feature = "shuttle")))] + #[cfg(all(test, unit_test))] pub fn new(query_id: QueryId) -> Self { Self { query_id } } - #[cfg(all(test, not(feature = "shuttle")))] + #[cfg(all(test, unit_test))] pub fn try_into_http_request( self, scheme: axum::http::uri::Scheme, From 30edb326f5ad2c47635623aff4b5bc17ce0cb453 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sat, 21 Sep 2024 23:42:05 -0700 Subject: [PATCH 077/191] change GET to POST --- ipa-core/src/net/http_serde.rs | 2 +- ipa-core/src/net/server/handlers/query/kill.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 250dc6792..f6902250f 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -589,7 +589,7 @@ pub mod query { self.query_id.as_ref() )) .build()?; - Ok(hyper::Request::get(uri).body(axum::body::Body::empty())?) + Ok(hyper::Request::post(uri).body(axum::body::Body::empty())?) } } diff --git a/ipa-core/src/net/server/handlers/query/kill.rs b/ipa-core/src/net/server/handlers/query/kill.rs index d6845a95d..8a2188df7 100644 --- a/ipa-core/src/net/server/handlers/query/kill.rs +++ b/ipa-core/src/net/server/handlers/query/kill.rs @@ -1,4 +1,4 @@ -use axum::{extract::Path, routing::get, Extension, Json, Router}; +use axum::{extract::Path, routing::post, Extension, Json, Router}; use hyper::StatusCode; use crate::{ @@ -31,7 +31,7 @@ async fn handler( pub fn router(transport: Arc) -> Router { Router::new() - .route(kill::AXUM_PATH, get(handler)) + .route(kill::AXUM_PATH, post(handler)) .layer(Extension(transport)) } @@ -89,7 +89,7 @@ mod tests { http_serde::query::BASE_AXUM_PATH, val.query_id ); - hyper::Request::get(uri).body(Body::empty()).unwrap() + hyper::Request::post(uri).body(Body::empty()).unwrap() } } From 5ebf875a956f4cd893ecf6c130927d34c23f4b7f Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sun, 22 Sep 2024 00:08:12 -0700 Subject: [PATCH 078/191] Fix clippy issues --- ipa-core/src/net/transport.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 1d657d9cd..508bfc8d5 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -273,10 +273,7 @@ mod tests { use bytes::Bytes; use futures::stream::{poll_immediate, StreamExt}; - use futures_util::{ - future::{join_all, try_join_all}, - stream, - }; + use futures_util::future::{join_all, try_join_all}; use generic_array::GenericArray; use once_cell::sync::Lazy; use tokio::sync::mpsc::channel; @@ -290,14 +287,12 @@ mod tests { helpers::{ make_owned_handler, query::{QueryInput, QueryType::TestMultiply}, - HandlerBox, }, net::{ client::ClientIdentity, test::{get_test_identity, TestConfig, TestConfigBuilder, TestServer}, }, secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, - test_executor::run, test_fixture::Reconstruct, AppConfig, AppSetup, HelperApp, }; @@ -306,12 +301,12 @@ mod tests { #[tokio::test] async fn clean_on_kill() { - let noop_handler = make_owned_handler(|addr, stream| async move { + let noop_handler = make_owned_handler(|_, _| async move { { Ok(HelperResponse::ok()) } }); - let TestServer { mut transport, .. } = TestServer::builder() + let TestServer { transport, .. } = TestServer::builder() .with_request_handler(Arc::clone(&noop_handler)) .build() .await; From 30be8404901d10ac0a5768de94cff82d28f335fe Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 23 Sep 2024 11:54:15 -0700 Subject: [PATCH 079/191] Fix the stall inside attribution --- ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index a745d5ea7..267631da0 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -510,7 +510,9 @@ where protocol: &Step::Attribute, validate: &Step::AttributeValidate, }, - chunk_size, + // The size of a single batch should not exceed the active work limit, + // otherwise it will stall + std::cmp::min(sh_ctx.active_work().get(), chunk_size), ); dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?; From b51e62e66d760a5144fa4265bd39f9e48508c396 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 24 Sep 2024 18:18:45 -0700 Subject: [PATCH 080/191] Reduce number of iterations for breakdown reveal test We are having some flakiness with this test: https://github.com/private-attribution/ipa/actions/runs/11018449181/job/30598750956?pr=1307 and I attribute it to having too many iterations (by default it is set to 32). Even running it locally takes a long time and it is not possible to have reliable detection for large routines. --- ipa-core/src/helpers/transport/stream/mod.rs | 2 +- ipa-core/src/lib.rs | 12 ++++++------ .../protocol/ipa_prf/aggregation/breakdown_reveal.rs | 8 ++++++-- .../ipa_prf/boolean_ops/share_conversion_aby.rs | 2 +- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/ipa-core/src/helpers/transport/stream/mod.rs b/ipa-core/src/helpers/transport/stream/mod.rs index 2f2d6ccc6..59c76cdf4 100644 --- a/ipa-core/src/helpers/transport/stream/mod.rs +++ b/ipa-core/src/helpers/transport/stream/mod.rs @@ -187,7 +187,7 @@ mod tests { let stream = BodyStream::from_bytes_stream(stream::once(future::ready(Ok(Bytes::from(data))))); - stream.try_collect::>().await.unwrap() + stream.try_collect::>().await.unwrap(); }); } } diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index a4f625a4c..59cae0106 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -122,10 +122,10 @@ pub(crate) mod test_executor { pub(crate) mod test_executor { use std::future::Future; - pub fn run_with(f: F) -> T + pub fn run_with(f: F) where F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future, + Fut: Future, { tokio::runtime::Builder::new_multi_thread() // enable_all() is common to use to build Tokio runtime, but it enables both IO and time drivers. @@ -134,16 +134,16 @@ pub(crate) mod test_executor { .enable_time() .build() .unwrap() - .block_on(f()) + .block_on(f()); } #[allow(dead_code)] - pub fn run(f: F) -> T + pub fn run(f: F) where F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future, + Fut: Future, { - run_with::<_, _, _, 1>(f) + run_with::<_, _, 1>(f); } } diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index cd9ea3eb0..0f5269bbf 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -210,7 +210,7 @@ pub mod tests { secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom, }, - test_executor::run, + test_executor::run_with, test_fixture::{Reconstruct, Runner, TestWorld}, }; @@ -224,7 +224,11 @@ pub mod tests { #[test] fn semi_honest_happy_path() { - run(|| async { + // if shuttle executor is enabled, run this test only once. + // it is a very expensive test to explore all possible states, + // sometimes github bails after 40 minutes of running it + // (workers there are really slow). + run_with::<_, _, 3>(|| async { let world = TestWorld::default(); let mut rng = rand::thread_rng(); let mut expectation = Vec::new(); diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs index 05e3d3b62..a42bdbbbb 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs @@ -532,7 +532,7 @@ mod tests { .await .unwrap() }) - .await + .await; }); } From b628bf22c732874bb915eb936c0be3c7a2f58874 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 25 Sep 2024 20:41:44 -0700 Subject: [PATCH 081/191] Move some tests --- .../src/protocol/context/dzkp_validator.rs | 141 +++++++++++++++--- ipa-core/src/secret_sharing/vector/impls.rs | 101 ------------- 2 files changed, 123 insertions(+), 119 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 517d4db46..31821c592 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -827,31 +827,136 @@ mod tests { use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec}; use futures::{StreamExt, TryStreamExt}; use futures_util::stream::iter; - use proptest::{prop_compose, proptest, sample::select}; - use rand::{thread_rng, Rng}; + use proptest::{prelude::prop, prop_compose, proptest}; + use rand::{distributions::Standard, prelude::Distribution}; use crate::{ - error::Error, - ff::{boolean::Boolean, Fp61BitPrime}, - protocol::{ - basics::SecureMul, + error::Error, ff::{ + boolean::Boolean, boolean_array::{BooleanArray, BA16, BA20, BA256, BA32, BA64, BA8}, Fp61BitPrime + }, protocol::{ + basics::{select, BooleanArrayMul, SecureMul}, context::{ - dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, - dzkp_validator::{ + dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, dzkp_validator::{ Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, - }, - Context, UpgradableContext, TEST_DZKP_STEPS, + }, Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, UpgradableContext, TEST_DZKP_STEPS }, Gate, RecordId, - }, - secret_sharing::{ - replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, - Vectorizable, - }, - seq_join::{seq_join, SeqJoin}, - test_fixture::{join3v, Reconstruct, Runner, TestWorld}, + }, rand::{thread_rng, Rng}, secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, + IntoShares, SharedValue, Vectorizable, + }, seq_join::{seq_join, SeqJoin}, sharding::NotSharded, test_fixture::{join3v, Reconstruct, Runner, TestWorld} }; + async fn test_simplest_circuit_semi_honest() + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let world = TestWorld::default(); + let context = world.contexts(); + let mut rng = thread_rng(); + + let bit = rng.gen::(); + let a = rng.gen::(); + let b = rng.gen::(); + + let bit_shares = bit.share_with(&mut rng); + let a_shares = a.share_with(&mut rng); + let b_shares = b.share_with(&mut rng); + + let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( + |(ctx, (bit_share, (a_share, b_share)))| async move { + let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); + let sh_ctx = v.context(); + + let result = select( + sh_ctx.set_total_records(1), + RecordId::from(0), + &bit_share, + &a_share, + &b_share, + ) + .await?; + + v.validate().await?; + + Ok::<_, Error>(result) + }, + ); + + let [ab0, ab1, ab2] = join3v(futures).await; + + let ab = [ab0, ab1, ab2].reconstruct(); + + assert_eq!(ab, if bit.into() { a } else { b }); + } + + #[tokio::test] + async fn simplest_circuit_semi_honest() { + test_simplest_circuit_semi_honest::().await; + test_simplest_circuit_semi_honest::().await; + test_simplest_circuit_semi_honest::().await; + test_simplest_circuit_semi_honest::().await; + test_simplest_circuit_semi_honest::().await; + test_simplest_circuit_semi_honest::().await; + } + + async fn test_simplest_circuit_malicious() + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let world = TestWorld::default(); + let context = world.malicious_contexts(); + let mut rng = thread_rng(); + + let bit = rng.gen::(); + let a = rng.gen::(); + let b = rng.gen::(); + + let bit_shares = bit.share_with(&mut rng); + let a_shares = a.share_with(&mut rng); + let b_shares = b.share_with(&mut rng); + + let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( + |(ctx, (bit_share, (a_share, b_share)))| async move { + let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); + let m_ctx = v.context(); + + let result = select( + m_ctx.set_total_records(1), + RecordId::from(0), + &bit_share, + &a_share, + &b_share, + ) + .await?; + + v.validate().await?; + + Ok::<_, Error>(result) + }, + ); + + let [ab0, ab1, ab2] = join3v(futures).await; + + let ab = [ab0, ab1, ab2].reconstruct(); + + assert_eq!(ab, if bit.into() { a } else { b }); + } + + #[tokio::test] + async fn simplest_circuit_malicious() { + test_simplest_circuit_malicious::().await; + test_simplest_circuit_malicious::().await; + test_simplest_circuit_malicious::().await; + test_simplest_circuit_malicious::().await; + test_simplest_circuit_malicious::().await; + test_simplest_circuit_malicious::().await; + } + #[tokio::test] async fn dzkp_malicious() { const COUNT: usize = 32; @@ -1022,7 +1127,7 @@ mod tests { } prop_compose! { - fn arb_count_and_chunk()((log_count, log_multiplication_amount) in select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) { + fn arb_count_and_chunk()((log_count, log_multiplication_amount) in prop::sample::select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) { (1usize<(); - let a = rng.gen::<$vec>(); - let b = rng.gen::<$vec>(); - - let bit_shares = bit.share_with(&mut rng); - let a_shares = a.share_with(&mut rng); - let b_shares = b.share_with(&mut rng); - - let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))) - .map(|(ctx, (bit_share, (a_share, b_share)))| async move { - let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); - let m_ctx = v.context(); - - let result = select( - m_ctx.set_total_records(1), - RecordId::from(0), - &bit_share, - &a_share, - &b_share, - ) - .await?; - - v.validate().await?; - - Ok::<_, Error>(result) - }); - - let [ab0, ab1, ab2] = join3v(futures).await; - - let ab = [ab0, ab1, ab2].reconstruct(); - - assert_eq!(ab, if bit.into() { a } else { b }); - } - - #[tokio::test] - async fn simplest_circuit_semi_honest() { - let world = TestWorld::default(); - let context = world.contexts(); - let mut rng = thread_rng(); - - let bit = rng.gen::(); - let a = rng.gen::<$vec>(); - let b = rng.gen::<$vec>(); - - let bit_shares = bit.share_with(&mut rng); - let a_shares = a.share_with(&mut rng); - let b_shares = b.share_with(&mut rng); - - let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))) - .map(|(ctx, (bit_share, (a_share, b_share)))| async move { - let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); - let sh_ctx = v.context(); - - let result = select( - sh_ctx.set_total_records(1), - RecordId::from(0), - &bit_share, - &a_share, - &b_share, - ) - .await?; - - v.validate().await?; - - Ok::<_, Error>(result) - }); - - let [ab0, ab1, ab2] = join3v(futures).await; - - let ab = [ab0, ab1, ab2].reconstruct(); - - assert_eq!(ab, if bit.into() { a } else { b }); - } - } } }; } From 011f03a57276be331a86e4fa2b98fb70f8cb550e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 25 Sep 2024 21:02:32 -0700 Subject: [PATCH 082/191] Parametrize MAC-malicious contexts with ShardBinding parameter This is just a first step to enable sharded circuits to be written with malicious security. It is far from being complete, but it gradually moves us towards that goal. This change just enables malicious contexts to support sharding. Because of type aliases, everything else from the outside of this module still uses non-sharded versions --- ipa-core/src/protocol/context/malicious.rs | 79 +++++++++++----------- ipa-core/src/protocol/context/mod.rs | 7 +- ipa-core/src/sharding.rs | 2 +- ipa-core/src/test_fixture/world.rs | 2 +- 4 files changed, 47 insertions(+), 43 deletions(-) diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index a93a8edfb..36ffb48bb 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -21,6 +21,7 @@ use crate::{ validator::{self, BatchValidator}, Base, Context as ContextTrait, InstrumentedSequentialSharedRandomness, SpecialAccessToUpgradedContext, UpgradableContext, UpgradedContext, + UpgradedMaliciousContext, }, prss::{Endpoint as PrssEndpoint, FromPrss}, Gate, RecordId, @@ -30,7 +31,7 @@ use crate::{ semi_honest::AdditiveShare as Replicated, }, seq_join::SeqJoin, - sharding::NotSharded, + sharding::{NotSharded, ShardBinding}, sync::Arc, }; @@ -49,28 +50,29 @@ pub(crate) const TEST_DZKP_STEPS: MaliciousProtocolSteps< }; #[derive(Clone)] -pub struct Context<'a> { - inner: Base<'a>, +pub struct Context<'a, B: ShardBinding> { + inner: Base<'a, B>, } -impl<'a> Context<'a> { +impl<'a> Context<'a, NotSharded> { pub fn new(participant: &'a PrssEndpoint, gateway: &'a Gateway) -> Self { - Self::new_with_gate(participant, gateway, Gate::default()) + Self::new_with_gate(participant, gateway, Gate::default(), NotSharded) } +} - pub fn new_with_gate(participant: &'a PrssEndpoint, gateway: &'a Gateway, gate: Gate) -> Self { +impl<'a, B: ShardBinding> Context<'a, B> { + pub fn new_with_gate( + participant: &'a PrssEndpoint, + gateway: &'a Gateway, + gate: Gate, + shard: B, + ) -> Self { Self { - inner: Base::new_complete( - participant, - gateway, - gate, - TotalRecords::Unspecified, - NotSharded, - ), + inner: Base::new_complete(participant, gateway, gate, TotalRecords::Unspecified, shard), } } - pub(crate) fn validator_context(self) -> Base<'a> { + pub(crate) fn validator_context(self) -> Base<'a, B> { // The DZKP validator uses communcation channels internally. We don't want any TotalRecords // set by the protocol to apply to those channels. Base { @@ -80,7 +82,7 @@ impl<'a> Context<'a> { } } -impl<'a> super::Context for Context<'a> { +impl<'a, B: ShardBinding> super::Context for Context<'a, B> { fn role(&self) -> Role { self.inner.role() } @@ -130,7 +132,7 @@ impl<'a> super::Context for Context<'a> { } } -impl<'a> UpgradableContext for Context<'a> { +impl<'a> UpgradableContext for Context<'a, NotSharded> { type Validator = BatchValidator<'a, F>; fn validator(self) -> Self::Validator { @@ -152,13 +154,13 @@ impl<'a> UpgradableContext for Context<'a> { } } -impl<'a> SeqJoin for Context<'a> { +impl<'a, B: ShardBinding> SeqJoin for Context<'a, B> { fn active_work(&self) -> NonZeroUsize { self.inner.active_work() } } -impl Debug for Context<'_> { +impl Debug for Context<'_, B> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "MaliciousContext") } @@ -171,13 +173,13 @@ pub(super) type MacBatcher<'a, F> = Mutex { +pub struct Upgraded<'a, F: ExtendableField, B: ShardBinding> { batch: Weak>, - base_ctx: Context<'a>, + base_ctx: Context<'a, B>, } -impl<'a, F: ExtendableField> Upgraded<'a, F> { - pub(super) fn new(batch: &Arc>, ctx: Context<'a>) -> Self { +impl<'a, F: ExtendableField, B: ShardBinding> Upgraded<'a, F, B> { + pub(super) fn new(batch: &Arc>, ctx: Context<'a, B>) -> Self { Self { batch: Arc::downgrade(batch), base_ctx: ctx, @@ -227,7 +229,7 @@ impl<'a, F: ExtendableField> Upgraded<'a, F> { } #[async_trait] -impl<'a, F: ExtendableField> UpgradedContext for Upgraded<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> UpgradedContext for Upgraded<'a, F, B> { type Field = F; async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> { @@ -243,7 +245,7 @@ impl<'a, F: ExtendableField> UpgradedContext for Upgraded<'a, F> { } } -impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> super::Context for Upgraded<'a, F, B> { fn role(&self) -> Role { self.base_ctx.role() } @@ -295,7 +297,7 @@ impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { } } -impl<'a, F: ExtendableField> SeqJoin for Upgraded<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> SeqJoin for Upgraded<'a, F, B> { fn active_work(&self) -> NonZeroUsize { self.base_ctx.active_work() } @@ -305,7 +307,7 @@ impl<'a, F: ExtendableField> SeqJoin for Upgraded<'a, F> { /// protocols should be generic over `SecretShare` trait and not requiring this cast and taking /// `ProtocolContext<'a, S: SecretShare, F: Field>` as the context. If that is not possible, /// this implementation makes it easier to reinterpret the context as semi-honest. -impl<'a, F: ExtendableField> SpecialAccessToUpgradedContext for Upgraded<'a, F> { +impl<'a, F: ExtendableField> SpecialAccessToUpgradedContext for Upgraded<'a, F, NotSharded> { type Base = Base<'a>; fn base_context(self) -> Self::Base { @@ -313,7 +315,7 @@ impl<'a, F: ExtendableField> SpecialAccessToUpgradedContext for Upgraded<'a, } } -impl Debug for Upgraded<'_, F> { +impl Debug for Upgraded<'_, F, B> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "MaliciousContext<{:?}>", type_name::()) } @@ -322,7 +324,8 @@ impl Debug for Upgraded<'_, F> { /// Upgrading a semi-honest replicated share using malicious context produces /// a MAC-secured share with the same vectorization factor. #[async_trait] -impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> for Replicated +impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> + for Replicated where Replicated<::ExtendedField, N>: FromPrss, { @@ -330,7 +333,7 @@ where async fn upgrade( self, - ctx: Upgraded<'a, V>, + ctx: UpgradedMaliciousContext<'a, V>, record_id: RecordId, ) -> Result { let ctx = ctx.narrow(&UpgradeStep); @@ -364,7 +367,7 @@ where #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> +impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> for (Replicated, Replicated) where Replicated<::ExtendedField, N>: FromPrss, @@ -373,7 +376,7 @@ where async fn upgrade( self, - ctx: Upgraded<'a, V>, + ctx: UpgradedMaliciousContext<'a, V>, record_id: RecordId, ) -> Result { let (l, r) = self; @@ -385,12 +388,12 @@ where #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V: ExtendableField> Upgradable> for () { +impl<'a, V: ExtendableField> Upgradable> for () { type Output = (); async fn upgrade( self, - _context: Upgraded<'a, V>, + _context: UpgradedMaliciousContext<'a, V>, _record_id: RecordId, ) -> Result { Ok(()) @@ -399,28 +402,28 @@ impl<'a, V: ExtendableField> Upgradable> for () { #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V, U> Upgradable> for Vec +impl<'a, V, U> Upgradable> for Vec where V: ExtendableField, - U: Upgradable, Output: Send> + Send + 'a, + U: Upgradable, Output: Send> + Send + 'a, { type Output = Vec; async fn upgrade( self, - ctx: Upgraded<'a, V>, + ctx: UpgradedMaliciousContext<'a, V>, record_id: RecordId, ) -> Result { /// Need a standalone function to avoid GAT issue that apparently can manifest /// even with `async_trait`. fn upgrade_vec<'a, V, U>( - ctx: Upgraded<'a, V>, + ctx: UpgradedMaliciousContext<'a, V>, record_id: RecordId, input: Vec, ) -> impl std::future::Future, Error>> + 'a where V: ExtendableField, - U: Upgradable> + 'a, + U: Upgradable> + 'a, { let mut upgraded = Vec::with_capacity(input.len()); async move { diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 4a090bae3..5e1b6ebd7 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -18,15 +18,16 @@ pub use dzkp_malicious::DZKPUpgraded as DZKPUpgradedMaliciousContext; pub use dzkp_semi_honest::DZKPUpgraded as DZKPUpgradedSemiHonestContext; use futures::{stream, Stream, StreamExt}; use ipa_step::{Step, StepNarrow}; -pub use malicious::{ - Context as MaliciousContext, MaliciousProtocolSteps, Upgraded as UpgradedMaliciousContext, -}; +pub use malicious::MaliciousProtocolSteps; use prss::{InstrumentedIndexedSharedRandomness, InstrumentedSequentialSharedRandomness}; pub use semi_honest::Upgraded as UpgradedSemiHonestContext; pub use validator::Validator; pub type SemiHonestContext<'a, B = NotSharded> = semi_honest::Context<'a, B>; pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>; +pub type MaliciousContext<'a, B = NotSharded> = malicious::Context<'a, B>; +pub type UpgradedMaliciousContext<'a, F> = malicious::Upgraded<'a, F, NotSharded>; + #[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))] pub(crate) use malicious::TEST_DZKP_STEPS; diff --git a/ipa-core/src/sharding.rs b/ipa-core/src/sharding.rs index 625f724e6..e4f9475b7 100644 --- a/ipa-core/src/sharding.rs +++ b/ipa-core/src/sharding.rs @@ -56,7 +56,7 @@ pub trait ShardConfiguration { } } -pub trait ShardBinding: Debug + Send + Sync + Clone {} +pub trait ShardBinding: Debug + Send + Sync + Clone + 'static {} #[derive(Debug, Copy, Clone)] pub struct NotSharded; diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index f92326c9b..cd6919580 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -780,7 +780,7 @@ impl ShardWorld { #[must_use] pub fn malicious_contexts(&self, gate: &Gate) -> [MaliciousContext<'_>; 3] { zip3_ref(&self.participants, &self.gateways).map(|(participant, gateway)| { - MaliciousContext::new_with_gate(participant, gateway, gate.clone()) + MaliciousContext::new_with_gate(participant, gateway, gate.clone(), NotSharded) }) } } From a4c6f03f8674328d1c879be10d25c89047c35606 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 25 Sep 2024 21:08:12 -0700 Subject: [PATCH 083/191] Make the DZKP batching proptest more comprehensive --- .../src/protocol/context/dzkp_validator.rs | 95 +++++++++++-------- 1 file changed, 56 insertions(+), 39 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 31821c592..87f442c77 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -827,27 +827,39 @@ mod tests { use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec}; use futures::{StreamExt, TryStreamExt}; use futures_util::stream::iter; - use proptest::{prelude::prop, prop_compose, proptest}; + use proptest::{prelude::Strategy, prop_oneof, proptest}; use rand::{distributions::Standard, prelude::Distribution}; use crate::{ - error::Error, ff::{ - boolean::Boolean, boolean_array::{BooleanArray, BA16, BA20, BA256, BA32, BA64, BA8}, Fp61BitPrime - }, protocol::{ + error::Error, + ff::{ + boolean::Boolean, + boolean_array::{BooleanArray, BA16, BA20, BA256, BA32, BA64, BA8}, + Fp61BitPrime, + }, + protocol::{ basics::{select, BooleanArrayMul, SecureMul}, context::{ - dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, dzkp_validator::{ + dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, + dzkp_validator::{ Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, - }, Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, UpgradableContext, TEST_DZKP_STEPS + }, + Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, + UpgradableContext, TEST_DZKP_STEPS, }, Gate, RecordId, - }, rand::{thread_rng, Rng}, secret_sharing::{ - replicated::semi_honest::AdditiveShare as Replicated, - IntoShares, SharedValue, Vectorizable, - }, seq_join::{seq_join, SeqJoin}, sharding::NotSharded, test_fixture::{join3v, Reconstruct, Runner, TestWorld} + }, + rand::{thread_rng, Rng}, + secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, + Vectorizable, + }, + seq_join::seq_join, + sharding::NotSharded, + test_fixture::{join3v, Reconstruct, Runner, TestWorld}, }; - async fn test_simplest_circuit_semi_honest() + async fn test_select_semi_honest() where V: BooleanArray, for<'a> Replicated: BooleanArrayMul>, @@ -893,16 +905,16 @@ mod tests { } #[tokio::test] - async fn simplest_circuit_semi_honest() { - test_simplest_circuit_semi_honest::().await; - test_simplest_circuit_semi_honest::().await; - test_simplest_circuit_semi_honest::().await; - test_simplest_circuit_semi_honest::().await; - test_simplest_circuit_semi_honest::().await; - test_simplest_circuit_semi_honest::().await; + async fn select_semi_honest() { + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; + test_select_semi_honest::().await; } - async fn test_simplest_circuit_malicious() + async fn test_select_malicious() where V: BooleanArray, for<'a> Replicated: BooleanArrayMul>, @@ -948,17 +960,17 @@ mod tests { } #[tokio::test] - async fn simplest_circuit_malicious() { - test_simplest_circuit_malicious::().await; - test_simplest_circuit_malicious::().await; - test_simplest_circuit_malicious::().await; - test_simplest_circuit_malicious::().await; - test_simplest_circuit_malicious::().await; - test_simplest_circuit_malicious::().await; + async fn select_malicious() { + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; + test_select_malicious::().await; } #[tokio::test] - async fn dzkp_malicious() { + async fn two_multiplies_malicious() { const COUNT: usize = 32; let mut rng = thread_rng(); @@ -1019,8 +1031,8 @@ mod tests { } /// test for testing `validated_seq_join` - /// similar to `complex_circuit` in `validator.rs` - async fn complex_circuit_dzkp( + /// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment) + async fn chained_multiplies_dzkp( count: usize, max_multiplications_per_gate: usize, ) -> Result<(), Error> { @@ -1050,7 +1062,7 @@ mod tests { .map(|(ctx, input_shares)| async move { let v = ctx .set_total_records(count - 1) - .dzkp_validator(TEST_DZKP_STEPS, ctx.active_work().get()); + .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); let m_ctx = v.context(); let m_results = v @@ -1126,19 +1138,24 @@ mod tests { Ok(()) } - prop_compose! { - fn arb_count_and_chunk()((log_count, log_multiplication_amount) in prop::sample::select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) { - (1usize< impl Strategy { + prop_oneof![1usize..=512, (1usize..=9).prop_map(|i| 1usize << i)] + } + + fn max_multiplications_per_gate_strategy() -> impl Strategy { + prop_oneof![1usize..=128, (1usize..=7).prop_map(|i| 1usize << i)] } proptest! { #[test] - fn test_complex_circuit_dzkp((count, multiplication_amount) in arb_count_and_chunk()){ - let future = async { - let _ = complex_circuit_dzkp(count, multiplication_amount).await; - }; - tokio::runtime::Runtime::new().unwrap().block_on(future); + fn test_chained_multiplies_dzkp( + record_count in record_count_strategy(), + max_multiplications_per_gate in max_multiplications_per_gate_strategy(), + ) { + println!("record_count {record_count} batch {max_multiplications_per_gate}"); + tokio::runtime::Runtime::new().unwrap().block_on(async { + chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap(); + }); } } From 1a6e388449c6fbfe2b94f2166526d15969c89d23 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 25 Sep 2024 09:15:17 -0700 Subject: [PATCH 084/191] Make active_work at least records_per_batch --- ipa-core/src/protocol/context/batcher.rs | 4 ++++ ipa-core/src/protocol/context/dzkp_malicious.rs | 11 ++++++++++- ipa-core/src/protocol/context/malicious.rs | 15 ++++++++++++--- ipa-core/src/protocol/context/validator.rs | 2 +- ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs | 4 +--- 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/ipa-core/src/protocol/context/batcher.rs b/ipa-core/src/protocol/context/batcher.rs index f7021d988..2cd00be01 100644 --- a/ipa-core/src/protocol/context/batcher.rs +++ b/ipa-core/src/protocol/context/batcher.rs @@ -97,6 +97,10 @@ impl<'a, B> Batcher<'a, B> { self.total_records = self.total_records.overwrite(total_records.into()); } + pub fn records_per_batch(&self) -> usize { + self.records_per_batch + } + fn batch_offset(&self, record_id: RecordId) -> usize { let batch_index = usize::from(record_id) / self.records_per_batch; batch_index diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 73cfda40c..fd37232b0 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -1,4 +1,5 @@ use std::{ + cmp::max, fmt::{Debug, Formatter}, num::NonZeroUsize, }; @@ -29,6 +30,7 @@ use crate::{ pub struct DZKPUpgraded<'a> { validator_inner: Weak>, base_ctx: MaliciousContext<'a>, + active_work: NonZeroUsize, } impl<'a> DZKPUpgraded<'a> { @@ -36,9 +38,16 @@ impl<'a> DZKPUpgraded<'a> { validator_inner: &Arc>, base_ctx: MaliciousContext<'a>, ) -> Self { + // Adjust active_work to be at least records_per_batch. If it is less, we will + // stall, since every record in the batch remains incomplete until the batch is + // validated. + let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch(); + let active_work = + NonZeroUsize::new(max(base_ctx.active_work().get(), records_per_batch)).unwrap(); Self { validator_inner: Arc::downgrade(validator_inner), base_ctx, + active_work, } } @@ -130,7 +139,7 @@ impl<'a> super::Context for DZKPUpgraded<'a> { impl<'a> SeqJoin for DZKPUpgraded<'a> { fn active_work(&self) -> NonZeroUsize { - self.base_ctx.active_work() + self.active_work } } diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index a93a8edfb..ceee87fdf 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -1,5 +1,6 @@ use std::{ any::type_name, + cmp::max, fmt::{Debug, Formatter}, num::NonZeroUsize, }; @@ -174,13 +175,21 @@ pub(super) type MacBatcher<'a, F> = Mutex { batch: Weak>, base_ctx: Context<'a>, + active_work: NonZeroUsize, } impl<'a, F: ExtendableField> Upgraded<'a, F> { - pub(super) fn new(batch: &Arc>, ctx: Context<'a>) -> Self { + pub(super) fn new(batch: &Arc>, base_ctx: Context<'a>) -> Self { + // Adjust active_work to be at least records_per_batch. The MAC validator + // currently configures the batcher with records_per_batch = active_work, which + // makes this adjustment a no-op, but we do it this way to match the DZKP validator. + let records_per_batch = batch.lock().unwrap().records_per_batch(); + let active_work = + NonZeroUsize::new(max(base_ctx.active_work().get(), records_per_batch)).unwrap(); Self { batch: Arc::downgrade(batch), - base_ctx: ctx, + base_ctx, + active_work, } } @@ -297,7 +306,7 @@ impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { impl<'a, F: ExtendableField> SeqJoin for Upgraded<'a, F> { fn active_work(&self) -> NonZeroUsize { - self.base_ctx.active_work() + self.active_work } } diff --git a/ipa-core/src/protocol/context/validator.rs b/ipa-core/src/protocol/context/validator.rs index 33303fb9b..e57ae3c6a 100644 --- a/ipa-core/src/protocol/context/validator.rs +++ b/ipa-core/src/protocol/context/validator.rs @@ -217,7 +217,7 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> { // TODO: Right now we set the batch work to be equal to active_work, // but it does not need to be. We can make this configurable if needed. - let records_per_batch = ctx.active_work().get().min(total_records.get()); + let records_per_batch = ctx.active_work().get(); Self { protocol_ctx: ctx.narrow(&Step::MaliciousProtocol), diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 267631da0..a745d5ea7 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -510,9 +510,7 @@ where protocol: &Step::Attribute, validate: &Step::AttributeValidate, }, - // The size of a single batch should not exceed the active work limit, - // otherwise it will stall - std::cmp::min(sh_ctx.active_work().get(), chunk_size), + chunk_size, ); dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?; From a7b26f8753d1552a38287b0582c5e9ccb9c981d2 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 26 Sep 2024 13:13:30 -0700 Subject: [PATCH 085/191] Print send config in tracing It was proven useful in several investigations, so we should just enable it --- ipa-core/src/helpers/gateway/send.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index fc73caf5d..07018fb14 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -203,10 +203,9 @@ impl GatewaySenders { match self.inner.entry(channel_id.clone()) { Entry::Occupied(entry) => Arc::clone(entry.get()), Entry::Vacant(entry) => { - let sender = Self::new_sender( - &SendChannelConfig::new::(config, total_records), - channel_id.clone(), - ); + let config = SendChannelConfig::new::(config, total_records); + tracing::trace!("send configuration for {channel_id:?}: {config:?}"); + let sender = Self::new_sender(&config, channel_id.clone()); entry.insert(Arc::clone(&sender)); tokio::spawn({ From 918f1430bda4e920955528a38297efc287e8c453 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 26 Sep 2024 15:04:28 -0700 Subject: [PATCH 086/191] Allow active work to be overridden by contexts Currently, send channels derive the active work from `GatewayConfig`. #1306 highlighted a need to be able to reconfigure `active_work` based on DZKP batch size. This PR makes it possible to do by allowing `get_sender` method to take `active_work` parameter --- ipa-core/src/helpers/gateway/mod.rs | 14 +++++++- .../src/helpers/gateway/stall_detection.rs | 4 ++- ipa-core/src/helpers/prss_protocol.rs | 12 +++++-- ipa-core/src/protocol/context/mod.rs | 32 ++++++++++++++++--- 4 files changed, 54 insertions(+), 8 deletions(-) diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index b1c57b3cc..66f8326cb 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -150,12 +150,15 @@ impl Gateway { &self, channel_id: &HelperChannelId, total_records: TotalRecords, + active_work: NonZeroUsize, ) -> send::SendingEnd { let transport = &self.transports.mpc; let channel = self.inner.mpc_senders.get::( channel_id, transport, - self.config, + // we override the active work provided in config if caller + // wants to use a different value. + self.config.set_active_work(active_work), self.query_id, total_records, ); @@ -280,6 +283,15 @@ impl GatewayConfig { // we set active to be at least 2, so unwrap is fine. self.active = NonZeroUsize::new(active).unwrap(); } + + /// Creates a new configuration by overriding the value of active work. + #[must_use] + pub fn set_active_work(&self, active_work: NonZeroUsize) -> Self { + Self { + active: active_work, + ..*self + } + } } #[cfg(all(test, unit_test))] diff --git a/ipa-core/src/helpers/gateway/stall_detection.rs b/ipa-core/src/helpers/gateway/stall_detection.rs index 49a879be4..43706f450 100644 --- a/ipa-core/src/helpers/gateway/stall_detection.rs +++ b/ipa-core/src/helpers/gateway/stall_detection.rs @@ -67,6 +67,7 @@ impl Observed { } mod gateway { + use std::num::NonZeroUsize; use delegate::delegate; @@ -153,12 +154,13 @@ mod gateway { &self, channel_id: &HelperChannelId, total_records: TotalRecords, + active_work: NonZeroUsize, ) -> SendingEnd { Observed::wrap( Weak::clone(self.get_sn()), self.inner() .gateway - .get_mpc_sender(channel_id, total_records), + .get_mpc_sender(channel_id, total_records, active_work), ) } diff --git a/ipa-core/src/helpers/prss_protocol.rs b/ipa-core/src/helpers/prss_protocol.rs index 8171ca019..f9284f9eb 100644 --- a/ipa-core/src/helpers/prss_protocol.rs +++ b/ipa-core/src/helpers/prss_protocol.rs @@ -21,8 +21,16 @@ pub async fn negotiate( let left_channel = ChannelId::new(gateway.role().peer(Direction::Left), gate.clone()); let right_channel = ChannelId::new(gateway.role().peer(Direction::Right), gate.clone()); - let left_sender = gateway.get_mpc_sender::(&left_channel, TotalRecords::ONE); - let right_sender = gateway.get_mpc_sender::(&right_channel, TotalRecords::ONE); + let left_sender = gateway.get_mpc_sender::( + &left_channel, + TotalRecords::ONE, + gateway.config().active_work(), + ); + let right_sender = gateway.get_mpc_sender::( + &right_channel, + TotalRecords::ONE, + gateway.config().active_work(), + ); let left_receiver = gateway.get_mpc_receiver::(&left_channel); let right_receiver = gateway.get_mpc_receiver::(&right_channel); diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 4a090bae3..dfbc749be 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -162,6 +162,7 @@ pub struct Base<'a, B: ShardBinding = NotSharded> { inner: Inner<'a>, gate: Gate, total_records: TotalRecords, + active_work: NonZeroUsize, /// This indicates whether the system uses sharding or no. It's not ideal that we keep it here /// because it gets cloned often, a potential solution to that, if this shows up on flame graph, /// would be to move it to [`Inner`] struct. @@ -175,11 +176,30 @@ impl<'a, B: ShardBinding> Base<'a, B> { gate: Gate, total_records: TotalRecords, sharding: B, + ) -> Self { + Self::new_with_active_work( + participant, + gateway, + gate, + total_records, + gateway.config().active_work(), + sharding, + ) + } + + fn new_with_active_work( + participant: &'a PrssEndpoint, + gateway: &'a Gateway, + gate: Gate, + total_records: TotalRecords, + active_work: NonZeroUsize, + sharding: B, ) -> Self { Self { inner: Inner::new(participant, gateway), gate, total_records, + active_work, sharding, } } @@ -217,6 +237,7 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> { inner: self.inner.clone(), gate: self.gate.narrow(step), total_records: self.total_records, + active_work: self.active_work, sharding: self.sharding.clone(), } } @@ -226,6 +247,7 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> { inner: self.inner.clone(), gate: self.gate.clone(), total_records: self.total_records.overwrite(total_records), + active_work: self.active_work, sharding: self.sharding.clone(), } } @@ -254,9 +276,11 @@ impl<'a, B: ShardBinding> Context for Base<'a, B> { } fn send_channel(&self, role: Role) -> SendingEnd { - self.inner - .gateway - .get_mpc_sender(&ChannelId::new(role, self.gate.clone()), self.total_records) + self.inner.gateway.get_mpc_sender( + &ChannelId::new(role, self.gate.clone()), + self.total_records, + self.active_work, + ) } fn recv_channel(&self, role: Role) -> MpcReceivingEnd { @@ -322,7 +346,7 @@ impl ShardConfiguration for Base<'_, Sharded> { impl<'a, B: ShardBinding> SeqJoin for Base<'a, B> { fn active_work(&self) -> NonZeroUsize { - self.inner.gateway.config().active_work() + self.active_work } } From 706dcbea218d39e35e64fd7273561316bdf50546 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 26 Sep 2024 11:22:45 -0700 Subject: [PATCH 087/191] Improvements to batching proptest --- .../src/protocol/context/dzkp_validator.rs | 112 ++++++++++++++++-- 1 file changed, 100 insertions(+), 12 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 87f442c77..2ff635d02 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -825,9 +825,13 @@ mod tests { }; use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec}; - use futures::{StreamExt, TryStreamExt}; + use futures::{stream, StreamExt, TryStreamExt}; use futures_util::stream::iter; - use proptest::{prelude::Strategy, prop_oneof, proptest}; + use proptest::{ + prelude::{Just, Strategy}, + prop_compose, prop_oneof, proptest, + test_runner::Config as ProptestConfig, + }; use rand::{distributions::Standard, prelude::Distribution}; use crate::{ @@ -1030,6 +1034,51 @@ mod tests { } } + /// Similar to `test_select_malicious`, but operating on vectors + async fn multi_select_malicious(count: usize, max_multiplications_per_gate: usize) + where + V: BooleanArray, + for<'a> Replicated: BooleanArrayMul>, + Standard: Distribution, + { + let mut rng = thread_rng(); + + let bit: Vec = repeat_with(|| rng.gen::()).take(count).collect(); + let a: Vec = repeat_with(|| rng.gen()).take(count).collect(); + let b: Vec = repeat_with(|| rng.gen()).take(count).collect(); + + let [ab0, ab1, ab2]: [Vec>; 3] = TestWorld::default() + .malicious( + zip(bit.clone(), zip(a.clone(), b.clone())), + |ctx, inputs| async move { + let v = ctx + .set_total_records(count) + .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); + let m_ctx = v.context(); + + v.validated_seq_join(stream::iter(inputs).enumerate().map( + |(i, (bit_share, (a_share, b_share)))| { + let m_ctx = m_ctx.clone(); + async move { + select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share) + .await + } + }, + )) + .try_collect() + .await + }, + ) + .await + .map(Result::unwrap); + + let ab: Vec = [ab0, ab1, ab2].reconstruct(); + + for i in 0..count { + assert_eq!(ab[i], if bit[i].into() { a[i] } else { b[i] }); + } + } + /// test for testing `validated_seq_join` /// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment) async fn chained_multiplies_dzkp( @@ -1139,23 +1188,62 @@ mod tests { } fn record_count_strategy() -> impl Strategy { - prop_oneof![1usize..=512, (1usize..=9).prop_map(|i| 1usize << i)] + // The chained_multiplies test has count - 1 records, so 1 is not a valid input size. + // It is for multi_select though. + prop_oneof![2usize..=512, (1u32..=9).prop_map(|i| 1usize << i)] + } + + fn max_multiplications_per_gate_strategy(record_count: usize) -> impl Strategy { + let max_max_mults = record_count.min(128); + prop_oneof![ + 1usize..=max_max_mults, + (0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i) + ] } - fn max_multiplications_per_gate_strategy() -> impl Strategy { - prop_oneof![1usize..=128, (1usize..=7).prop_map(|i| 1usize << i)] + prop_compose! { + fn batching() + (record_count in record_count_strategy()) + (record_count in Just(record_count), max_mults in max_multiplications_per_gate_strategy(record_count)) + -> (usize, usize) + { + (record_count, max_mults) + } } proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] #[test] - fn test_chained_multiplies_dzkp( - record_count in record_count_strategy(), - max_multiplications_per_gate in max_multiplications_per_gate_strategy(), - ) { + fn batching_proptest((record_count, max_multiplications_per_gate) in batching()) { println!("record_count {record_count} batch {max_multiplications_per_gate}"); - tokio::runtime::Runtime::new().unwrap().block_on(async { - chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap(); - }); + if record_count / max_multiplications_per_gate >= 192 { + // TODO: #1269, or even if we don't fix that, don't hardcode the limit. + println!("skipping config because batch count exceeds limit of 192"); + } + // This condition is correct only for active_work = 16 and record size of 1 byte. + else if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 { + // TODO: #1300, read_size | batch_size. + // Note: for active work < 2048, read size matches active work. + + // Besides read_size | batch_size, there is also a constraint + // something like active_work > read_size + batch_size - 1. + println!("skipping config due to read_size vs. batch_size constraints"); + } else { + tokio::runtime::Runtime::new().unwrap().block_on(async { + chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap(); + /* + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + */ + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + /* + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + multi_select_malicious::(record_count, max_multiplications_per_gate).await; + */ + }); + } } } From c5c0ccf8a35d5dbbb04a9161228eb5aef0d70bd8 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 26 Sep 2024 11:22:35 -0700 Subject: [PATCH 088/191] Revise based on PR feedback and offline discussion --- .../src/protocol/context/dzkp_malicious.rs | 25 ++++++++++++++----- .../src/protocol/context/dzkp_validator.rs | 4 ++- ipa-core/src/protocol/context/malicious.rs | 23 +++++++++-------- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index fd37232b0..70dd3d2af 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -1,5 +1,4 @@ use std::{ - cmp::max, fmt::{Debug, Formatter}, num::NonZeroUsize, }; @@ -38,12 +37,26 @@ impl<'a> DZKPUpgraded<'a> { validator_inner: &Arc>, base_ctx: MaliciousContext<'a>, ) -> Self { - // Adjust active_work to be at least records_per_batch. If it is less, we will - // stall, since every record in the batch remains incomplete until the batch is - // validated. let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch(); - let active_work = - NonZeroUsize::new(max(base_ctx.active_work().get(), records_per_batch)).unwrap(); + let active_work = if records_per_batch == 1 { + // If records_per_batch is 1, let active_work be anything. This only happens + // in tests; there shouldn't be a risk of deadlocks with one record per + // batch; and UnorderedReceiver capacity (which is set from active_work) + // must be at least two. + base_ctx.active_work() + } else { + // Adjust active_work to match records_per_batch. If it is less, we will + // certainly stall, since every record in the batch remains incomplete until + // the batch is validated. It is possible that it can be larger, but making + // it the same seems safer for now. + let active_work = NonZeroUsize::new(records_per_batch).unwrap(); + tracing::debug!( + "Changed active_work from {} to {} to match batch size", + base_ctx.active_work().get(), + active_work, + ); + active_work + }; Self { validator_inner: Arc::downgrade(validator_inner), base_ctx, diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 2ff635d02..f5586336f 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -838,7 +838,7 @@ mod tests { error::Error, ff::{ boolean::Boolean, - boolean_array::{BooleanArray, BA16, BA20, BA256, BA32, BA64, BA8}, + boolean_array::{BooleanArray, BA16, BA20, BA256, BA3, BA32, BA64, BA8}, Fp61BitPrime, }, protocol::{ @@ -910,6 +910,7 @@ mod tests { #[tokio::test] async fn select_semi_honest() { + test_select_semi_honest::().await; test_select_semi_honest::().await; test_select_semi_honest::().await; test_select_semi_honest::().await; @@ -965,6 +966,7 @@ mod tests { #[tokio::test] async fn select_malicious() { + test_select_malicious::().await; test_select_malicious::().await; test_select_malicious::().await; test_select_malicious::().await; diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index ceee87fdf..18b9b8e29 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -1,6 +1,5 @@ use std::{ any::type_name, - cmp::max, fmt::{Debug, Formatter}, num::NonZeroUsize, }; @@ -175,21 +174,23 @@ pub(super) type MacBatcher<'a, F> = Mutex { batch: Weak>, base_ctx: Context<'a>, - active_work: NonZeroUsize, } impl<'a, F: ExtendableField> Upgraded<'a, F> { - pub(super) fn new(batch: &Arc>, base_ctx: Context<'a>) -> Self { - // Adjust active_work to be at least records_per_batch. The MAC validator - // currently configures the batcher with records_per_batch = active_work, which - // makes this adjustment a no-op, but we do it this way to match the DZKP validator. + pub(super) fn new(batch: &Arc>, ctx: Context<'a>) -> Self { + // The DZKP malicious context adjusts active_work to match records_per_batch. + // The MAC validator currently configures the batcher with records_per_batch = + // active_work. If the latter behavior changes, this code may need to be + // updated. let records_per_batch = batch.lock().unwrap().records_per_batch(); - let active_work = - NonZeroUsize::new(max(base_ctx.active_work().get(), records_per_batch)).unwrap(); + let active_work = ctx.active_work().get(); + assert_eq!( + records_per_batch, active_work, + "Expect MAC validation batch size ({records_per_batch}) to match active work ({active_work})", + ); Self { batch: Arc::downgrade(batch), - base_ctx, - active_work, + base_ctx: ctx, } } @@ -306,7 +307,7 @@ impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { impl<'a, F: ExtendableField> SeqJoin for Upgraded<'a, F> { fn active_work(&self) -> NonZeroUsize { - self.active_work + self.base_ctx.active_work() } } From eea8b6a1ff6399df1e89e3b1702c08d9083139a1 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 26 Sep 2024 16:15:44 -0700 Subject: [PATCH 089/191] Add a unit test to make sure active work is adjusted correctly --- ipa-core/src/helpers/gateway/mod.rs | 49 ++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 66f8326cb..e654f85f7 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -296,7 +296,10 @@ impl GatewayConfig { #[cfg(all(test, unit_test))] mod tests { - use std::iter::{repeat, zip}; + use std::{ + iter::{repeat, zip}, + num::NonZeroUsize, + }; use futures::{ future::{join, try_join, try_join_all}, @@ -305,12 +308,14 @@ mod tests { use crate::{ ff::{boolean_array::BA3, Fp31, Fp32BitPrime, Gf2, U128Conversions}, - helpers::{Direction, GatewayConfig, MpcMessage, Role, SendingEnd}, + helpers::{ + ChannelId, Direction, GatewayConfig, MpcMessage, Role, SendingEnd, TotalRecords, + }, protocol::{ context::{Context, ShardedContext}, - RecordId, + Gate, RecordId, }, - secret_sharing::replicated::semi_honest::AdditiveShare, + secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue}, sharding::ShardConfiguration, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards}, @@ -528,6 +533,42 @@ mod tests { }); } + #[test] + fn custom_active_work() { + run(|| async move { + let world = TestWorld::new_with(TestWorldConfig { + gateway_config: GatewayConfig { + active: 5.try_into().unwrap(), + ..Default::default() + }, + ..Default::default() + }); + let new_active_work = NonZeroUsize::new(3).unwrap(); + assert!(new_active_work < world.gateway(Role::H1).config().active_work()); + let sender = world.gateway(Role::H1).get_mpc_sender::( + &ChannelId::new(Role::H2, Gate::default()), + TotalRecords::specified(15).unwrap(), + new_active_work, + ); + try_join_all( + (0..new_active_work.get()) + .map(|record_id| sender.send(record_id.into(), BA3::ZERO)), + ) + .await + .unwrap(); + let recv = world.gateway(Role::H2).get_mpc_receiver::(&ChannelId { + peer: Role::H1, + gate: Gate::default(), + }); + // this will hang if the original active work is used + try_join_all( + (0..new_active_work.get()).map(|record_id| recv.receive(record_id.into())), + ) + .await + .unwrap(); + }); + } + async fn shard_comms_test(test_world: &TestWorld>) { let input = vec![BA3::truncate_from(0_u32), BA3::truncate_from(1_u32)]; From 2443420b0e3a6dcd0efa5935232d02f58fb08d50 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 26 Sep 2024 17:21:37 -0700 Subject: [PATCH 090/191] Improve code coverage --- ipa-core/src/net/http_serde.rs | 8 +++- .../src/net/server/handlers/query/kill.rs | 38 +++++++++++++++++-- ipa-core/src/net/server/handlers/query/mod.rs | 13 +++++++ ipa-core/src/query/processor.rs | 4 +- 4 files changed, 56 insertions(+), 7 deletions(-) diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index f6902250f..1965c15ce 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -542,7 +542,6 @@ pub mod query { protocol::QueryId, }; - #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Request { pub query_id: QueryId, } @@ -568,7 +567,12 @@ pub mod query { } impl Request { - /// see above for the reason why this needs to be behind a feature flag + /// Currently, it is only possible to kill + /// a query by issuing an HTTP request manually. + /// Maybe report collector can support this API, + /// but for now, only tests exercise this path + /// hence methods here are hidden behind feature + /// flags #[cfg(all(test, unit_test))] pub fn new(query_id: QueryId) -> Self { Self { query_id } diff --git a/ipa-core/src/net/server/handlers/query/kill.rs b/ipa-core/src/net/server/handlers/query/kill.rs index 8a2188df7..aae68b993 100644 --- a/ipa-core/src/net/server/handlers/query/kill.rs +++ b/ipa-core/src/net/server/handlers/query/kill.rs @@ -47,14 +47,16 @@ mod tests { helpers::{ make_owned_handler, routing::{Addr, RouteId}, - BodyStream, HelperIdentity, HelperResponse, + ApiError, BodyStream, HelperIdentity, HelperResponse, }, net::{ http_serde, - server::handlers::query::test_helpers::{assert_fails_with, assert_success_with}, + server::handlers::query::test_helpers::{ + assert_fails_with, assert_fails_with_handler, assert_success_with, + }, }, protocol::QueryId, - query::QueryKilled, + query::{QueryKillStatus, QueryKilled}, }; #[tokio::test] @@ -78,6 +80,36 @@ mod tests { assert_success_with(req, handler).await; } + #[tokio::test] + async fn no_such_query() { + let handler = make_owned_handler( + move |_addr: Addr, _data: BodyStream| async move { + Err(QueryKillStatus::NoSuchQuery(QueryId).into()) + }, + ); + + let req = http_serde::query::kill::Request::new(QueryId) + .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) + .unwrap(); + assert_fails_with_handler(req, handler, StatusCode::NOT_FOUND).await; + } + + #[tokio::test] + async fn unknown_error() { + let handler = make_owned_handler( + move |_addr: Addr, _data: BodyStream| async move { + Err(ApiError::DeserializationFailure( + serde_json::from_str::<()>("not-a-json").unwrap_err(), + )) + }, + ); + + let req = http_serde::query::kill::Request::new(QueryId) + .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) + .unwrap(); + assert_fails_with_handler(req, handler, StatusCode::INTERNAL_SERVER_ERROR).await; + } + struct OverrideReq { query_id: String, } diff --git a/ipa-core/src/net/server/handlers/query/mod.rs b/ipa-core/src/net/server/handlers/query/mod.rs index 92890bd57..616308eea 100644 --- a/ipa-core/src/net/server/handlers/query/mod.rs +++ b/ipa-core/src/net/server/handlers/query/mod.rs @@ -141,6 +141,19 @@ pub mod test_helpers { assert_eq!(resp.status(), expected_status); } + pub async fn assert_fails_with_handler( + req: hyper::Request, + handler: Arc>, + expected_status: StatusCode, + ) { + let test_server = TestServer::builder() + .with_request_handler(handler) + .build() + .await; + let resp = test_server.server.handle_req(req).await; + assert_eq!(resp.status(), expected_status); + } + pub async fn assert_success_with( req: hyper::Request, handler: Arc>, diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 9734eab2d..a8694012e 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -5,7 +5,7 @@ use std::{ }; use futures::{future::try_join, stream}; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use crate::{ error::Error as ProtocolError, @@ -352,7 +352,7 @@ impl Processor { } } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Serialize)] pub struct QueryKilled(pub QueryId); #[derive(thiserror::Error, Debug)] From 61f37abba209e7bc7ab956422002bdc575013065 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 26 Sep 2024 18:06:55 -0700 Subject: [PATCH 091/191] Relaxing this a bit --- ipa-core/src/test_fixture/hybrid_event_gen.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/test_fixture/hybrid_event_gen.rs b/ipa-core/src/test_fixture/hybrid_event_gen.rs index bc1a463dd..792b1cc37 100644 --- a/ipa-core/src/test_fixture/hybrid_event_gen.rs +++ b/ipa-core/src/test_fixture/hybrid_event_gen.rs @@ -186,8 +186,8 @@ mod tests { const EXPECTED_HISTOGRAM_WITH_TOLERANCE: [(i32, f64); 12] = [ (0, 0.0), (647_634, 0.01), - (137_626, 0.01), - (20_652, 0.02), + (137_626, 0.02), + (20_652, 0.03), (3_085, 0.05), (463, 0.12), (70, 0.5), From 966160fcdd12effd9687a159768e25cc6df179a6 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 26 Sep 2024 21:20:45 -0700 Subject: [PATCH 092/191] Fix a bug in the batcher --- ipa-core/src/protocol/context/batcher.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/context/batcher.rs b/ipa-core/src/protocol/context/batcher.rs index 2cd00be01..cdbfddbce 100644 --- a/ipa-core/src/protocol/context/batcher.rs +++ b/ipa-core/src/protocol/context/batcher.rs @@ -114,7 +114,7 @@ impl<'a, B> Batcher<'a, B> { while self.batches.len() <= batch_offset { let (validation_result, _) = watch::channel::(false); let state = BatchState { - batch: (self.batch_constructor)(self.first_batch + batch_offset), + batch: (self.batch_constructor)(self.first_batch + self.batches.len()), validation_result, pending_count: 0, pending_records: bitvec![0; self.records_per_batch], @@ -296,6 +296,23 @@ mod tests { ); } + #[test] + fn makes_batches_out_of_order() { + // Regression test for a bug where, when adding batches i..j to fill in a gap in + // the batch deque prior to out-of-order requested batch j, the batcher passed + // batch index `j` to the constructor for all of them, as opposed to the correct + // sequence of indices i..=j. + + let batcher = Batcher::new(1, 2, Box::new(std::convert::identity)); + let mut batcher = batcher.lock().unwrap(); + + batcher.get_batch(RecordId::from(1)); + batcher.get_batch(RecordId::from(0)); + + assert_eq!(batcher.get_batch(RecordId::from(0)).batch, 0); + assert_eq!(batcher.get_batch(RecordId::from(1)).batch, 1); + } + #[tokio::test] async fn validates_batches() { let batcher = Batcher::new(2, 4, Box::new(|_| Vec::new())); From d2512f1a2cfa24c2af58d7700f534586b3cc89ef Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 27 Sep 2024 09:43:32 -0700 Subject: [PATCH 093/191] Restore the batch size kludge for attribution --- ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index a745d5ea7..9a1f8f278 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -510,7 +510,9 @@ where protocol: &Step::Attribute, validate: &Step::AttributeValidate, }, - chunk_size, + // TODO: this should not be necessary, but probably can't be removed + // until we align read_size with the batch size. + std::cmp::min(sh_ctx.active_work().get(), chunk_size), ); dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?; From 75161ad275d02f6add55c016fa81f5cc3b4c943d Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 27 Sep 2024 10:59:49 -0700 Subject: [PATCH 094/191] Increase number of steps for DZKP proof to 600 from 192 Running Draft on 50M records revealed that we don't have sufficient number of steps for zero-knowledge proof runs. I tried to keep the step limit under 20k, so I did some cleanup as well, but I think I failed and I had to bump it. --- ipa-core/src/protocol/context/step.rs | 2 +- ipa-core/src/protocol/ipa_prf/aggregation/step.rs | 8 -------- ipa-core/src/protocol/ipa_prf/mod.rs | 2 +- ipa-core/src/protocol/step.rs | 10 ---------- 4 files changed, 2 insertions(+), 20 deletions(-) diff --git a/ipa-core/src/protocol/context/step.rs b/ipa-core/src/protocol/context/step.rs index d650340f8..aeb6bd76f 100644 --- a/ipa-core/src/protocol/context/step.rs +++ b/ipa-core/src/protocol/context/step.rs @@ -31,7 +31,7 @@ pub(crate) enum ValidateStep { // This really is only for DZKPs and not for MACs. The MAC protocol uses record IDs to // count batches. DZKP probably should do the same to avoid the fixed upper limit. #[derive(CompactStep)] -#[step(count = 192, child = DzkpValidationProtocolStep)] +#[step(count = 600, child = DzkpValidationProtocolStep)] pub(crate) struct DzkpBatchStep(pub usize); // This is used when we don't do batched verification, to avoid paying for x256 as many diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs index fcb619a45..3c7b5da95 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs @@ -12,16 +12,8 @@ pub(crate) enum AggregationStep { RevealStep, #[step(child = AggregateChunkStep)] SumContributions, - #[step(child = crate::protocol::context::step::DzkpBatchStep)] - AggregateValidate, } -/// the number of steps must be kept in sync with `MAX_BREAKDOWNS` defined -/// [here](https://tinyurl.com/mwnbbnj6) -#[derive(CompactStep)] -#[step(count = 512, child = crate::protocol::boolean::step::EightBitStep, name = "b")] -pub struct BucketStep(usize); - #[derive(CompactStep)] #[step(count = 32, child = AggregateValuesStep, name = "depth")] pub(crate) struct AggregateChunkStep(usize); diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 0462e31d6..cc3fa2633 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -744,7 +744,7 @@ mod compact_gate_tests { fn step_count_limit() { // This is an arbitrary limit intended to catch changes that unintentionally // blow up the step count. It can be increased, within reason. - const STEP_COUNT_LIMIT: u32 = 20_000; + const STEP_COUNT_LIMIT: u32 = 24_000; assert!( ProtocolStep::STEP_COUNT < STEP_COUNT_LIMIT, "Step count of {actual} exceeds limit of {STEP_COUNT_LIMIT}.", diff --git a/ipa-core/src/protocol/step.rs b/ipa-core/src/protocol/step.rs index c19bdb53b..8346557d2 100644 --- a/ipa-core/src/protocol/step.rs +++ b/ipa-core/src/protocol/step.rs @@ -31,16 +31,6 @@ impl<'de> serde::Deserialize<'de> for ProtocolGate { #[derive(CompactStep)] pub enum DeadCodeStep { - #[step(child = crate::protocol::basics::step::CheckZeroStep)] - CheckZero, - #[step(child = crate::protocol::basics::mul::step::MaliciousMultiplyStep)] - MaliciousMultiply, - #[step(child = crate::protocol::context::step::UpgradeStep)] - UpgradeShare, - #[step(child = crate::protocol::context::step::MaliciousProtocolStep)] - MaliciousProtocol, - #[step(child = crate::protocol::context::step::ValidateStep)] - MaliciousValidation, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::SaturatedSubtractionStep)] SaturatedSubtraction, #[step(child = crate::protocol::ipa_prf::prf_sharding::step::FeatureLabelDotProductStep)] From e8ad98f15cfff538dd142039659c3ee5b357c8a6 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 27 Sep 2024 11:55:58 -0700 Subject: [PATCH 095/191] Remove `active_work` field from ZKP malicious context We don't need it because active work is supplied by base context --- .../src/protocol/context/dzkp_malicious.rs | 10 ++++--- ipa-core/src/protocol/context/malicious.rs | 7 +++++ ipa-core/src/protocol/context/mod.rs | 26 ++++++------------- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 70dd3d2af..9f28239ba 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -29,7 +29,6 @@ use crate::{ pub struct DZKPUpgraded<'a> { validator_inner: Weak>, base_ctx: MaliciousContext<'a>, - active_work: NonZeroUsize, } impl<'a> DZKPUpgraded<'a> { @@ -59,8 +58,11 @@ impl<'a> DZKPUpgraded<'a> { }; Self { validator_inner: Arc::downgrade(validator_inner), - base_ctx, - active_work, + // This overrides the active work for this context and all children + // created from it by using narrow, clone, etc. + // This allows all steps participating in malicious validation + // to use the same active work window and prevent deadlocks + base_ctx: base_ctx.set_active_work(active_work), } } @@ -152,7 +154,7 @@ impl<'a> super::Context for DZKPUpgraded<'a> { impl<'a> SeqJoin for DZKPUpgraded<'a> { fn active_work(&self) -> NonZeroUsize { - self.active_work + self.base_ctx.active_work() } } diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index 18b9b8e29..8c287b1f2 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -78,6 +78,13 @@ impl<'a> Context<'a> { ..self.inner } } + + #[must_use] + pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self { + Self { + inner: self.inner.set_active_work(new_active_work), + } + } } impl<'a> super::Context for Context<'a> { diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index dfbc749be..eead81a16 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -177,30 +177,20 @@ impl<'a, B: ShardBinding> Base<'a, B> { total_records: TotalRecords, sharding: B, ) -> Self { - Self::new_with_active_work( - participant, - gateway, + Self { + inner: Inner::new(participant, gateway), gate, total_records, - gateway.config().active_work(), + active_work: gateway.config().active_work(), sharding, - ) + } } - fn new_with_active_work( - participant: &'a PrssEndpoint, - gateway: &'a Gateway, - gate: Gate, - total_records: TotalRecords, - active_work: NonZeroUsize, - sharding: B, - ) -> Self { + #[must_use] + pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self { Self { - inner: Inner::new(participant, gateway), - gate, - total_records, - active_work, - sharding, + active_work: new_active_work, + ..self.clone() } } } From 03c369216c32b08c0d611fbd85cd0622faf77c22 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Mon, 30 Sep 2024 13:50:23 -0700 Subject: [PATCH 096/191] Pinning shuffle futures that are too large for Shuttle --- ipa-core/src/protocol/ipa_prf/shuffle/base.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs index a48cbd128..d999db229 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs @@ -62,9 +62,9 @@ where let zs = generate_random_tables_with_peers(shares_len, &ctx_z); match ctx.role() { - Role::H1 => run_h1(&ctx, shares_len, shares, zs).await, - Role::H2 => run_h2(&ctx, shares_len, shares, zs).await, - Role::H3 => run_h3(&ctx, shares_len, zs).await, + Role::H1 => Box::pin(run_h1(&ctx, shares_len, shares, zs)).await, + Role::H2 => Box::pin(run_h2(&ctx, shares_len, shares, zs)).await, + Role::H3 => Box::pin(run_h3(&ctx, shares_len, zs)).await, } } From d42df8b0edde3371f621bcde46f4857b28bf5975 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Tue, 1 Oct 2024 10:32:09 -0700 Subject: [PATCH 097/191] Lowering future size threshold to 8kb --- .clippy.toml | 2 ++ ipa-core/src/protocol/context/dzkp_validator.rs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.clippy.toml b/.clippy.toml index 5c572f532..9ed6287e0 100644 --- a/.clippy.toml +++ b/.clippy.toml @@ -7,3 +7,5 @@ disallowed-methods = [ { path = "std::mem::ManuallyDrop::new", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, { path = "std::vec::Vec::leak", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, ] + +future-size-threshold = 8192 \ No newline at end of file diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index f5586336f..6510d216f 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -972,7 +972,7 @@ mod tests { test_select_malicious::().await; test_select_malicious::().await; test_select_malicious::().await; - test_select_malicious::().await; + Box::pin(test_select_malicious::()).await; } #[tokio::test] From b54d3da79773a00fe5a9bfb8a6b81b0de9109b38 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 1 Oct 2024 17:15:42 -0700 Subject: [PATCH 098/191] Address feedback --- ipa-core/src/protocol/context/malicious.rs | 23 +++++++++++----------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index 36ffb48bb..2c998c8ab 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -21,7 +21,6 @@ use crate::{ validator::{self, BatchValidator}, Base, Context as ContextTrait, InstrumentedSequentialSharedRandomness, SpecialAccessToUpgradedContext, UpgradableContext, UpgradedContext, - UpgradedMaliciousContext, }, prss::{Endpoint as PrssEndpoint, FromPrss}, Gate, RecordId, @@ -324,7 +323,7 @@ impl Debug for Upgraded<'_, F, B> { /// Upgrading a semi-honest replicated share using malicious context produces /// a MAC-secured share with the same vectorization factor. #[async_trait] -impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> +impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> for Replicated where Replicated<::ExtendedField, N>: FromPrss, @@ -333,7 +332,7 @@ where async fn upgrade( self, - ctx: UpgradedMaliciousContext<'a, V>, + ctx: Upgraded<'a, V, NotSharded>, record_id: RecordId, ) -> Result { let ctx = ctx.narrow(&UpgradeStep); @@ -367,7 +366,7 @@ where #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> +impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> for (Replicated, Replicated) where Replicated<::ExtendedField, N>: FromPrss, @@ -376,7 +375,7 @@ where async fn upgrade( self, - ctx: UpgradedMaliciousContext<'a, V>, + ctx: Upgraded<'a, V, NotSharded>, record_id: RecordId, ) -> Result { let (l, r) = self; @@ -388,12 +387,12 @@ where #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V: ExtendableField> Upgradable> for () { +impl<'a, V: ExtendableField> Upgradable> for () { type Output = (); async fn upgrade( self, - _context: UpgradedMaliciousContext<'a, V>, + _context: Upgraded<'a, V, NotSharded>, _record_id: RecordId, ) -> Result { Ok(()) @@ -402,28 +401,28 @@ impl<'a, V: ExtendableField> Upgradable> for () #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V, U> Upgradable> for Vec +impl<'a, V, U> Upgradable> for Vec where V: ExtendableField, - U: Upgradable, Output: Send> + Send + 'a, + U: Upgradable, Output: Send> + Send + 'a, { type Output = Vec; async fn upgrade( self, - ctx: UpgradedMaliciousContext<'a, V>, + ctx: Upgraded<'a, V, NotSharded>, record_id: RecordId, ) -> Result { /// Need a standalone function to avoid GAT issue that apparently can manifest /// even with `async_trait`. fn upgrade_vec<'a, V, U>( - ctx: UpgradedMaliciousContext<'a, V>, + ctx: Upgraded<'a, V, NotSharded>, record_id: RecordId, input: Vec, ) -> impl std::future::Future, Error>> + 'a where V: ExtendableField, - U: Upgradable> + 'a, + U: Upgradable> + 'a, { let mut upgraded = Vec::with_capacity(input.len()); async move { From 46792f67ad45be58557de9b59b1abf535edde15e Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Tue, 1 Oct 2024 20:31:40 -0700 Subject: [PATCH 099/191] Updating tonic to prevent RUSTSEC-2024-0376 --- ipa-core/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 4a495d50e..49348843d 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -94,7 +94,7 @@ bytes = "1.4" clap = { version = "4.3.2", optional = true, features = ["derive"] } comfy-table = { version = "7.0", optional = true } config = "0.14" -console-subscriber = { version = "0.2", optional = true } +console-subscriber = { version = "0.4", optional = true } criterion = { version = "0.5.1", optional = true, default-features = false, features = [ "async_tokio", "plotters", From bdcd5a6563ca7444baf4fda7496824daf810c889 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Tue, 1 Oct 2024 20:35:04 -0700 Subject: [PATCH 100/191] Boxing large future --- ipa-core/src/protocol/context/dzkp_validator.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 6510d216f..835a32e9d 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -916,7 +916,7 @@ mod tests { test_select_semi_honest::().await; test_select_semi_honest::().await; test_select_semi_honest::().await; - test_select_semi_honest::().await; + Box::pin(test_select_semi_honest::()).await; } async fn test_select_malicious() From c8a02812aa16c8ba67fc30267ffc6a66fef826ba Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 2 Oct 2024 12:41:10 -0700 Subject: [PATCH 101/191] Plumb padding parameters to breakdown_reveal_aggregation Use the relaxed parameters for most tests, and disable padding for the end-to-end malicious test when running under shuttle. --- .../ipa_prf/aggregation/breakdown_reveal.rs | 21 ++++++++----- ipa-core/src/protocol/ipa_prf/mod.rs | 14 +++++++-- .../src/protocol/ipa_prf/oprf_padding/mod.rs | 8 ++--- .../src/protocol/ipa_prf/prf_sharding/mod.rs | 31 +++++++++++++++---- 4 files changed, 53 insertions(+), 21 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index 0f5269bbf..a643d397d 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -52,6 +52,7 @@ use crate::{ pub async fn breakdown_reveal_aggregation( ctx: C, attributed_values: Vec>, + padding_params: &PaddingParameters, ) -> Result>, Error> where C: Context, @@ -63,13 +64,12 @@ where BitDecomposed>: for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, { - let dp_padding_params = PaddingParameters::default(); // Apply DP padding for Breakdown Reveal Aggregation let attributed_values_padded = apply_dp_padding::<_, AttributionOutputs, Replicated>, B>( ctx.narrow(&AggregationStep::PaddingDp), attributed_values, - dp_padding_params, + padding_params, ) .await?; @@ -205,6 +205,7 @@ pub mod tests { }, protocol::ipa_prf::{ aggregation::breakdown_reveal::breakdown_reveal_aggregation, + oprf_padding::PaddingParameters, prf_sharding::{AttributionOutputsTestInput, SecretSharedAttributionOutputs}, }, secret_sharing::{ @@ -257,12 +258,16 @@ pub mod tests { }) .collect(); let r: Vec> = - breakdown_reveal_aggregation::<_, BA5, BA3, BA8, 32>(ctx, aos) - .map_ok(|d: BitDecomposed>| { - Vec::transposed_from(&d).unwrap() - }) - .await - .unwrap(); + breakdown_reveal_aggregation::<_, BA5, BA3, BA8, 32>( + ctx, + aos, + &PaddingParameters::relaxed(), + ) + .map_ok(|d: BitDecomposed>| { + Vec::transposed_from(&d).unwrap() + }) + .await + .unwrap(); r }) .await diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index cc3fa2633..a0f051e50 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -269,7 +269,7 @@ where let padded_input_rows = apply_dp_padding::<_, OPRFIPAInputRow, B>( ctx.narrow(&Step::PaddingDp), input_rows, - dp_padding_params, + &dp_padding_params, ) .await?; @@ -297,6 +297,7 @@ where prfd_inputs, attribution_window_seconds, &row_count_histogram, + &dp_padding_params, ) .await?; @@ -449,7 +450,14 @@ pub mod tests { ]; // trigger value of 2 attributes to earlier source row with breakdown 1 and trigger // value of 5 attributes to source row with breakdown 2. let dp_params = DpMechanism::NoDp; - let padding_params = PaddingParameters::relaxed(); + let padding_params = if cfg!(feature = "shuttle") { + // To reduce runtime. There is also a hard upper limit in the shuttle + // config (`max_steps`), that may need to be increased to support larger + // runs. + PaddingParameters::no_padding() + } else { + PaddingParameters::relaxed() + }; let mut result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { @@ -489,7 +497,7 @@ pub mod tests { ]; // trigger value of 2 attributes to earlier source row with breakdown 1 and trigger // value of 5 attributes to source row with breakdown 2. let dp_params = DpMechanism::NoDp; - let padding_params = PaddingParameters::relaxed(); + let padding_params = PaddingParameters::no_padding(); let mut result: Vec<_> = world .malicious(records.into_iter(), |ctx, input_rows| async move { diff --git a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs index d4680f0ac..207dd2a43 100644 --- a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs @@ -278,7 +278,7 @@ where pub async fn apply_dp_padding( ctx: C, mut input: Vec, - padding_params: PaddingParameters, + padding_params: &PaddingParameters, ) -> Result, Error> where C: Context, @@ -291,7 +291,7 @@ where ctx.narrow(&PaddingDpStep::PaddingDpPass1), input, Role::H3, - &padding_params, + padding_params, ) .await?; @@ -300,7 +300,7 @@ where ctx.narrow(&PaddingDpStep::PaddingDpPass2), input, Role::H2, - &padding_params, + padding_params, ) .await?; @@ -309,7 +309,7 @@ where ctx.narrow(&PaddingDpStep::PaddingDpPass3), input, Role::H1, - &padding_params, + padding_params, ) .await?; diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 9a1f8f278..682841c8b 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -39,6 +39,7 @@ use crate::{ comparison_and_subtraction_sequential::{compare_gt, integer_sub}, expand_shared_array_in_place, }, + oprf_padding::PaddingParameters, prf_sharding::step::{ AttributionPerRowStep as PerRowStep, AttributionStep as Step, AttributionWindowStep as WindowStep, @@ -469,6 +470,7 @@ pub async fn attribute_cap_aggregate< input_rows: Vec>, attribution_window_seconds: Option, histogram: &[usize], + padding_parameters: &PaddingParameters, ) -> Result>, Error> where C: UpgradableContext + 'ctx, @@ -544,9 +546,12 @@ where aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()), ); let user_contributions = flattened_user_results.try_collect::>().await?; - let result = - breakdown_reveal_aggregation::<_, _, _, HV, B>(validator.context(), user_contributions) - .await; + let result = breakdown_reveal_aggregation::<_, _, _, HV, B>( + validator.context(), + user_contributions, + padding_parameters, + ) + .await; validator.validate().await?; result } @@ -891,7 +896,9 @@ pub mod tests { Field, U128Conversions, }, helpers::repeat_n, - protocol::ipa_prf::prf_sharding::attribute_cap_aggregate, + protocol::ipa_prf::{ + oprf_padding::PaddingParameters, prf_sharding::attribute_cap_aggregate, + }, rand::Rng, secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, @@ -1077,7 +1084,11 @@ pub mod tests { .malicious(records.into_iter(), |ctx, input_rows| async move { Vec::transposed_from( &attribute_cap_aggregate::<_, BA5, BA3, BA16, BA20, 5, 32>( - ctx, input_rows, None, &histogram, + ctx, + input_rows, + None, + &histogram, + &PaddingParameters::relaxed(), ) .await .unwrap(), @@ -1138,6 +1149,7 @@ pub mod tests { input_rows, NonZeroU32::new(ATTRIBUTION_WINDOW_SECONDS), &histogram, + &PaddingParameters::relaxed(), ) .await .unwrap(), @@ -1175,6 +1187,7 @@ pub mod tests { input_rows, None, histogram_ref, + &PaddingParameters::relaxed(), ) .await .unwrap() @@ -1261,7 +1274,13 @@ pub mod tests { BA20, { SaturatingSumType::BITS as usize }, 256, - >(ctx, input_rows, None, &HISTOGRAM) + >( + ctx, + input_rows, + None, + &HISTOGRAM, + &PaddingParameters::relaxed(), + ) .await .unwrap(), ) From e259e5b9b97c04ca05244e91070460cd2576e01b Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 1 Oct 2024 22:07:23 -0700 Subject: [PATCH 102/191] Enforce active work to be a power of two This is the second attempt to mitigate send buffer misalignment. Previous one (#1307) didn't handle all the edge cases and was abandoned in favour of this PR. What I believe makes this change work is the new requirement for active work to be a power of two. With this constraint, it is much easier to align the read size with it. Given that `total_capacity = active * record_size`, we can represent `read_size` as a multiple of `record_size` too: `read_size = X * record_size`. If X is a power of two and active_work is a power of two, then they will always be aligned with each other. For example, if active work is 16, read size is 10 bytes and record size is 3 bytes, then: ``` total_capacity = 16*3 read_size = X*3 (close to 10) X = 2 (power of two that satisfies the requirement) ``` when picking up the read size, we are rounding down to avoid buffer overflows. In the example above, setting X=3 would make it closer to the desired read size, but it is greater than 10, so we pick 2 instead. --- ipa-core/src/helpers/gateway/mod.rs | 211 ++++++++++++++++++++++++++- ipa-core/src/helpers/gateway/send.rs | 103 +++++++++++-- 2 files changed, 296 insertions(+), 18 deletions(-) diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index e654f85f7..55d1b1ffc 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -73,6 +73,7 @@ pub struct State { pub struct GatewayConfig { /// The number of items that can be active at the one time. /// This is used to determine the size of sending and receiving buffers. + /// Any value that is not a power of two will be rejected pub active: NonZeroUsize, /// Number of bytes packed and sent together in one batch down to the network layer. This @@ -84,6 +85,10 @@ pub struct GatewayConfig { /// payload may not be exactly this, but it will be the closest multiple of record size to this /// number. For instance, having 14 bytes records and batch size of 4096 will result in /// 4088 bytes being sent in a batch. + /// + /// The actual size for read chunks may be bigger or smaller, depending on the record size + /// sent through each channel. Read size will be aligned with [`Self::active_work`] value to + /// prevent deadlocks. pub read_size: NonZeroUsize, /// Time to wait before checking gateway progress. If no progress has been made between @@ -279,7 +284,8 @@ impl GatewayConfig { // capabilities (see #ipa/1171) to allow that currently. usize::from(value.size), ), - ); + ) + .next_power_of_two(); // we set active to be at least 2, so unwrap is fine. self.active = NonZeroUsize::new(active).unwrap(); } @@ -299,23 +305,35 @@ mod tests { use std::{ iter::{repeat, zip}, num::NonZeroUsize, + sync::Arc, }; use futures::{ future::{join, try_join, try_join_all}, + stream, stream::StreamExt, }; + use proptest::proptest; use crate::{ - ff::{boolean_array::BA3, Fp31, Fp32BitPrime, Gf2, U128Conversions}, + ff::{ + boolean_array::{BA20, BA256, BA3, BA4, BA5, BA6, BA7, BA8}, + FieldType, Fp31, Fp32BitPrime, Gf2, U128Conversions, + }, helpers::{ - ChannelId, Direction, GatewayConfig, MpcMessage, Role, SendingEnd, TotalRecords, + gateway::QueryConfig, + query::{QuerySize, QueryType}, + ChannelId, Direction, GatewayConfig, MpcMessage, MpcReceivingEnd, Role, SendingEnd, + TotalRecords, }, protocol::{ context::{Context, ShardedContext}, Gate, RecordId, }, - secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue}, + secret_sharing::{ + replicated::semi_honest::AdditiveShare, SharedValue, SharedValueArray, StdArray, + }, + seq_join::seq_join, sharding::ShardConfiguration, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards}, @@ -569,6 +587,87 @@ mod tests { }); } + macro_rules! send_recv_test { + ( + message: $message:expr, + read_size: $read_size:expr, + active_work: $active_work:expr, + total_records: $total_records:expr, + $test_fn: ident + ) => { + #[test] + fn $test_fn() { + run(|| async { + send_recv($read_size, $active_work, $total_records, $message).await; + }); + } + }; + } + + send_recv_test! { + message: BA20::ZERO, + read_size: 5, + active_work: 8, + total_records: 25, + test_ba20_5_10_25 + } + + send_recv_test! { + message: StdArray::::ZERO_ARRAY, + read_size: 2048, + active_work: 16, + total_records: 43, + test_ba256_by_16_2048_10_43 + } + + send_recv_test! { + message: StdArray::::ZERO_ARRAY, + read_size: 2048, + active_work: 32, + total_records: 50, + test_ba8_by_16_2048_37_50 + } + + proptest! { + #[test] + fn send_recv_randomized( + total_records in 1_usize..500, + active in 2_usize..1000, + read_size in (1_usize..32768), + record_size in 1_usize..=8, + ) { + let active = active.next_power_of_two(); + run(move || async move { + match record_size { + 1 => send_recv(read_size, active, total_records, StdArray::::ZERO_ARRAY).await, + 2 => send_recv(read_size, active, total_records, StdArray::::ZERO_ARRAY).await, + 3 => send_recv(read_size, active, total_records, BA3::ZERO).await, + 4 => send_recv(read_size, active, total_records, BA4::ZERO).await, + 5 => send_recv(read_size, active, total_records, BA5::ZERO).await, + 6 => send_recv(read_size, active, total_records, BA6::ZERO).await, + 7 => send_recv(read_size, active, total_records, BA7::ZERO).await, + 8 => send_recv(read_size, active, total_records, StdArray::::ZERO_ARRAY).await, + _ => unreachable!(), + } + }); + } + } + + /// ensures when active work is set from query input, it is always a power of two + #[test] + fn gateway_config_active_work_power_of_two() { + let mut config = GatewayConfig { + active: 2.try_into().unwrap(), + ..Default::default() + }; + config.set_active_work_from_query_config(&QueryConfig { + size: QuerySize::try_from(5).unwrap(), + field_type: FieldType::Fp31, + query_type: QueryType::TestAddInPrimeField, + }); + assert_eq!(8, config.active_work().get()); + } + async fn shard_comms_test(test_world: &TestWorld>) { let input = vec![BA3::truncate_from(0_u32), BA3::truncate_from(1_u32)]; @@ -606,4 +705,108 @@ mod tests { let world_ptr = world as *mut _; (world, world_ptr) } + + /// This serves the purpose of randomized testing of our send channels by providing + /// variable sizes for read size, active work and record size + async fn send_recv(read_size: usize, active_work: usize, total_records: usize, sample: M) + where + M: MpcMessage + Clone + PartialEq, + { + fn duplex_channel( + world: &TestWorld, + left: Role, + right: Role, + total_records: usize, + active_work: usize, + ) -> (SendingEnd, MpcReceivingEnd) { + ( + world.gateway(left).get_mpc_sender::( + &ChannelId::new(right, Gate::default()), + TotalRecords::specified(total_records).unwrap(), + active_work.try_into().unwrap(), + ), + world + .gateway(right) + .get_mpc_receiver::(&ChannelId::new(left, Gate::default())), + ) + } + + async fn circuit( + send_channel: SendingEnd, + recv_channel: MpcReceivingEnd, + active_work: usize, + total_records: usize, + msg: M, + ) where + M: MpcMessage + Clone + PartialEq, + { + let send_notify = Arc::new(tokio::sync::Notify::new()); + + // perform "multiplication-like" operation (send + subsequent receive) + // and "validate": block the future until we have at least `active_work` + // futures pending and unblock them all at the same time + seq_join( + active_work.try_into().unwrap(), + stream::iter(std::iter::repeat(msg).take(total_records).enumerate()).map( + |(record_id, msg)| { + let send_channel = &send_channel; + let recv_channel = &recv_channel; + let send_notify = Arc::clone(&send_notify); + async move { + send_channel + .send(record_id.into(), msg.clone()) + .await + .unwrap(); + let r = recv_channel.receive(record_id.into()).await.unwrap(); + // this simulates validate_record API by forcing futures to wait + // until the entire batch is validated by the last future in that batch + if record_id % active_work == active_work - 1 + || record_id == total_records - 1 + { + send_notify.notify_waiters(); + } else { + send_notify.notified().await; + } + assert_eq!(msg, r); + } + }, + ), + ) + .collect::>() + .await; + } + + let config = TestWorldConfig { + gateway_config: GatewayConfig { + active: active_work.try_into().unwrap(), + read_size: read_size.try_into().unwrap(), + ..Default::default() + }, + ..Default::default() + }; + + let world = TestWorld::new_with(&config); + let (h1_send_channel, h1_recv_channel) = + duplex_channel(&world, Role::H1, Role::H2, total_records, active_work); + let (h2_send_channel, h2_recv_channel) = + duplex_channel(&world, Role::H2, Role::H1, total_records, active_work); + + join( + circuit( + h1_send_channel, + h1_recv_channel, + active_work, + total_records, + sample.clone(), + ), + circuit( + h2_send_channel, + h2_recv_channel, + active_work, + total_records, + sample, + ), + ) + .await; + } } diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index 07018fb14..dab292353 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -248,28 +248,50 @@ impl Stream for GatewaySendStream { impl SendChannelConfig { fn new(gateway_config: GatewayConfig, total_records: TotalRecords) -> Self { - debug_assert!(M::Size::USIZE > 0, "Message size cannot be 0"); + Self::new_with(gateway_config, total_records, M::Size::USIZE) + } + fn new_with( + gateway_config: GatewayConfig, + total_records: TotalRecords, + record_size: usize, + ) -> Self { + debug_assert!(record_size > 0, "Message size cannot be 0"); + debug_assert!( + gateway_config.active.is_power_of_two(), + "Active work {} must be a power of two", + gateway_config.active.get() + ); - let record_size = M::Size::USIZE; let total_capacity = gateway_config.active.get() * record_size; - Self { + // define read size as a multiplier of record size. The multiplier must be + // a power of two to align perfectly with total capacity. + let read_size_multiplier = { + // next_power_of_two returns a value that is greater than or equal to. + // That is not what we want here: if read_size / record_size is a power + // of two, then subsequent division will get us further away from desired target. + // For example: if read_size / record_size = 4, then prev_power_of_two = 2. + // In such cases, we want to stay where we are, so we add +1 for that. + let prev_power_of_two = + (gateway_config.read_size.get() / record_size + 1).next_power_of_two() / 2; + std::cmp::max(1, prev_power_of_two) + }; + + let this = Self { total_capacity: total_capacity.try_into().unwrap(), record_size: record_size.try_into().unwrap(), - read_size: if total_records.is_indeterminate() - || gateway_config.read_size.get() <= record_size - { + read_size: if total_records.is_indeterminate() { record_size } else { - std::cmp::min( - total_capacity, - // closest multiple of record_size to read_size - gateway_config.read_size.get() / record_size * record_size, - ) + std::cmp::min(total_capacity, read_size_multiplier * record_size) } .try_into() .unwrap(), total_records, - } + }; + + debug_assert!(this.total_capacity.get() >= record_size * gateway_config.active.get()); + + this } } @@ -277,6 +299,7 @@ impl SendChannelConfig { mod test { use std::num::NonZeroUsize; + use proptest::proptest; use typenum::Unsigned; use crate::{ @@ -379,15 +402,67 @@ mod test { fn config_read_size_closest_multiple_to_record_size() { assert_eq!( 6, - send_config::(TotalRecords::Specified(2.try_into().unwrap())) + send_config::(TotalRecords::Specified(2.try_into().unwrap())) .read_size .get() ); assert_eq!( 6, - send_config::(TotalRecords::Specified(2.try_into().unwrap())) + send_config::(TotalRecords::Specified(2.try_into().unwrap())) .read_size .get() ); } + + #[test] + fn config_read_size_record_size_misalignment() { + ensure_config(Some(15), 90, 16, 3); + } + + fn ensure_config( + total_records: Option, + active: usize, + read_size: usize, + record_size: usize, + ) { + let gateway_config = GatewayConfig { + active: active.next_power_of_two().try_into().unwrap(), + read_size: read_size.try_into().unwrap(), + // read_size: read_size.next_power_of_two().try_into().unwrap(), + ..Default::default() + }; + let config = SendChannelConfig::new_with( + gateway_config, + total_records.map_or(TotalRecords::Indeterminate, |v| { + TotalRecords::specified(v).unwrap() + }), + record_size, + ); + + // total capacity checks + assert!(config.total_capacity.get() > 0); + assert!(config.total_capacity.get() >= config.read_size.get()); + assert_eq!(0, config.total_capacity.get() % config.record_size.get()); + assert_eq!( + config.total_capacity.get(), + record_size * gateway_config.active.get() + ); + + // read size checks + assert!(config.read_size.get() > 0); + assert!(config.read_size.get() >= config.record_size.get()); + assert_eq!(0, config.total_capacity.get() % config.read_size.get()); + } + + proptest! { + #[test] + fn config_prop( + total_records in proptest::option::of(1_usize..1 << 32), + active in 1_usize..100_000, + read_size in 1_usize..32768, + record_size in 1_usize..4096, + ) { + ensure_config(total_records, active, read_size, record_size); + } + } } From d9a29f3483cf5d2b84b8cdeab78c1c0ff4cb9d15 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 2 Oct 2024 12:50:55 -0700 Subject: [PATCH 103/191] Add a type that enforces power of two constraint While working on changing the gateway and parameters, I ran into several issues where the power of two constraint was not enforced and breakages were hard to find. A better model for me is to gate the active work at the config level, prohibiting invalid constructions at the caller side. --- ipa-core/src/app.rs | 7 +- ipa-core/src/helpers/gateway/mod.rs | 29 +++-- ipa-core/src/helpers/gateway/send.rs | 28 +++-- .../src/helpers/gateway/stall_detection.rs | 4 +- ipa-core/src/helpers/prss_protocol.rs | 4 +- .../src/protocol/context/dzkp_malicious.rs | 7 +- ipa-core/src/protocol/context/malicious.rs | 7 +- ipa-core/src/protocol/context/mod.rs | 9 +- ipa-core/src/query/processor.rs | 6 +- ipa-core/src/test_fixture/circuit.rs | 10 +- ipa-core/src/utils/mod.rs | 5 + ipa-core/src/utils/power_of_two.rs | 110 ++++++++++++++++++ 12 files changed, 185 insertions(+), 41 deletions(-) create mode 100644 ipa-core/src/utils/power_of_two.rs diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index da56e67e3..f84aed06e 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -1,4 +1,4 @@ -use std::{num::NonZeroUsize, sync::Weak}; +use std::sync::Weak; use async_trait::async_trait; @@ -13,17 +13,18 @@ use crate::{ protocol::QueryId, query::{NewQueryError, QueryProcessor, QueryStatus}, sync::Arc, + utils::NonZeroU32PowerOfTwo, }; #[derive(Default)] pub struct AppConfig { - active_work: Option, + active_work: Option, key_registry: Option>, } impl AppConfig { #[must_use] - pub fn with_active_work(mut self, active_work: Option) -> Self { + pub fn with_active_work(mut self, active_work: Option) -> Self { self.active_work = active_work; self } diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 55d1b1ffc..15d2580d2 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -30,6 +30,7 @@ use crate::{ protocol::QueryId, sharding::ShardIndex, sync::{Arc, Mutex}, + utils::NonZeroU32PowerOfTwo, }; /// Alias for the currently configured transport. @@ -73,8 +74,7 @@ pub struct State { pub struct GatewayConfig { /// The number of items that can be active at the one time. /// This is used to determine the size of sending and receiving buffers. - /// Any value that is not a power of two will be rejected - pub active: NonZeroUsize, + pub active: NonZeroU32PowerOfTwo, /// Number of bytes packed and sent together in one batch down to the network layer. This /// shouldn't be too small to keep the network throughput, but setting it large enough may @@ -155,7 +155,7 @@ impl Gateway { &self, channel_id: &HelperChannelId, total_records: TotalRecords, - active_work: NonZeroUsize, + active_work: NonZeroU32PowerOfTwo, ) -> send::SendingEnd { let transport = &self.transports.mpc; let channel = self.inner.mpc_senders.get::( @@ -265,6 +265,11 @@ impl GatewayConfig { /// The configured amount of active work. #[must_use] pub fn active_work(&self) -> NonZeroUsize { + self.active.to_non_zero_usize() + } + + #[must_use] + pub fn active_work_as_power_of_two(&self) -> NonZeroU32PowerOfTwo { self.active } @@ -287,12 +292,12 @@ impl GatewayConfig { ) .next_power_of_two(); // we set active to be at least 2, so unwrap is fine. - self.active = NonZeroUsize::new(active).unwrap(); + self.active = NonZeroU32PowerOfTwo::try_from(active).unwrap(); } /// Creates a new configuration by overriding the value of active work. #[must_use] - pub fn set_active_work(&self, active_work: NonZeroUsize) -> Self { + pub fn set_active_work(&self, active_work: NonZeroU32PowerOfTwo) -> Self { Self { active: active_work, ..*self @@ -304,7 +309,6 @@ impl GatewayConfig { mod tests { use std::{ iter::{repeat, zip}, - num::NonZeroUsize, sync::Arc, }; @@ -337,6 +341,7 @@ mod tests { sharding::ShardConfiguration, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards}, + utils::NonZeroU32PowerOfTwo, }; /// Verifies that [`Gateway`] send buffer capacity is adjusted to the message size. @@ -556,13 +561,19 @@ mod tests { run(|| async move { let world = TestWorld::new_with(TestWorldConfig { gateway_config: GatewayConfig { - active: 5.try_into().unwrap(), + active: 8.try_into().unwrap(), ..Default::default() }, ..Default::default() }); - let new_active_work = NonZeroUsize::new(3).unwrap(); - assert!(new_active_work < world.gateway(Role::H1).config().active_work()); + let new_active_work = NonZeroU32PowerOfTwo::try_from(4).unwrap(); + assert!( + new_active_work + < world + .gateway(Role::H1) + .config() + .active_work_as_power_of_two() + ); let sender = world.gateway(Role::H1).get_mpc_sender::( &ChannelId::new(Role::H2, Gate::default()), TotalRecords::specified(15).unwrap(), diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index dab292353..de9dcabd5 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -255,12 +255,7 @@ impl SendChannelConfig { total_records: TotalRecords, record_size: usize, ) -> Self { - debug_assert!(record_size > 0, "Message size cannot be 0"); - debug_assert!( - gateway_config.active.is_power_of_two(), - "Active work {} must be a power of two", - gateway_config.active.get() - ); + assert!(record_size > 0, "Message size cannot be 0"); let total_capacity = gateway_config.active.get() * record_size; // define read size as a multiplier of record size. The multiplier must be @@ -289,7 +284,8 @@ impl SendChannelConfig { total_records, }; - debug_assert!(this.total_capacity.get() >= record_size * gateway_config.active.get()); + assert!(this.total_capacity.get() >= record_size * gateway_config.active.get()); + assert_eq!(0, this.total_capacity.get() % this.read_size.get()); this } @@ -304,7 +300,7 @@ mod test { use crate::{ ff::{ - boolean_array::{BA16, BA20, BA256, BA3, BA7}, + boolean_array::{BA16, BA20, BA256, BA3, BA32, BA7}, Serializable, }, helpers::{gateway::send::SendChannelConfig, GatewayConfig, TotalRecords}, @@ -419,6 +415,21 @@ mod test { ensure_config(Some(15), 90, 16, 3); } + #[test] + fn config_read_size_multiple_of_record_size() { + // 4 bytes * 8 = 32 bytes total capacity. + // desired read size is 15 bytes, and the closest multiple of BA32 + // to it that is a power of two is 2 (4 gets us over 15 byte target) + assert_eq!(8, send_config::(50.into()).read_size.get()); + + // here, read size is already a power of two + assert_eq!(16, send_config::(50.into()).read_size.get()); + + // read size can be ridiculously small, config adjusts it to fit + // at least one record + assert_eq!(3, send_config::(50.into()).read_size.get()); + } + fn ensure_config( total_records: Option, active: usize, @@ -428,7 +439,6 @@ mod test { let gateway_config = GatewayConfig { active: active.next_power_of_two().try_into().unwrap(), read_size: read_size.try_into().unwrap(), - // read_size: read_size.next_power_of_two().try_into().unwrap(), ..Default::default() }; let config = SendChannelConfig::new_with( diff --git a/ipa-core/src/helpers/gateway/stall_detection.rs b/ipa-core/src/helpers/gateway/stall_detection.rs index 43706f450..4a844386f 100644 --- a/ipa-core/src/helpers/gateway/stall_detection.rs +++ b/ipa-core/src/helpers/gateway/stall_detection.rs @@ -67,7 +67,6 @@ impl Observed { } mod gateway { - use std::num::NonZeroUsize; use delegate::delegate; @@ -81,6 +80,7 @@ mod gateway { protocol::QueryId, sharding::ShardIndex, sync::Arc, + utils::NonZeroU32PowerOfTwo, }; pub struct InstrumentedGateway { @@ -154,7 +154,7 @@ mod gateway { &self, channel_id: &HelperChannelId, total_records: TotalRecords, - active_work: NonZeroUsize, + active_work: NonZeroU32PowerOfTwo, ) -> SendingEnd { Observed::wrap( Weak::clone(self.get_sn()), diff --git a/ipa-core/src/helpers/prss_protocol.rs b/ipa-core/src/helpers/prss_protocol.rs index f9284f9eb..850d6c733 100644 --- a/ipa-core/src/helpers/prss_protocol.rs +++ b/ipa-core/src/helpers/prss_protocol.rs @@ -24,12 +24,12 @@ pub async fn negotiate( let left_sender = gateway.get_mpc_sender::( &left_channel, TotalRecords::ONE, - gateway.config().active_work(), + gateway.config().active_work_as_power_of_two(), ); let right_sender = gateway.get_mpc_sender::( &right_channel, TotalRecords::ONE, - gateway.config().active_work(), + gateway.config().active_work_as_power_of_two(), ); let left_receiver = gateway.get_mpc_receiver::(&left_channel); let right_receiver = gateway.get_mpc_receiver::(&right_channel); diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 9f28239ba..80762fb52 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -61,8 +61,11 @@ impl<'a> DZKPUpgraded<'a> { // This overrides the active work for this context and all children // created from it by using narrow, clone, etc. // This allows all steps participating in malicious validation - // to use the same active work window and prevent deadlocks - base_ctx: base_ctx.set_active_work(active_work), + // to use the same active work window and prevent deadlocks. + // + // This also checks that active work is a power of two and + // panics if it is not. + base_ctx: base_ctx.set_active_work(active_work.get().try_into().unwrap()), } } diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index 8c287b1f2..b11f6f5a8 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -80,7 +80,7 @@ impl<'a> Context<'a> { } #[must_use] - pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self { + pub fn set_active_work(self, new_active_work: NonZeroU32PowerOfTwo) -> Self { Self { inner: self.inner.set_active_work(new_active_work), } @@ -171,7 +171,10 @@ impl Debug for Context<'_> { } } -use crate::sync::{Mutex, Weak}; +use crate::{ + sync::{Mutex, Weak}, + utils::NonZeroU32PowerOfTwo, +}; pub(super) type MacBatcher<'a, F> = Mutex>>; diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index eead81a16..abf6f8476 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -44,6 +44,7 @@ use crate::{ secret_sharing::replicated::malicious::ExtendableField, seq_join::SeqJoin, sharding::{NotSharded, ShardBinding, ShardConfiguration, ShardIndex, Sharded}, + utils::NonZeroU32PowerOfTwo, }; /// Context used by each helper to perform secure computation. Provides access to shared randomness @@ -162,7 +163,7 @@ pub struct Base<'a, B: ShardBinding = NotSharded> { inner: Inner<'a>, gate: Gate, total_records: TotalRecords, - active_work: NonZeroUsize, + active_work: NonZeroU32PowerOfTwo, /// This indicates whether the system uses sharding or no. It's not ideal that we keep it here /// because it gets cloned often, a potential solution to that, if this shows up on flame graph, /// would be to move it to [`Inner`] struct. @@ -181,13 +182,13 @@ impl<'a, B: ShardBinding> Base<'a, B> { inner: Inner::new(participant, gateway), gate, total_records, - active_work: gateway.config().active_work(), + active_work: gateway.config().active_work_as_power_of_two(), sharding, } } #[must_use] - pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self { + pub fn set_active_work(self, new_active_work: NonZeroU32PowerOfTwo) -> Self { Self { active_work: new_active_work, ..self.clone() @@ -336,7 +337,7 @@ impl ShardConfiguration for Base<'_, Sharded> { impl<'a, B: ShardBinding> SeqJoin for Base<'a, B> { fn active_work(&self) -> NonZeroUsize { - self.active_work + self.active_work.to_non_zero_usize() } } diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index a8694012e..679b740fd 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -1,7 +1,6 @@ use std::{ collections::hash_map::Entry, fmt::{Debug, Formatter}, - num::NonZeroUsize, }; use futures::{future::try_join, stream}; @@ -22,6 +21,7 @@ use crate::{ CompletionHandle, ProtocolResult, }, sync::Arc, + utils::NonZeroU32PowerOfTwo, }; /// `Processor` accepts and tracks requests to initiate new queries on this helper party @@ -44,7 +44,7 @@ use crate::{ pub struct Processor { queries: RunningQueries, key_registry: Arc>, - active_work: Option, + active_work: Option, } impl Default for Processor { @@ -118,7 +118,7 @@ impl Processor { #[must_use] pub fn new( key_registry: KeyRegistry, - active_work: Option, + active_work: Option, ) -> Self { Self { queries: RunningQueries::default(), diff --git a/ipa-core/src/test_fixture/circuit.rs b/ipa-core/src/test_fixture/circuit.rs index 5a1ecd67e..17920591f 100644 --- a/ipa-core/src/test_fixture/circuit.rs +++ b/ipa-core/src/test_fixture/circuit.rs @@ -1,4 +1,4 @@ -use std::{array, num::NonZeroUsize}; +use std::array; use futures::{future::join3, stream, StreamExt}; use ipa_step::StepNarrow; @@ -17,7 +17,7 @@ use crate::{ secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, FieldSimd, IntoShares}, seq_join::seq_join, test_fixture::{ReconstructArr, TestWorld, TestWorldConfig}, - utils::array::zip3, + utils::{array::zip3, NonZeroU32PowerOfTwo}, }; pub struct Inputs, const N: usize> { @@ -76,7 +76,7 @@ pub async fn arithmetic( [F; N]: IntoShares>, Standard: Distribution, { - let active = NonZeroUsize::new(active_work).unwrap(); + let active = NonZeroU32PowerOfTwo::try_from(active_work.next_power_of_two()).unwrap(); let config = TestWorldConfig { gateway_config: GatewayConfig { active, @@ -85,7 +85,7 @@ pub async fn arithmetic( initial_gate: Some(Gate::default().narrow(&ProtocolStep::Test)), ..Default::default() }; - let world = TestWorld::new_with(config); + let world = TestWorld::new_with(&config); // Re-use contexts for the entire execution because record identifiers are contiguous. let contexts = world.contexts(); @@ -96,7 +96,7 @@ pub async fn arithmetic( // accumulated. This gives the best performance for vectorized operation. let ctx = ctx.set_total_records(TotalRecords::Indeterminate); seq_join( - active, + config.gateway_config.active_work(), stream::iter((0..(width / u32::try_from(N).unwrap())).zip(col_data)).map( move |(record, Inputs { a, b })| { circuit(ctx.clone(), RecordId::from(record), depth, a, b) diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index a3600e899..e8dfd95ae 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -1,2 +1,7 @@ pub mod array; pub mod arraychunks; +#[cfg(target_pointer_width = "64")] +mod power_of_two; + +#[cfg(target_pointer_width = "64")] +pub use power_of_two::NonZeroU32PowerOfTwo; diff --git a/ipa-core/src/utils/power_of_two.rs b/ipa-core/src/utils/power_of_two.rs new file mode 100644 index 000000000..abce8055e --- /dev/null +++ b/ipa-core/src/utils/power_of_two.rs @@ -0,0 +1,110 @@ +use std::{fmt::Display, num::NonZeroUsize, str::FromStr}; + +#[derive(Debug, thiserror::Error)] +#[error("{0} is not a power of two or not within the 1..u32::MAX range")] +pub struct ConvertError(I); + +impl PartialEq for ConvertError { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +/// This construction guarantees the value to be a power of two and +/// within the range 0..2^32-1 +#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)] +pub struct NonZeroU32PowerOfTwo(u32); + +impl Display for NonZeroU32PowerOfTwo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", u32::from(*self)) + } +} + +impl TryFrom for NonZeroU32PowerOfTwo { + type Error = ConvertError; + + fn try_from(value: usize) -> Result { + if value > 0 && value < usize::try_from(u32::MAX).unwrap() && value.is_power_of_two() { + Ok(NonZeroU32PowerOfTwo(u32::try_from(value).unwrap())) + } else { + Err(ConvertError(value)) + } + } +} + +impl From for usize { + fn from(value: NonZeroU32PowerOfTwo) -> Self { + // we are using 64 bit registers + usize::try_from(value.0).unwrap() + } +} + +impl From for u32 { + fn from(value: NonZeroU32PowerOfTwo) -> Self { + value.0 + } +} + +impl FromStr for NonZeroU32PowerOfTwo { + type Err = ConvertError; + + fn from_str(s: &str) -> Result { + let v = s.parse::().map_err(|_| ConvertError(s.to_owned()))?; + NonZeroU32PowerOfTwo::try_from(v).map_err(|_| ConvertError(s.to_owned())) + } +} + +impl NonZeroU32PowerOfTwo { + #[must_use] + pub fn to_non_zero_usize(self) -> NonZeroUsize { + let v = usize::from(self); + NonZeroUsize::new(v).unwrap_or_else(|| unreachable!()) + } + + #[must_use] + pub fn get(self) -> usize { + usize::from(self) + } +} + +#[cfg(test)] +mod tests { + use super::{ConvertError, NonZeroU32PowerOfTwo}; + + #[test] + fn rejects_invalid_values() { + assert!(matches!( + NonZeroU32PowerOfTwo::try_from(0), + Err(ConvertError(0)) + )); + assert!(matches!( + NonZeroU32PowerOfTwo::try_from(3), + Err(ConvertError(3)) + )); + + assert!(matches!( + NonZeroU32PowerOfTwo::try_from(1_usize << 33), + Err(ConvertError(_)) + )); + } + + #[test] + fn accepts_valid() { + assert_eq!(4, u32::from(NonZeroU32PowerOfTwo::try_from(4).unwrap())); + assert_eq!(16, u32::from(NonZeroU32PowerOfTwo::try_from(16).unwrap())); + } + + #[test] + fn parse_from_str() { + assert_eq!(NonZeroU32PowerOfTwo(4), "4".parse().unwrap()); + assert_eq!( + ConvertError("0".to_owned()), + "0".parse::().unwrap_err() + ); + assert_eq!( + ConvertError("3".to_owned()), + "3".parse::().unwrap_err() + ); + } +} From 46087f93f56c3ccf3d00719ec4cdbc8582e17fb9 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 2 Oct 2024 14:56:03 -0700 Subject: [PATCH 104/191] Fix batch size to be a power of two in tests --- ipa-core/src/bin/helper.rs | 5 ++--- ipa-core/src/helpers/gateway/send.rs | 1 + ipa-core/src/lib.rs | 2 +- ipa-core/src/protocol/basics/mul/dzkp_malicious.rs | 2 +- ipa-core/src/protocol/context/dzkp_validator.rs | 7 ++----- ipa-core/src/protocol/ipa_prf/mod.rs | 6 ++++-- ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs | 4 ++-- ipa-core/src/test_fixture/world.rs | 2 +- 8 files changed, 14 insertions(+), 15 deletions(-) diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 790245587..884745180 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -2,7 +2,6 @@ use std::{ fs, io::BufReader, net::TcpListener, - num::NonZeroUsize, os::fd::{FromRawFd, RawFd}, path::{Path, PathBuf}, process, @@ -18,7 +17,7 @@ use ipa_core::{ error::BoxError, helpers::HelperIdentity, net::{ClientIdentity, HttpShardTransport, HttpTransport, MpcHelperClient}, - AppConfig, AppSetup, + AppConfig, AppSetup, NonZeroU32PowerOfTwo, }; use tracing::{error, info}; @@ -93,7 +92,7 @@ struct ServerArgs { /// Override the amount of active work processed in parallel #[arg(long)] - active_work: Option, + active_work: Option, } #[derive(Debug, Subcommand)] diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index de9dcabd5..b1dc11155 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -436,6 +436,7 @@ mod test { read_size: usize, record_size: usize, ) { + #[allow(clippy::needless_update)] // stall detection feature wants default value let gateway_config = GatewayConfig { active: active.next_power_of_two().try_into().unwrap(), read_size: read_size.try_into().unwrap(), diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 59cae0106..f88ea718e 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -32,8 +32,8 @@ mod seq_join; mod serde; pub mod sharding; mod utils; - pub use app::{AppConfig, HelperApp, Setup as AppSetup}; +pub use utils::NonZeroU32PowerOfTwo; extern crate core; #[cfg(all(feature = "shuttle", test))] diff --git a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs index 23a96c982..e024c4483 100644 --- a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs +++ b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs @@ -101,7 +101,7 @@ mod test { let res = world .malicious((a, b), |ctx, (a, b)| async move { - let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 10); + let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 8); let mctx = validator.context(); let result = a .multiply(&b, mctx.set_total_records(1), RecordId::from(0)) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 835a32e9d..a16ca32fd 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -1197,10 +1197,7 @@ mod tests { fn max_multiplications_per_gate_strategy(record_count: usize) -> impl Strategy { let max_max_mults = record_count.min(128); - prop_oneof![ - 1usize..=max_max_mults, - (0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i) - ] + (0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i) } prop_compose! { @@ -1546,7 +1543,7 @@ mod tests { let [h1_batch, h2_batch, h3_batch] = world .malicious((a, b), |ctx, (a, b)| async move { - let mut validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 10); + let mut validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 8); let mctx = validator.context(); let _ = a .multiply(&b, mctx.set_total_records(1), RecordId::from(0)) diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index cc3fa2633..754f179b6 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -308,8 +308,10 @@ where // We expect 2*256 = 512 gates in total for two additions per conversion. The vectorization factor // is CONV_CHUNK. Let `len` equal the number of converted shares. The total amount of // multiplications is CONV_CHUNK*512*len. We want CONV_CHUNK*512*len ≈ 50M, or len ≈ 381, for a -// reasonably-sized proof. -const CONV_PROOF_CHUNK: usize = 400; +// reasonably-sized proof. There is also a constraint on proof chunks to be powers of two, so +// we pick the closest power of two close to 381 but less than that value. 256 gives us around 33M +// multiplications per batch +const CONV_PROOF_CHUNK: usize = 256; #[tracing::instrument(name = "compute_prf_for_inputs", skip_all)] async fn compute_prf_for_inputs( diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 9a1f8f278..e3c8cf49c 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -512,7 +512,7 @@ where }, // TODO: this should not be necessary, but probably can't be removed // until we align read_size with the batch size. - std::cmp::min(sh_ctx.active_work().get(), chunk_size), + std::cmp::min(sh_ctx.active_work().get(), chunk_size.next_power_of_two()), ); dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?; @@ -541,7 +541,7 @@ where protocol: &Step::Aggregate, validate: &Step::AggregateValidate, }, - aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()), + aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()).next_power_of_two(), ); let user_contributions = flattened_user_results.try_collect::>().await?; let result = diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index f92326c9b..bdcd7448e 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -676,7 +676,7 @@ impl Runner for TestWorld { R: Future + Send, { self.malicious(input, |ctx, share| async { - let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 10); + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 8); let m_ctx = v.context(); let m_result = helper_fn(m_ctx, share).await; v.validate().await.unwrap(); From 831986c3c45c1c2fb6d43035804694185bb3ed8a Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 2 Oct 2024 16:28:08 -0700 Subject: [PATCH 105/191] Fix quicksort batch size and align it with a power of two --- ipa-core/src/protocol/ipa_prf/quicksort.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/ipa_prf/quicksort.rs b/ipa-core/src/protocol/ipa_prf/quicksort.rs index 16c3f1c8d..30218e5c2 100644 --- a/ipa-core/src/protocol/ipa_prf/quicksort.rs +++ b/ipa-core/src/protocol/ipa_prf/quicksort.rs @@ -176,7 +176,7 @@ where }, // TODO: use something like this when validating in chunks // `TARGET_PROOF_SIZE / usize::try_from(K::BITS).unwrap() / SORT_CHUNK`` - total_records_usize, + total_records_usize.next_power_of_two(), ); let c = v.context(); let cmp_ctx = c.narrow(&QuicksortPassStep::Compare); From 27f65d01123503805484070c7432a977b2db77cc Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 2 Oct 2024 16:46:00 -0700 Subject: [PATCH 106/191] Fix oneshot bench --- ipa-core/benches/oneshot/ipa.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/ipa-core/benches/oneshot/ipa.rs b/ipa-core/benches/oneshot/ipa.rs index 3d19836ee..b880c28d6 100644 --- a/ipa-core/benches/oneshot/ipa.rs +++ b/ipa-core/benches/oneshot/ipa.rs @@ -86,6 +86,7 @@ impl Args { self.active_work .map(NonZeroUsize::get) .unwrap_or_else(|| self.query_size.clamp(16, 1024)) + .next_power_of_two() } fn attribution_window(&self) -> Option { From 0ba1ec47be5df4621c72c19095984c132d1d8c71 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 2 Oct 2024 21:03:17 -0700 Subject: [PATCH 107/191] Support sharding in MAC and ZKP validations The next step to be able to run sharded protocols is to change the validators API and support both sharded and non-sharded contexts. This is a fairly straightforward change that just propagates `ShardBinding` trait bound through MAC and ZKP contexts and validators --- ipa-core/src/protocol/basics/mod.rs | 14 +++++++--- .../src/protocol/basics/mul/dzkp_malicious.rs | 14 ++++++---- ipa-core/src/protocol/basics/mul/mod.rs | 8 ++++-- ipa-core/src/protocol/basics/reveal.rs | 7 +++-- .../src/protocol/context/dzkp_malicious.rs | 21 +++++++------- .../src/protocol/context/dzkp_validator.rs | 28 +++++++++---------- ipa-core/src/protocol/context/malicious.rs | 14 +++++----- ipa-core/src/protocol/context/mod.rs | 2 +- ipa-core/src/protocol/context/validator.rs | 26 ++++++++--------- ipa-core/src/test_fixture/world.rs | 6 ++-- 10 files changed, 76 insertions(+), 64 deletions(-) diff --git a/ipa-core/src/protocol/basics/mod.rs b/ipa-core/src/protocol/basics/mod.rs index 278aded76..cd2e92be6 100644 --- a/ipa-core/src/protocol/basics/mod.rs +++ b/ipa-core/src/protocol/basics/mod.rs @@ -89,7 +89,10 @@ impl<'a, B: ShardBinding> BooleanProtocols> { } -impl<'a> BooleanProtocols> for AdditiveShare {} +impl<'a, B: ShardBinding> BooleanProtocols> + for AdditiveShare +{ +} // Used for aggregation tests impl<'a, B: ShardBinding> BooleanProtocols, 8> @@ -107,7 +110,7 @@ impl<'a, B: ShardBinding> BooleanProtocols, { } -impl<'a> BooleanProtocols, PRF_CHUNK> +impl<'a, B: ShardBinding> BooleanProtocols, PRF_CHUNK> for AdditiveShare { } @@ -124,7 +127,7 @@ impl<'a, B: ShardBinding> BooleanProtocols, { } -impl<'a> BooleanProtocols, AGG_CHUNK> +impl<'a, B: ShardBinding> BooleanProtocols, AGG_CHUNK> for AdditiveShare { } @@ -159,7 +162,10 @@ impl<'a, B: ShardBinding> BooleanProtocols, { } -impl<'a> BooleanProtocols, 32> for AdditiveShare {} +impl<'a, B: ShardBinding> BooleanProtocols, 32> + for AdditiveShare +{ +} const_assert_eq!( AGG_CHUNK, diff --git a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs index 23a96c982..4a618d507 100644 --- a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs +++ b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs @@ -13,6 +13,7 @@ use crate::{ RecordId, }, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, Vectorizable}, + sharding::{NotSharded, ShardBinding}, }; /// This function implements an MPC multiply using the standard strategy, i.e. via computing the @@ -27,13 +28,14 @@ use crate::{ /// back via the error response /// ## Panics /// Panics if the mutex is found to be poisoned -pub async fn zkp_multiply<'a, F, const N: usize>( - ctx: DZKPUpgradedMaliciousContext<'a>, +pub async fn zkp_multiply<'a, B, F, const N: usize>( + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, a: &Replicated, b: &Replicated, ) -> Result, Error> where + B: ShardBinding, F: Field + DZKPCompatibleField, { // Shared randomness used to mask the values that are sent. @@ -62,17 +64,17 @@ where /// Implement secure multiplication for malicious contexts with replicated secret sharing. #[async_trait] -impl<'a, F: Field + DZKPCompatibleField, const N: usize> - SecureMul> for Replicated +impl<'a, B: ShardBinding, F: Field + DZKPCompatibleField, const N: usize> + SecureMul> for Replicated { async fn multiply<'fut>( &self, rhs: &Self, - ctx: DZKPUpgradedMaliciousContext<'a>, + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, ) -> Result where - DZKPUpgradedMaliciousContext<'a>: 'fut, + DZKPUpgradedMaliciousContext<'a, NotSharded>: 'fut, { zkp_multiply(ctx, record_id, self, rhs).await } diff --git a/ipa-core/src/protocol/basics/mul/mod.rs b/ipa-core/src/protocol/basics/mul/mod.rs index 260d7a4b8..89f5e107a 100644 --- a/ipa-core/src/protocol/basics/mul/mod.rs +++ b/ipa-core/src/protocol/basics/mul/mod.rs @@ -123,17 +123,19 @@ macro_rules! boolean_array_mul { } } - impl<'a> BooleanArrayMul> for Replicated<$vec> { + impl<'a, B: sharding::ShardBinding> BooleanArrayMul> + for Replicated<$vec> + { type Vectorized = Replicated; fn multiply<'fut>( - ctx: DZKPUpgradedMaliciousContext<'a>, + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, a: &'fut Self::Vectorized, b: &'fut Self::Vectorized, ) -> impl Future> + Send + 'fut where - DZKPUpgradedMaliciousContext<'a>: 'fut, + DZKPUpgradedMaliciousContext<'a, B>: 'fut, { use crate::protocol::basics::mul::dzkp_malicious::zkp_multiply; zkp_multiply(ctx, record_id, a, b) diff --git a/ipa-core/src/protocol/basics/reveal.rs b/ipa-core/src/protocol/basics/reveal.rs index 75867046f..19363e1af 100644 --- a/ipa-core/src/protocol/basics/reveal.rs +++ b/ipa-core/src/protocol/basics/reveal.rs @@ -321,20 +321,21 @@ where } } -impl<'a, const N: usize> Reveal> for Replicated +impl<'a, B, const N: usize> Reveal> for Replicated where + B: ShardBinding, Boolean: Vectorizable, { type Output = >::Array; async fn generic_reveal<'fut>( &'fut self, - ctx: DZKPUpgradedMaliciousContext<'a>, + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, excluded: Option, ) -> Result, Error> where - DZKPUpgradedMaliciousContext<'a>: 'fut, + DZKPUpgradedMaliciousContext<'a, B>: 'fut, { malicious_reveal(ctx, record_id, excluded, self).await } diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 9f28239ba..2023f427a 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -20,21 +20,22 @@ use crate::{ Gate, RecordId, }, seq_join::SeqJoin, + sharding::ShardBinding, sync::{Arc, Weak}, }; /// Represents protocol context in malicious setting when using zero-knowledge proofs, /// i.e. secure against one active adversary in 3 party MPC ring. #[derive(Clone)] -pub struct DZKPUpgraded<'a> { - validator_inner: Weak>, - base_ctx: MaliciousContext<'a>, +pub struct DZKPUpgraded<'a, B: ShardBinding> { + validator_inner: Weak>, + base_ctx: MaliciousContext<'a, B>, } -impl<'a> DZKPUpgraded<'a> { +impl<'a, B: ShardBinding> DZKPUpgraded<'a, B> { pub(super) fn new( - validator_inner: &Arc>, - base_ctx: MaliciousContext<'a>, + validator_inner: &Arc>, + base_ctx: MaliciousContext<'a, B>, ) -> Self { let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch(); let active_work = if records_per_batch == 1 { @@ -82,7 +83,7 @@ impl<'a> DZKPUpgraded<'a> { } #[async_trait] -impl<'a> DZKPContext for DZKPUpgraded<'a> { +impl<'a, B: ShardBinding> DZKPContext for DZKPUpgraded<'a, B> { async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> { let validator_inner = self.validator_inner.upgrade().expect("validator is active"); @@ -100,7 +101,7 @@ impl<'a> DZKPContext for DZKPUpgraded<'a> { } } -impl<'a> super::Context for DZKPUpgraded<'a> { +impl<'a, B: ShardBinding> super::Context for DZKPUpgraded<'a, B> { fn role(&self) -> Role { self.base_ctx.role() } @@ -152,13 +153,13 @@ impl<'a> super::Context for DZKPUpgraded<'a> { } } -impl<'a> SeqJoin for DZKPUpgraded<'a> { +impl<'a, B: ShardBinding> SeqJoin for DZKPUpgraded<'a, B> { fn active_work(&self) -> NonZeroUsize { self.base_ctx.active_work() } } -impl Debug for DZKPUpgraded<'_> { +impl Debug for DZKPUpgraded<'_, B> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "DZKPMaliciousContext") } diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 835a32e9d..ec9393486 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -520,7 +520,7 @@ impl Batch { /// ## Panics /// If `usize` to `u128` conversion fails. - pub(super) async fn validate(self, ctx: Base<'_>) -> Result<(), Error> { + pub(super) async fn validate(self, ctx: Base<'_, B>) -> Result<(), Error> { let proof_ctx = ctx.narrow(&Step::GenerateProof); if self.is_empty() { @@ -701,26 +701,26 @@ type DzkpBatcher<'a> = Batcher<'a, Batch>; /// The DZKP validator, and all associated contexts, each hold a reference to a single /// instance of `MaliciousDZKPValidatorInner`. -pub(super) struct MaliciousDZKPValidatorInner<'a> { +pub(super) struct MaliciousDZKPValidatorInner<'a, B: ShardBinding> { pub(super) batcher: Mutex>, - pub(super) validate_ctx: Base<'a>, + pub(super) validate_ctx: Base<'a, B>, } /// `MaliciousDZKPValidator` corresponds to pub struct `Malicious` and implements the trait `DZKPValidator` /// The implementation of `validate` of the `DZKPValidator` trait depends on generic `DF` -pub struct MaliciousDZKPValidator<'a> { +pub struct MaliciousDZKPValidator<'a, B: ShardBinding> { // This is an `Option` because we want to consume it in `DZKPValidator::validate`, // but we also want to implement `Drop`. Note that the `is_verified` check in `Drop` // does nothing when `batcher_ref` is already `None`. - inner_ref: Option>>, - protocol_ctx: MaliciousDZKPUpgraded<'a>, + inner_ref: Option>>, + protocol_ctx: MaliciousDZKPUpgraded<'a, B>, } #[async_trait] -impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> { - type Context = MaliciousDZKPUpgraded<'a>; +impl<'a, B: ShardBinding> DZKPValidator for MaliciousDZKPValidator<'a, B> { + type Context = MaliciousDZKPUpgraded<'a, B>; - fn context(&self) -> MaliciousDZKPUpgraded<'a> { + fn context(&self) -> MaliciousDZKPUpgraded<'a, B> { self.protocol_ctx.clone() } @@ -774,11 +774,11 @@ impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> { } } -impl<'a> MaliciousDZKPValidator<'a> { +impl<'a, B: ShardBinding> MaliciousDZKPValidator<'a, B> { #[must_use] #[allow(clippy::needless_pass_by_value)] pub fn new( - ctx: MaliciousContext<'a>, + ctx: MaliciousContext<'a, B>, steps: MaliciousProtocolSteps, max_multiplications_per_gate: usize, ) -> Self @@ -808,7 +808,7 @@ impl<'a> MaliciousDZKPValidator<'a> { } } -impl<'a> Drop for MaliciousDZKPValidator<'a> { +impl<'a, B: ShardBinding> Drop for MaliciousDZKPValidator<'a, B> { fn drop(&mut self) { if self.inner_ref.is_some() { self.is_verified().unwrap(); @@ -922,7 +922,7 @@ mod tests { async fn test_select_malicious() where V: BooleanArray, - for<'a> Replicated: BooleanArrayMul>, + for<'a> Replicated: BooleanArrayMul>, Standard: Distribution, { let world = TestWorld::default(); @@ -1040,7 +1040,7 @@ mod tests { async fn multi_select_malicious(count: usize, max_multiplications_per_gate: usize) where V: BooleanArray, - for<'a> Replicated: BooleanArrayMul>, + for<'a> Replicated: BooleanArrayMul>, Standard: Distribution, { let mut rng = thread_rng(); diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index 401a6cb0e..def33e950 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -138,14 +138,14 @@ impl<'a, B: ShardBinding> super::Context for Context<'a, B> { } } -impl<'a> UpgradableContext for Context<'a, NotSharded> { - type Validator = BatchValidator<'a, F>; +impl<'a, B: ShardBinding> UpgradableContext for Context<'a, B> { + type Validator = BatchValidator<'a, F, B>; fn validator(self) -> Self::Validator { BatchValidator::new(self) } - type DZKPValidator = MaliciousDZKPValidator<'a>; + type DZKPValidator = MaliciousDZKPValidator<'a, B>; fn dzkp_validator( self, @@ -174,18 +174,18 @@ impl Debug for Context<'_, B> { use crate::sync::{Mutex, Weak}; -pub(super) type MacBatcher<'a, F> = Mutex>>; +pub(super) type MacBatcher<'a, F, B> = Mutex>>; /// Represents protocol context in malicious setting, i.e. secure against one active adversary /// in 3 party MPC ring. #[derive(Clone)] pub struct Upgraded<'a, F: ExtendableField, B: ShardBinding> { - batch: Weak>, + batch: Weak>, base_ctx: Context<'a, B>, } impl<'a, F: ExtendableField, B: ShardBinding> Upgraded<'a, F, B> { - pub(super) fn new(batch: &Arc>, ctx: Context<'a, B>) -> Self { + pub(super) fn new(batch: &Arc>, ctx: Context<'a, B>) -> Self { // The DZKP malicious context adjusts active_work to match records_per_batch. // The MAC validator currently configures the batcher with records_per_batch = // active_work. If the latter behavior changes, this code may need to be @@ -231,7 +231,7 @@ impl<'a, F: ExtendableField, B: ShardBinding> Upgraded<'a, F, B> { self.with_batch(record_id, |v| v.r_share().clone()) } - fn with_batch) -> T, T>( + fn with_batch) -> T, T>( &self, record_id: RecordId, action: C, diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index abd53b6ee..0651b74a4 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -26,7 +26,7 @@ pub type SemiHonestContext<'a, B = NotSharded> = semi_honest::Context<'a, B>; pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>; pub type MaliciousContext<'a, B = NotSharded> = malicious::Context<'a, B>; -pub type UpgradedMaliciousContext<'a, F> = malicious::Upgraded<'a, F, NotSharded>; +pub type UpgradedMaliciousContext<'a, F, B = NotSharded> = malicious::Upgraded<'a, F, B>; #[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))] pub(crate) use malicious::TEST_DZKP_STEPS; diff --git a/ipa-core/src/protocol/context/validator.rs b/ipa-core/src/protocol/context/validator.rs index e57ae3c6a..a71b395c3 100644 --- a/ipa-core/src/protocol/context/validator.rs +++ b/ipa-core/src/protocol/context/validator.rs @@ -199,18 +199,18 @@ impl MaliciousAccumulator { /// When batch is validated, `r` is revealed and can never be /// used again. In fact, it gets out of scope after successful validation /// so no code can get access to it. -pub struct BatchValidator<'a, F: ExtendableField> { - batches_ref: Arc>, - protocol_ctx: MaliciousContext<'a>, +pub struct BatchValidator<'a, F: ExtendableField, B: ShardBinding> { + batches_ref: Arc>, + protocol_ctx: MaliciousContext<'a, B>, } -impl<'a, F: ExtendableField> BatchValidator<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> BatchValidator<'a, F, B> { /// Create a new validator for malicious context. /// /// ## Panics /// If total records is not set. #[must_use] - pub fn new(ctx: MaliciousContext<'a>) -> Self { + pub fn new(ctx: MaliciousContext<'a, B>) -> Self { let TotalRecords::Specified(total_records) = ctx.total_records() else { panic!("Total records must be specified before creating the validator"); }; @@ -230,14 +230,14 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> { } } -pub struct Malicious<'a, F: ExtendableField> { +pub struct Malicious<'a, F: ExtendableField, B: ShardBinding> { r_share: Replicated, pub(super) accumulator: MaliciousAccumulator, - validate_ctx: Base<'a>, + validate_ctx: Base<'a, B>, offset: usize, } -impl Malicious<'_, F> { +impl Malicious<'_, F, B> { /// ## Errors /// If the two information theoretic MACs are not equal (after multiplying by `r`), this indicates that one of the parties /// must have launched an additive attack. At this point the honest parties should abort the protocol. This method throws an @@ -294,21 +294,21 @@ impl Malicious<'_, F> { } } -impl<'a, F> Validator for BatchValidator<'a, F> +impl<'a, F, B: ShardBinding> Validator for BatchValidator<'a, F, B> where F: ExtendableField, { - type Context = UpgradedMaliciousContext<'a, F>; + type Context = UpgradedMaliciousContext<'a, F, B>; fn context(&self) -> Self::Context { UpgradedMaliciousContext::new(&self.batches_ref, self.protocol_ctx.clone()) } } -impl<'a, F: ExtendableField> Malicious<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> Malicious<'a, F, B> { #[must_use] #[allow(clippy::needless_pass_by_value)] - pub fn new(ctx: MaliciousContext<'a>, offset: usize) -> Self { + pub fn new(ctx: MaliciousContext<'a, B>, offset: usize) -> Self { // Each invocation requires 3 calls to PRSS to generate the state. // Validation occurs in batches and `offset` indicates which batch // we're in right now. @@ -386,7 +386,7 @@ impl<'a, F: ExtendableField> Malicious<'a, F> { } } -impl Debug for Malicious<'_, F> { +impl Debug for Malicious<'_, F, B> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "MaliciousValidator<{:?}>", type_name::()) } diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index cd6919580..1c337b10e 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -428,7 +428,7 @@ pub trait Runner { I: IntoShares + Send + 'static, A: Send + 'static, O: Send + Debug, - H: Fn(DZKPUpgradedMaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(DZKPUpgradedMaliciousContext<'a, NotSharded>, A) -> R + Send + Sync, R: Future + Send; } @@ -531,7 +531,7 @@ impl Runner> I: IntoShares + Send + 'static, A: Send + 'static, O: Send + Debug, - H: Fn(DZKPUpgradedMaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(DZKPUpgradedMaliciousContext<'a, NotSharded>, A) -> R + Send + Sync, R: Future + Send, { unimplemented!() @@ -672,7 +672,7 @@ impl Runner for TestWorld { I: IntoShares + Send + 'static, A: Send + 'static, O: Send + Debug, - H: (Fn(DZKPUpgradedMaliciousContext<'a>, A) -> R) + Send + Sync, + H: (Fn(DZKPUpgradedMaliciousContext<'a, NotSharded>, A) -> R) + Send + Sync, R: Future + Send, { self.malicious(input, |ctx, share| async { From b7ee6dba8f990ca6e8f9fd4a338223e40e39def7 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 2 Oct 2024 09:20:41 -0700 Subject: [PATCH 108/191] Rewrite join3 helpers to reduce size of some futures --- .../src/protocol/context/dzkp_validator.rs | 4 +-- ipa-core/src/test_fixture/mod.rs | 28 +++++++++++-------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 835a32e9d..f5586336f 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -916,7 +916,7 @@ mod tests { test_select_semi_honest::().await; test_select_semi_honest::().await; test_select_semi_honest::().await; - Box::pin(test_select_semi_honest::()).await; + test_select_semi_honest::().await; } async fn test_select_malicious() @@ -972,7 +972,7 @@ mod tests { test_select_malicious::().await; test_select_malicious::().await; test_select_malicious::().await; - Box::pin(test_select_malicious::()).await; + test_select_malicious::().await; } #[tokio::test] diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index 041208efd..9c6e7995f 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -19,12 +19,12 @@ pub mod metrics; #[cfg(feature = "in-memory-infra")] mod test_gate; -use std::fmt::Debug; +use std::{fmt::Debug, future::Future}; #[cfg(feature = "in-memory-infra")] pub use app::TestApp; pub use event_gen::{Config as EventGeneratorConfig, EventGenerator}; -use futures::TryFuture; +use futures::{FutureExt, TryFuture}; pub use hybrid_event_gen::{ Config as HybridGeneratorConfig, EventGenerator as HybridEventGenerator, }; @@ -106,30 +106,32 @@ pub fn permutation_valid(permutation: &[u32]) -> bool { /// Wrapper for joining three things into an array. /// # Errors /// If one of the futures returned an error. -pub async fn try_join3_array([f0, f1, f2]: [T; 3]) -> Result<[T::Ok; 3], T::Error> { - futures::future::try_join3(f0, f1, f2) - .await - .map(|(a, b, c)| [a, b, c]) +pub fn try_join3_array( + [f0, f1, f2]: [T; 3], +) -> impl Future> { + futures::future::try_join3(f0, f1, f2).map(|res| res.map(|(a, b, c)| [a, b, c])) } /// Wrapper for joining three things into an array. /// # Panics /// If the tasks return `Err`. -pub async fn join3(a: T, b: T, c: T) -> [T::Ok; 3] +pub fn join3(a: T, b: T, c: T) -> impl Future where T: TryFuture, T::Output: Debug, T::Ok: Debug, T::Error: Debug, { - let (a, b, c) = futures::future::try_join3(a, b, c).await.unwrap(); - [a, b, c] + futures::future::try_join3(a, b, c).map(|res| { + let (a, b, c) = res.unwrap(); + [a, b, c] + }) } /// Wrapper for joining three things from an iterator into an array. /// # Panics /// If the tasks return `Err` or if `a` is the wrong length. -pub async fn join3v(a: V) -> [T::Ok; 3] +pub fn join3v(a: V) -> impl Future where V: IntoIterator, T: TryFuture, @@ -138,9 +140,11 @@ where T::Error: Debug, { let mut it = a.into_iter(); - let res = join3(it.next().unwrap(), it.next().unwrap(), it.next().unwrap()).await; + let fut0 = it.next().unwrap(); + let fut1 = it.next().unwrap(); + let fut2 = it.next().unwrap(); assert!(it.next().is_none()); - res + join3(fut0, fut1, fut2) } /// Take a slice of bits in `{0,1} ⊆ F_p`, and reconstruct the integer in `Z` From c456668fc66f21ca381a85374d24e5f00a3f4306 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 3 Oct 2024 15:15:14 -0700 Subject: [PATCH 109/191] Support sharded malicious protocols inside test infrastructure This is the final PR in the series (#1333, #1315) that enables writing sharded malicious protocols. The support is added for `Runner::malicious` function, and it can be used to execute malicious circuits. Other methods like `upgraded_semi_honest` and `upgraded_malicious` haven't been touched yet - I am not convinced that these are necessary and we can update them as we go. The core goal here is to unblock efforts to implement Hybrid protocol --- ipa-core/src/protocol/basics/mul/malicious.rs | 13 +- ipa-core/src/protocol/context/malicious.rs | 32 ++-- ipa-core/src/protocol/context/mod.rs | 1 + ipa-core/src/protocol/context/semi_honest.rs | 6 +- ipa-core/src/test_fixture/world.rs | 139 +++++++++++++++--- 5 files changed, 151 insertions(+), 40 deletions(-) diff --git a/ipa-core/src/protocol/basics/mul/malicious.rs b/ipa-core/src/protocol/basics/mul/malicious.rs index e55d855d6..92bb6bee7 100644 --- a/ipa-core/src/protocol/basics/mul/malicious.rs +++ b/ipa-core/src/protocol/basics/mul/malicious.rs @@ -16,6 +16,7 @@ use crate::{ malicious::{AdditiveShare as MaliciousReplicated, ExtendableFieldSimd}, semi_honest::AdditiveShare as Replicated, }, + sharding::ShardBinding, }; /// @@ -49,8 +50,8 @@ use crate::{ /// back via the error response /// ## Panics /// Panics if the mutex is found to be poisoned -pub async fn mac_multiply( - ctx: UpgradedMaliciousContext<'_, F>, +pub async fn mac_multiply( + ctx: UpgradedMaliciousContext<'_, F, B>, record_id: RecordId, a: &MaliciousReplicated, b: &MaliciousReplicated, @@ -108,19 +109,19 @@ where /// Implement secure multiplication for malicious contexts with replicated secret sharing. #[async_trait] -impl<'a, F: ExtendableFieldSimd, const N: usize> SecureMul> - for MaliciousReplicated +impl<'a, F: ExtendableFieldSimd, B: ShardBinding, const N: usize> + SecureMul> for MaliciousReplicated where Replicated: FromPrss, { async fn multiply<'fut>( &self, rhs: &Self, - ctx: UpgradedMaliciousContext<'a, F>, + ctx: UpgradedMaliciousContext<'a, F, B>, record_id: RecordId, ) -> Result where - UpgradedMaliciousContext<'a, F>: 'fut, + UpgradedMaliciousContext<'a, F, B>: 'fut, { mac_multiply(ctx, record_id, self, rhs).await } diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index def33e950..ac008a19c 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -323,8 +323,10 @@ impl<'a, F: ExtendableField, B: ShardBinding> SeqJoin for Upgraded<'a, F, B> { /// protocols should be generic over `SecretShare` trait and not requiring this cast and taking /// `ProtocolContext<'a, S: SecretShare, F: Field>` as the context. If that is not possible, /// this implementation makes it easier to reinterpret the context as semi-honest. -impl<'a, F: ExtendableField> SpecialAccessToUpgradedContext for Upgraded<'a, F, NotSharded> { - type Base = Base<'a>; +impl<'a, F: ExtendableField, B: ShardBinding> SpecialAccessToUpgradedContext + for Upgraded<'a, F, B> +{ + type Base = Base<'a, B>; fn base_context(self) -> Self::Base { self.base_ctx.inner @@ -340,7 +342,7 @@ impl Debug for Upgraded<'_, F, B> { /// Upgrading a semi-honest replicated share using malicious context produces /// a MAC-secured share with the same vectorization factor. #[async_trait] -impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> +impl<'a, V: ExtendableFieldSimd, B: ShardBinding, const N: usize> Upgradable> for Replicated where Replicated<::ExtendedField, N>: FromPrss, @@ -349,7 +351,7 @@ where async fn upgrade( self, - ctx: Upgraded<'a, V, NotSharded>, + ctx: Upgraded<'a, V, B>, record_id: RecordId, ) -> Result { let ctx = ctx.narrow(&UpgradeStep); @@ -383,7 +385,7 @@ where #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> +impl<'a, V: ExtendableFieldSimd, B: ShardBinding, const N: usize> Upgradable> for (Replicated, Replicated) where Replicated<::ExtendedField, N>: FromPrss, @@ -392,7 +394,7 @@ where async fn upgrade( self, - ctx: Upgraded<'a, V, NotSharded>, + ctx: Upgraded<'a, V, B>, record_id: RecordId, ) -> Result { let (l, r) = self; @@ -404,12 +406,12 @@ where #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V: ExtendableField> Upgradable> for () { +impl<'a, V: ExtendableField, B: ShardBinding> Upgradable> for () { type Output = (); async fn upgrade( self, - _context: Upgraded<'a, V, NotSharded>, + _context: Upgraded<'a, V, B>, _record_id: RecordId, ) -> Result { Ok(()) @@ -418,28 +420,30 @@ impl<'a, V: ExtendableField> Upgradable> for () { #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V, U> Upgradable> for Vec +impl<'a, V, U, B> Upgradable> for Vec where V: ExtendableField, - U: Upgradable, Output: Send> + Send + 'a, + U: Upgradable, Output: Send> + Send + 'a, + B: ShardBinding, { type Output = Vec; async fn upgrade( self, - ctx: Upgraded<'a, V, NotSharded>, + ctx: Upgraded<'a, V, B>, record_id: RecordId, ) -> Result { /// Need a standalone function to avoid GAT issue that apparently can manifest /// even with `async_trait`. - fn upgrade_vec<'a, V, U>( - ctx: Upgraded<'a, V, NotSharded>, + fn upgrade_vec<'a, V, U, B>( + ctx: Upgraded<'a, V, B>, record_id: RecordId, input: Vec, ) -> impl std::future::Future, Error>> + 'a where V: ExtendableField, - U: Upgradable> + 'a, + U: Upgradable> + 'a, + B: ShardBinding, { let mut upgraded = Vec::with_capacity(input.len()); async move { diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 0651b74a4..627ffc0df 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -26,6 +26,7 @@ pub type SemiHonestContext<'a, B = NotSharded> = semi_honest::Context<'a, B>; pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>; pub type MaliciousContext<'a, B = NotSharded> = malicious::Context<'a, B>; +pub type ShardedMaliciousContext<'a> = malicious::Context<'a, Sharded>; pub type UpgradedMaliciousContext<'a, F, B = NotSharded> = malicious::Upgraded<'a, F, B>; #[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))] diff --git a/ipa-core/src/protocol/context/semi_honest.rs b/ipa-core/src/protocol/context/semi_honest.rs index 65f4e644e..bd8c2e260 100644 --- a/ipa-core/src/protocol/context/semi_honest.rs +++ b/ipa-core/src/protocol/context/semi_honest.rs @@ -302,14 +302,14 @@ impl Debug for Upgraded<'_, B, F> { } #[async_trait] -impl<'a, V: ExtendableField + Vectorizable, const N: usize> - Upgradable> for Replicated +impl<'a, V: ExtendableField + Vectorizable, B: ShardBinding, const N: usize> + Upgradable> for Replicated { type Output = Replicated; async fn upgrade( self, - _context: Upgraded<'a, NotSharded, V>, + _context: Upgraded<'a, B, V>, _record_id: RecordId, ) -> Result { Ok(self) diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 1c337b10e..b54d76abb 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -23,8 +23,8 @@ use crate::{ context::{ dzkp_validator::DZKPValidator, upgrade::Upgradable, Context, DZKPUpgradedMaliciousContext, MaliciousContext, SemiHonestContext, - ShardedSemiHonestContext, UpgradableContext, UpgradedContext, UpgradedMaliciousContext, - UpgradedSemiHonestContext, Validator, TEST_DZKP_STEPS, + ShardedMaliciousContext, ShardedSemiHonestContext, UpgradableContext, UpgradedContext, + UpgradedMaliciousContext, UpgradedSemiHonestContext, Validator, TEST_DZKP_STEPS, }, prss::Endpoint as PrssEndpoint, Gate, QueryId, RecordId, @@ -369,6 +369,10 @@ where pub trait Runner { /// This could be also derived from [`S`], but maybe that's too much for that trait. type SemiHonestContext<'ctx>: Context; + /// The type of context used to run protocols that are secure against + /// active adversaries. It varies depending on whether sharding is used or not. + type MaliciousContext<'ctx>: Context; + /// Run with a context that can be upgraded, but is only good for semi-honest. async fn semi_honest<'a, I, A, O, H, R>( &'a self, @@ -396,12 +400,12 @@ pub trait Runner { R: Future + Send; /// Run with a context that can be upgraded to malicious. - async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3] + async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> S::Container<[O; 3]> where - I: IntoShares + Send + 'static, + I: RunnerInput, A: Send, O: Send + Debug, - H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(Self::MaliciousContext<'a>, S::Container) -> R + Send + Sync, R: Future + Send; /// Run with a context that has already been upgraded to malicious. @@ -444,6 +448,7 @@ impl Runner> for TestWorld> { type SemiHonestContext<'ctx> = ShardedSemiHonestContext<'ctx>; + type MaliciousContext<'ctx> = ShardedMaliciousContext<'ctx>; async fn semi_honest<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> Vec<[O; 3]> where I: RunnerInput, A>, @@ -494,15 +499,39 @@ impl Runner> unimplemented!() } - async fn malicious<'a, I, A, O, H, R>(&'a self, _input: I, _helper_fn: H) -> [O; 3] + async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> Vec<[O; 3]> where - I: IntoShares + Send + 'static, + I: RunnerInput, A>, A: Send, O: Send + Debug, - H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn( + Self::MaliciousContext<'a>, + as ShardingScheme>::Container, + ) -> R + + Send + + Sync, R: Future + Send, { - unimplemented!() + let shards = self.shards(); + let [h1, h2, h3]: [[Vec; SHARDS]; 3] = input.share().map(D::distribute); + let gate = self.next_gate(); + // todo!() + + // No clippy, you're wrong, it is not redundant, it allows shard_fn to be `Copy` + #[allow(clippy::redundant_closure)] + let shard_fn = |ctx, input| helper_fn(ctx, input); + zip(shards.into_iter(), zip(zip(h1, h2), h3)) + .map(|(shard, ((h1, h2), h3))| { + ShardWorld::::run_either( + shard.malicious_contexts(&gate), + self.metrics_handle.span(), + [h1, h2, h3], + shard_fn, + ) + }) + .collect::>() + .collect::>() + .await } async fn upgraded_malicious<'a, F, I, A, M, O, H, R, P>( @@ -541,6 +570,7 @@ impl Runner> #[async_trait] impl Runner for TestWorld { type SemiHonestContext<'ctx> = SemiHonestContext<'ctx>; + type MaliciousContext<'ctx> = MaliciousContext<'ctx>; async fn semi_honest<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3] where @@ -583,10 +613,10 @@ impl Runner for TestWorld { async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3] where - I: IntoShares + Send + 'static, + I: RunnerInput, A: Send, O: Send + Debug, - H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(Self::MaliciousContext<'a>, A) -> R + Send + Sync, R: Future + Send, { ShardWorld::::run_either( @@ -778,9 +808,14 @@ impl ShardWorld { /// # Panics /// Panics if world has more or less than 3 gateways/participants #[must_use] - pub fn malicious_contexts(&self, gate: &Gate) -> [MaliciousContext<'_>; 3] { + pub fn malicious_contexts(&self, gate: &Gate) -> [MaliciousContext<'_, B>; 3] { zip3_ref(&self.participants, &self.gateways).map(|(participant, gateway)| { - MaliciousContext::new_with_gate(participant, gateway, gate.clone(), NotSharded) + MaliciousContext::new_with_gate( + participant, + gateway, + gate.clone(), + self.shard_info.clone(), + ) }) } } @@ -816,7 +851,8 @@ impl Distribute for Random { } } -#[cfg(all(test, unit_test))] +// #[cfg(all(test, unit_test))] +#[cfg(test)] mod tests { use std::{ collections::{HashMap, HashSet}, @@ -826,12 +862,20 @@ mod tests { use futures_util::future::try_join4; use crate::{ - ff::{boolean_array::BA3, Field, Fp31, U128Conversions}, + ff::{boolean::Boolean, boolean_array::BA3, Field, Fp31, U128Conversions}, helpers::{ in_memory_config::{MaliciousHelper, MaliciousHelperContext}, - Direction, Role, + Direction, Role, TotalRecords, + }, + protocol::{ + basics::SecureMul, + context::{ + dzkp_validator::DZKPValidator, upgrade::Upgradable, Context, UpgradableContext, + UpgradedContext, Validator, TEST_DZKP_STEPS, + }, + prss::SharedRandomness, + RecordId, }, - protocol::{context::Context, prss::SharedRandomness, RecordId}, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, SharedValue, @@ -961,4 +1005,65 @@ mod tests { assert_eq!(shares[1].right(), shares[2].left()); }); } + + #[test] + fn zkp_malicious_sharded() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let input = vec![Boolean::truncate_from(0_u32), Boolean::truncate_from(1_u32)]; + let r = world + .malicious(input.clone().into_iter(), |ctx, input| async move { + assert_eq!(1, input.iter().len()); + let ctx = ctx.set_total_records(TotalRecords::ONE); + let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 1); + let ctx = validator.context(); + let r = input[0] + .multiply(&input[0], ctx, RecordId::FIRST) + .await + .unwrap(); + validator.validate().await.unwrap(); + + vec![r] + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } + + #[test] + fn mac_malicious_sharded() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let input = vec![Fp31::truncate_from(0_u32), Fp31::truncate_from(1_u32)]; + let r = world + .malicious(input.clone().into_iter(), |ctx, input| async move { + assert_eq!(1, input.iter().len()); + let validator = ctx.set_total_records(1).validator(); + let ctx = validator.context(); + let (a_upgraded, b_upgraded) = (input[0].clone(), input[0].clone()) + .upgrade(ctx.clone(), RecordId::FIRST) + .await + .unwrap(); + let _ = a_upgraded + .multiply(&b_upgraded, ctx.narrow("multiply"), RecordId::FIRST) + .await + .unwrap(); + ctx.validate_record(RecordId::FIRST).await.unwrap(); + + input + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } } From 824ee6dff8989f9d0210c79c3643005c0110384a Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 3 Oct 2024 16:51:29 -0700 Subject: [PATCH 110/191] Fix compact gate tests --- ipa-core/src/test_fixture/world.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index b54d76abb..b2ea2759a 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -851,8 +851,7 @@ impl Distribute for Random { } } -// #[cfg(all(test, unit_test))] -#[cfg(test)] +#[cfg(all(test, unit_test))] mod tests { use std::{ collections::{HashMap, HashSet}, From 27de8074345dabfbb4d84980b0a0baf676b17c8d Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 3 Oct 2024 17:04:47 -0700 Subject: [PATCH 111/191] Skip the unverified multiplies check if we're already unwinding (#1334) --- .../src/protocol/context/dzkp_validator.rs | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index db34399b0..f40a7d805 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -810,7 +810,15 @@ impl<'a, B: ShardBinding> MaliciousDZKPValidator<'a, B> { impl<'a, B: ShardBinding> Drop for MaliciousDZKPValidator<'a, B> { fn drop(&mut self) { - if self.inner_ref.is_some() { + // If `validate` has not been called, and we are not unwinding, check that the + // validator is not holding unverified multiplies. + // * If `validate` has been called (i.e. the validator was used in the + // non-`validate_record` mode of operation), then `self.inner_ref` is `None`, + // because validation consumed the batcher via `self.inner_ref`. + // * Unwinding can happen at any time, so complaining about incomplete + // validation is likely just extra noise, and the additional panic + // during unwinding could be confusing. + if self.inner_ref.is_some() && !std::thread::panicking() { self.is_verified().unwrap(); } } @@ -1249,6 +1257,47 @@ mod tests { } } + #[tokio::test] + #[should_panic(expected = "ContextUnsafe(\"DZKPMaliciousContext\")")] + async fn missing_validate() { + let mut rng = thread_rng(); + + let a = rng.gen::(); + let b = rng.gen::(); + + TestWorld::default() + .malicious((a, b), |ctx, (a, b)| async move { + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 1); + let m_ctx = v.context().set_total_records(1); + + a.multiply(&b, m_ctx, RecordId::FIRST).await.unwrap() + + // `validate` should appear here. + }) + .await; + } + + #[tokio::test] + #[should_panic(expected = "panicking before validate")] + #[allow(unreachable_code)] + async fn missing_validate_panic() { + let mut rng = thread_rng(); + + let a = rng.gen::(); + let b = rng.gen::(); + + TestWorld::default() + .malicious((a, b), |ctx, (a, b)| async move { + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 1); + let m_ctx = v.context().set_total_records(1); + + let _result = a.multiply(&b, m_ctx, RecordId::FIRST).await.unwrap(); + + panic!("panicking before validate"); + }) + .await; + } + #[test] fn batch_allocation_small() { const SIZE: usize = 1; From 52fd07704335fe6f3ca4eb01f5022df39b235024 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 4 Oct 2024 13:50:36 -0700 Subject: [PATCH 112/191] Fix shuttle tests --- ipa-core/src/helpers/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index e33a2ec99..2c43ccd53 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -817,7 +817,7 @@ mod concurrency_tests { let input = (0u32..11).map(TestField::truncate_from).collect::>(); let config = TestWorldConfig { gateway_config: GatewayConfig { - active: input.len().try_into().unwrap(), + active: input.len().next_power_of_two().try_into().unwrap(), ..Default::default() }, ..Default::default() @@ -875,7 +875,7 @@ mod concurrency_tests { let input = (0u32..11).map(TestField::truncate_from).collect::>(); let config = TestWorldConfig { gateway_config: GatewayConfig { - active: input.len().try_into().unwrap(), + active: input.len().next_power_of_two().try_into().unwrap(), ..Default::default() }, ..Default::default() From 279c3063dfe17e7436055fde9bc749eb78a0567e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 4 Oct 2024 13:52:45 -0700 Subject: [PATCH 113/191] Fix flaky `send_recv_randomized` test When using more than one thread, this test was failing if futures were scheduled out of order, because `Notify` couldn't wake up futures scheduled after `notify_all` call. Using barriers solves the issue --- ipa-core/src/helpers/gateway/mod.rs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 15d2580d2..42982abc0 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -318,6 +318,7 @@ mod tests { stream::StreamExt, }; use proptest::proptest; + use tokio::sync::Barrier; use crate::{ ff::{ @@ -751,7 +752,11 @@ mod tests { ) where M: MpcMessage + Clone + PartialEq, { - let send_notify = Arc::new(tokio::sync::Notify::new()); + let last_batch_size = total_records % active_work; + let last_batch = total_records / active_work; + + let barrier = Arc::new(Barrier::new(active_work)); + let last_batch_barrier = Arc::new(Barrier::new(last_batch_size)); // perform "multiplication-like" operation (send + subsequent receive) // and "validate": block the future until we have at least `active_work` @@ -762,7 +767,8 @@ mod tests { |(record_id, msg)| { let send_channel = &send_channel; let recv_channel = &recv_channel; - let send_notify = Arc::clone(&send_notify); + let barrier = Arc::clone(&barrier); + let last_batch_barrier = Arc::clone(&last_batch_barrier); async move { send_channel .send(record_id.into(), msg.clone()) @@ -771,13 +777,12 @@ mod tests { let r = recv_channel.receive(record_id.into()).await.unwrap(); // this simulates validate_record API by forcing futures to wait // until the entire batch is validated by the last future in that batch - if record_id % active_work == active_work - 1 - || record_id == total_records - 1 - { - send_notify.notify_waiters(); + if record_id >= last_batch * active_work { + last_batch_barrier.wait().await; } else { - send_notify.notified().await; + barrier.wait().await; } + assert_eq!(msg, r); } }, From 5e228244f151492f0f03c5b167ad146eaf680a11 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 4 Oct 2024 14:54:27 -0700 Subject: [PATCH 114/191] Change how we compute the previous power of two. Using bitshift turns out to be much easier to understand --- ipa-core/src/helpers/gateway/send.rs | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index b1dc11155..9e0753ab7 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -259,17 +259,12 @@ impl SendChannelConfig { let total_capacity = gateway_config.active.get() * record_size; // define read size as a multiplier of record size. The multiplier must be - // a power of two to align perfectly with total capacity. - let read_size_multiplier = { - // next_power_of_two returns a value that is greater than or equal to. - // That is not what we want here: if read_size / record_size is a power - // of two, then subsequent division will get us further away from desired target. - // For example: if read_size / record_size = 4, then prev_power_of_two = 2. - // In such cases, we want to stay where we are, so we add +1 for that. - let prev_power_of_two = - (gateway_config.read_size.get() / record_size + 1).next_power_of_two() / 2; - std::cmp::max(1, prev_power_of_two) - }; + // a power of two to align perfectly with total capacity. We don't want to exceed + // the target read size, so we pick a power of two <= read_size. + let read_size_multiplier = + // this computes the highest power of 2 less than or equal to read_size / record_size. + // Note, that if record_size is greater than read_size, we round it to 1 + 1 << (std::cmp::max(1, usize::BITS - (gateway_config.read_size.get() / record_size).leading_zeros()) - 1); let this = Self { total_capacity: total_capacity.try_into().unwrap(), From 4aeb686c3ba1df34da30b4e77b16df4eb16b7070 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 4 Oct 2024 14:54:52 -0700 Subject: [PATCH 115/191] Change comment on `read_size` --- ipa-core/src/helpers/gateway/mod.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 42982abc0..2d48380db 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -82,13 +82,20 @@ pub struct GatewayConfig { /// A rule of thumb is that this should get as close to network packet size as possible. /// /// This will be set for all channels and because they send records of different side, the actual - /// payload may not be exactly this, but it will be the closest multiple of record size to this - /// number. For instance, having 14 bytes records and batch size of 4096 will result in - /// 4088 bytes being sent in a batch. + /// payload may not be exactly this, but it will be the closest multiple of record size smaller than + /// or equal to number. For alignment reasons, this multiple will be a power of two, otherwise + /// a deadlock is possible. See ipa/#1300 for details how it can happen. /// - /// The actual size for read chunks may be bigger or smaller, depending on the record size - /// sent through each channel. Read size will be aligned with [`Self::active_work`] value to - /// prevent deadlocks. + /// For instance, having 14 bytes records and batch size of 4096 will result in + /// 3584 bytes being sent in a batch (`2^8 * 14 < 4096, 2^9 * 14 > 4096`). + /// + /// The consequence is that HTTP buffer size may not be perfectly aligned with the target. + /// As long as we use TCP it does not matter, but if we want to switch to UDP and have + /// precise control over the size of chunk sent, we should tune the buffer size at the + /// HTTP layer instead (using Hyper/H3 API or something like that). If we do this, then + /// read size becomes obsolete and should be removed in favor of flushing the entire + /// buffer chunks from the application layer down to HTTP and let network to figure out + /// the best way to slice this data before sending it to a peer. pub read_size: NonZeroUsize, /// Time to wait before checking gateway progress. If no progress has been made between From ac31cbbfa612e3dbeab23cba8d7093ee46d76ac5 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 4 Oct 2024 14:55:21 -0700 Subject: [PATCH 116/191] Don't run send config tests under Shuttle There is no reason to do that. --- ipa-core/src/helpers/gateway/send.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index 9e0753ab7..3a0fc1ad8 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -286,7 +286,7 @@ impl SendChannelConfig { } } -#[cfg(test)] +#[cfg(all(test, unit_test))] mod test { use std::num::NonZeroUsize; From 34838feb2873400f01c65e1d90c464b7581e5147 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 4 Oct 2024 15:11:04 -0700 Subject: [PATCH 117/191] Don't run non zero power of two tests under Shuttle There is no reason to do that. --- ipa-core/src/utils/power_of_two.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/utils/power_of_two.rs b/ipa-core/src/utils/power_of_two.rs index abce8055e..a84455c92 100644 --- a/ipa-core/src/utils/power_of_two.rs +++ b/ipa-core/src/utils/power_of_two.rs @@ -68,7 +68,7 @@ impl NonZeroU32PowerOfTwo { } } -#[cfg(test)] +#[cfg(all(test, unit_test))] mod tests { use super::{ConvertError, NonZeroU32PowerOfTwo}; From 230920f30a1366da754564edceb5f93fe6a35ac8 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sat, 5 Oct 2024 00:15:48 -0700 Subject: [PATCH 118/191] Introduce IpaRuntime and plumb it all the way down to executor The time has come to have a need to have a precise control over which runtime is used to run queries. The reason for that is that we had another occurrence of HTTP keep-alive timeout that aligns with OPRF computation. ``` 2024-10-05T03:57:35.203964 - 2024-10-05T03:57:35.203922Z INFO oprf_ipa_query{sz=50000000}:compute_prf_for_inputs: ipa_core::protocol::context::batcher: batch 30 is ready for validation 2024-10-05T03:57:35.206943 - 2024-10-05T03:57:35.206901Z INFO oprf_ipa_query{sz=50000000}:compute_prf_for_inputs: ipa_core::protocol::context::batcher: batch 31 is ready for validation 2024-10-05T03:57:35.209813 - 2024-10-05T03:57:35.209777Z INFO oprf_ipa_query{sz=50000000}:compute_prf_for_inputs: ipa_core::protocol::context::batcher: batch 32 is ready for validation 2024-10-05T03:58:28.720077 - 2024-10-05T03:58:28.715520Z ERROR ipa_core::error: ThreadId(9) "tokio-runtime-worker" panicked at ipa-core/src/helpers/gateway/send.rs:222:30: 2024-10-05T03:58:28.720653 - {channel_id:?} receiving end should be accepted by transport: SendToRoleError(H1, ConnectError { dest: "helper1.ipa-helper.dev", inner: hyper_util::client::legacy::Error(SendRequest,> 2024-10-05T03:58:28.720957 - stack trace: 2024-10-05T03:58:28.721250 - 0: ipa_core::error::set_global_panic_hook::{{closure}} 2024-10-05T03:58:28.721524 - 1: std::panicking::rust_panic_with_hook 2024-10-05T03:58:28.721776 - 2: std::panicking::begin_panic_handler::{{closure}} 2024-10-05T03:58:28.722034 - 3: std::sys_common::backtrace::__rust_end_short_backtrace 2024-10-05T03:58:28.722281 - 4: rust_begin_unwind 2024-10-05T03:58:28.722522 - 5: core::panicking::panic_fmt 2024-10-05T03:58:28.722764 - 6: core::result::unwrap_failed 2024-10-05T03:58:28.723006 - 7: ipa_core::helpers::gateway::send::GatewaySenders::get::{{closure}} 2024-10-05T03:58:28.723246 - 8: tokio::runtime::task::core::Core::poll 2024-10-05T03:58:28.723490 - 9: tokio::runtime::task::harness::Harness::poll 2024-10-05T03:58:28.723749 - 10: tokio::runtime::scheduler::multi_thread::worker::Context::run_task 2024-10-05T03:58:28.723986 - 11: tokio::runtime::scheduler::multi_thread::worker::Context::run 2024-10-05T03:58:28.724235 - 12: tokio::runtime::context::runtime::enter_runtime 2024-10-05T03:58:28.724484 - 13: tokio::runtime::scheduler::multi_thread::worker::run 2024-10-05T03:58:28.724732 - 14: tokio::runtime::task::core::Core::poll 2024-10-05T03:58:28.724974 - 15: tokio::runtime::task::harness::Harness::poll 2024-10-05T03:58:28.725221 - 16: tokio::runtime::blocking::pool::Inner::run 2024-10-05T03:58:28.725462 - 17: std::sys_common::backtrace::__rust_begin_short_backtrace 2024-10-05T03:58:28.725705 - 18: core::ops::function::FnOnce::call_once{{vtable.shim}} 2024-10-05T03:58:28.725946 - 19: std::sys::pal::unix::thread::Thread::new::thread_start 2024-10-05T03:58:28.726193 - 20: start_thread 2024-10-05T03:58:28.726422 - 21: __clone3 ``` The root cause for this is PRF computation blocking scheduler for too long, so it does not schedule Hyper task to respond to `status` requests from RC, or to accept data from another peer. While it deserves to be fixed (I believe @danielmasny was looking into why we trash CPU so badly in PRF), it is not OK to crash if that happens. This change just does the plumbing to allow dedicated runtime to be provided for query executors. --- ipa-core/src/app.rs | 10 +++- ipa-core/src/lib.rs | 82 +++++++++++++++++++++++++++++++++ ipa-core/src/query/executor.rs | 69 ++++++++++++++++++--------- ipa-core/src/query/processor.rs | 21 ++++++--- ipa-core/src/query/state.rs | 4 +- 5 files changed, 154 insertions(+), 32 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index da56e67e3..8375c0cbd 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -3,6 +3,7 @@ use std::{num::NonZeroUsize, sync::Weak}; use async_trait::async_trait; use crate::{ + executor::IpaRuntime, helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, routing::{Addr, RouteId}, @@ -19,6 +20,7 @@ use crate::{ pub struct AppConfig { active_work: Option, key_registry: Option>, + runtime: IpaRuntime, } impl AppConfig { @@ -33,6 +35,12 @@ impl AppConfig { self.key_registry = Some(key_registry); self } + + #[must_use] + pub fn with_runtime(mut self, runtime: IpaRuntime) -> Self { + self.runtime = runtime; + self + } } pub struct Setup { @@ -60,7 +68,7 @@ impl Setup { #[must_use] pub fn new(config: AppConfig) -> (Self, HandlerRef) { let key_registry = config.key_registry.unwrap_or_else(KeyRegistry::empty); - let query_processor = QueryProcessor::new(key_registry, config.active_work); + let query_processor = QueryProcessor::new(key_registry, config.active_work, config.runtime); let handler = HandlerBox::empty(); let this = Self { query_processor, diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 59cae0106..35bd1272f 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -94,9 +94,91 @@ pub(crate) mod shim { #[cfg(not(all(feature = "shuttle", test)))] pub(crate) mod task { + #[allow(unused_imports)] pub use tokio::task::{JoinError, JoinHandle}; } +#[cfg(not(feature = "shuttle"))] +pub mod executor { + use std::future::Future; + + use tokio::{runtime::Handle, task::JoinHandle}; + + /// In prod we use Tokio scheduler, so this struct just wraps + /// its runtime handle and mimics the standard executor API. + /// The name was chosen to avoid clashes with tokio runtime + /// when importing it + #[derive(Clone)] + pub struct IpaRuntime(Handle); + + /// Wrapper around Tokio's [`JoinHandle`] + pub struct IpaJoinHandle(JoinHandle); + + impl Default for IpaRuntime { + fn default() -> Self { + Self::current() + } + } + + impl IpaRuntime { + #[must_use] + pub fn current() -> Self { + Self(Handle::current()) + } + + #[must_use] + pub fn spawn(&self, future: F) -> IpaJoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + IpaJoinHandle(self.0.spawn(future)) + } + } + + impl IpaJoinHandle { + pub fn abort(self) { + self.0.abort(); + } + } +} + +#[cfg(feature = "shuttle")] +pub(crate) mod executor { + use std::future::Future; + + use shuttle_crate::future::{spawn, JoinHandle}; + + /// Shuttle does not support more than one runtime + /// so we always use its default + #[derive(Clone, Default)] + pub struct IpaRuntime; + pub struct IpaJoinHandle(JoinHandle); + + impl IpaRuntime { + #[must_use] + pub fn current() -> Self { + Self + } + + #[must_use] + #[allow(clippy::unused_self)] // to conform with runtime API + pub fn spawn(&self, future: F) -> IpaJoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + IpaJoinHandle(spawn(future)) + } + } + + impl IpaJoinHandle { + pub fn abort(self) { + self.0.abort(); + } + } +} + #[cfg(all(feature = "shuttle", test))] pub(crate) mod test_executor { use std::future::Future; diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index b3e197e4d..a6b1dae74 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -15,8 +15,6 @@ use generic_array::GenericArray; use ipa_step::StepNarrow; use rand::rngs::StdRng; use rand_core::SeedableRng; -#[cfg(all(feature = "shuttle", test))] -use shuttle::future as tokio; use typenum::Unsigned; #[cfg(any( @@ -26,11 +24,8 @@ use typenum::Unsigned; feature = "weak-field" ))] use crate::ff::FieldType; -#[cfg(any(test, feature = "cli", feature = "test-fixture"))] -use crate::{ - ff::Fp32BitPrime, query::runner::execute_test_multiply, query::runner::test_add_in_prime_field, -}; use crate::{ + executor::IpaRuntime, ff::{boolean_array::BA32, Serializable}, helpers::{ negotiate_prss, @@ -49,6 +44,10 @@ use crate::{ }, sync::Arc, }; +#[cfg(any(test, feature = "cli", feature = "test-fixture"))] +use crate::{ + ff::Fp32BitPrime, query::runner::execute_test_multiply, query::runner::test_add_in_prime_field, +}; pub trait Result: Send + Debug { fn to_bytes(&self) -> Vec; @@ -74,6 +73,7 @@ where /// Needless pass by value because IPA v3 does not make use of key registry yet. #[allow(clippy::too_many_lines, clippy::needless_pass_by_value)] pub fn execute( + runtime: &IpaRuntime, config: QueryConfig, key_registry: Arc, gateway: Gateway, @@ -81,45 +81,63 @@ pub fn execute( ) -> RunningQuery { match (config.query_type, config.field_type) { #[cfg(any(test, feature = "weak-field"))] - (QueryType::TestMultiply, FieldType::Fp31) => { - do_query(config, gateway, input, |prss, gateway, _config, input| { + (QueryType::TestMultiply, FieldType::Fp31) => do_query( + runtime, + config, + gateway, + input, + |prss, gateway, _config, input| { Box::pin(execute_test_multiply::( prss, gateway, input, )) - }) - } + }, + ), #[cfg(any(test, feature = "cli", feature = "test-fixture"))] - (QueryType::TestMultiply, FieldType::Fp32BitPrime) => { - do_query(config, gateway, input, |prss, gateway, _config, input| { + (QueryType::TestMultiply, FieldType::Fp32BitPrime) => do_query( + runtime, + config, + gateway, + input, + |prss, gateway, _config, input| { Box::pin(execute_test_multiply::(prss, gateway, input)) - }) - } + }, + ), #[cfg(any(test, feature = "cli", feature = "test-fixture"))] (QueryType::TestShardedShuffle, _) => do_query( + runtime, config, gateway, input, |_prss, _gateway, _config, _input| unimplemented!(), ), #[cfg(any(test, feature = "weak-field"))] - (QueryType::TestAddInPrimeField, FieldType::Fp31) => { - do_query(config, gateway, input, |prss, gateway, _config, input| { + (QueryType::TestAddInPrimeField, FieldType::Fp31) => do_query( + runtime, + config, + gateway, + input, + |prss, gateway, _config, input| { Box::pin(test_add_in_prime_field::( prss, gateway, input, )) - }) - } + }, + ), #[cfg(any(test, feature = "cli", feature = "test-fixture"))] - (QueryType::TestAddInPrimeField, FieldType::Fp32BitPrime) => { - do_query(config, gateway, input, |prss, gateway, _config, input| { + (QueryType::TestAddInPrimeField, FieldType::Fp32BitPrime) => do_query( + runtime, + config, + gateway, + input, + |prss, gateway, _config, input| { Box::pin(test_add_in_prime_field::( prss, gateway, input, )) - }) - } + }, + ), // TODO(953): This is really using BA32, not Fp32bitPrime. The `FieldType` mechanism needs // to be reworked. (QueryType::SemiHonestOprfIpa(ipa_config), _) => do_query( + runtime, config, gateway, input, @@ -133,6 +151,7 @@ pub fn execute( }, ), (QueryType::MaliciousOprfIpa(ipa_config), _) => do_query( + runtime, config, gateway, input, @@ -146,6 +165,7 @@ pub fn execute( }, ), (QueryType::SemiHonestHybrid(query_params), _) => do_query( + runtime, config, gateway, input, @@ -162,6 +182,7 @@ pub fn execute( } pub fn do_query( + executor_handle: &IpaRuntime, config: QueryConfig, gateway: B, input_stream: BodyStream, @@ -180,7 +201,7 @@ where { let (tx, rx) = oneshot::channel(); - let join_handle = tokio::spawn(async move { + let join_handle = executor_handle.spawn(async move { let gateway = gateway.borrow(); // TODO: make it a generic argument for this function let mut rng = StdRng::from_entropy(); @@ -232,6 +253,7 @@ mod tests { use tokio::sync::Barrier; use crate::{ + executor::IpaRuntime, ff::{FieldType, Fp31, U128Conversions}, helpers::{ query::{QueryConfig, QueryType}, @@ -352,6 +374,7 @@ mod tests { Fut: Future + Send, { do_query( + &IpaRuntime::current(), QueryConfig { size: 1.try_into().unwrap(), field_type: FieldType::Fp31, diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index a8694012e..ffde7699f 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -9,6 +9,7 @@ use serde::Serialize; use crate::{ error::Error as ProtocolError, + executor::IpaRuntime, helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, RoleAssignment, @@ -45,6 +46,7 @@ pub struct Processor { queries: RunningQueries, key_registry: Arc>, active_work: Option, + runtime: IpaRuntime, } impl Default for Processor { @@ -53,6 +55,7 @@ impl Default for Processor { queries: RunningQueries::default(), key_registry: Arc::new(KeyRegistry::::empty()), active_work: None, + runtime: IpaRuntime::current(), } } } @@ -119,11 +122,13 @@ impl Processor { pub fn new( key_registry: KeyRegistry, active_work: Option, + runtime: IpaRuntime, ) -> Self { Self { queries: RunningQueries::default(), key_registry: Arc::new(key_registry), active_work, + runtime, } } @@ -249,6 +254,7 @@ impl Processor { queries.insert( input.query_id, QueryState::Running(executor::execute( + &self.runtime, config, Arc::clone(&self.key_registry), gateway, @@ -584,6 +590,7 @@ mod tests { use std::sync::Arc; use crate::{ + executor::IpaRuntime, ff::FieldType, helpers::{ query::{ @@ -603,11 +610,13 @@ mod tests { #[test] fn non_existent_query() { - let processor = Processor::default(); - assert!(matches!( - processor.kill(QueryId), - Err(QueryKillStatus::NoSuchQuery(QueryId)) - )); + run(|| async { + let processor = Processor::default(); + assert!(matches!( + processor.kill(QueryId), + Err(QueryKillStatus::NoSuchQuery(QueryId)) + )); + }); } #[test] @@ -650,7 +659,7 @@ mod tests { let processor = Processor::default(); let (_tx, rx) = tokio::sync::oneshot::channel(); let counter = Arc::new(1); - let task = tokio::spawn({ + let task = IpaRuntime::current().spawn({ let counter = Arc::clone(&counter); async move { loop { diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index 3c4359ca9..9d42a0439 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -10,11 +10,11 @@ use futures::{ready, FutureExt}; use serde::{Deserialize, Serialize}; use crate::{ + executor::IpaJoinHandle, helpers::{query::QueryConfig, RoleAssignment}, protocol::QueryId, query::runner::QueryResult, sync::Mutex, - task::JoinHandle, }; /// The status of query processing @@ -87,7 +87,7 @@ pub struct RunningQuery { /// /// We could return the result via the `JoinHandle`, except that we want to check the status /// of the task, and shuttle doesn't implement `JoinHandle::is_finished`. - pub join_handle: JoinHandle<()>, + pub join_handle: IpaJoinHandle<()>, } impl RunningQuery { From bdc5eb5e8d6a67ea4c8b4c73ff3c37dc04df12f7 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 7 Oct 2024 10:48:33 -0700 Subject: [PATCH 119/191] Follow up on feedback from #1332 --- ipa-core/src/helpers/gateway/send.rs | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index 3a0fc1ad8..089876243 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -255,16 +255,27 @@ impl SendChannelConfig { total_records: TotalRecords, record_size: usize, ) -> Self { + // this computes the greatest positive power of 2 that is + // less than or equal to target. + fn non_zero_prev_power_of_two(target: usize) -> usize { + let bits = usize::BITS - target.leading_zeros(); + + 1 << (std::cmp::max(1, bits) - 1) + } + assert!(record_size > 0, "Message size cannot be 0"); let total_capacity = gateway_config.active.get() * record_size; // define read size as a multiplier of record size. The multiplier must be // a power of two to align perfectly with total capacity. We don't want to exceed - // the target read size, so we pick a power of two <= read_size. - let read_size_multiplier = - // this computes the highest power of 2 less than or equal to read_size / record_size. - // Note, that if record_size is greater than read_size, we round it to 1 - 1 << (std::cmp::max(1, usize::BITS - (gateway_config.read_size.get() / record_size).leading_zeros()) - 1); + // the target read size, so multiplier * record_size <= read_size. We want to get + // as close as possible to read_size. + let read_size_multiplier = { + let target = gateway_config.read_size.get() / record_size; + // If record_size is greater than read_size, we set the multiplier to 1 + // as read size cannot be 0. + non_zero_prev_power_of_two(target) + }; let this = Self { total_capacity: total_capacity.try_into().unwrap(), @@ -279,7 +290,11 @@ impl SendChannelConfig { total_records, }; + // If capacity can't fit all active work items, the protocol deadlocks on + // inserts above the total capacity. assert!(this.total_capacity.get() >= record_size * gateway_config.active.get()); + // if capacity is not aligned with read size, we can get a deadlock + // described in ipa/1300 assert_eq!(0, this.total_capacity.get() % this.read_size.get()); this From 357ecde0252d8041d393c4238052bc2f19264762 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 8 Oct 2024 18:03:13 -0700 Subject: [PATCH 120/191] Exclude all packages from sanitizers except ipa_core --- .github/workflows/check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 4dfdbdf2d..9a7735e9e 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -201,7 +201,7 @@ jobs: - name: Add Rust sources run: rustup component add rust-src - name: Run tests with sanitizer - run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate ${{ matrix.features }}" + run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std -p ipa-core --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate ${{ matrix.features }}" miri: runs-on: ubuntu-latest From 0194eb89ba3e58ae24dd4c70a45ef6a6b4096ad2 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 9 Oct 2024 13:55:53 -0700 Subject: [PATCH 121/191] Make reshard work with streams too Internally, reshard used streams already, so it is only a matter of changing the API and connecting things together --- ipa-core/src/protocol/context/mod.rs | 65 ++++++++++++++++--- .../src/protocol/ipa_prf/shuffle/sharded.rs | 4 +- 2 files changed, 57 insertions(+), 12 deletions(-) diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 3463e85db..4e9c38e39 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -35,8 +35,8 @@ pub(crate) use malicious::TEST_DZKP_STEPS; use crate::{ error::Error, helpers::{ - ChannelId, Direction, Gateway, Message, MpcMessage, MpcReceivingEnd, Role, SendingEnd, - ShardReceivingEnd, TotalRecords, + stream::ExactSizeStream, ChannelId, Direction, Gateway, Message, MpcMessage, + MpcReceivingEnd, Role, SendingEnd, ShardReceivingEnd, TotalRecords, }, protocol::{ context::dzkp_validator::DZKPValidator, @@ -374,24 +374,44 @@ impl<'a> Inner<'a> { /// /// [`calculations`]: https://docs.google.com/document/d/1vej6tYgNV3GWcldD4tl7a4Z9EeZwda3F5u7roPGArlU/ /// +/// ## Stream size +/// Note that it currently works for streams where size is known in advance. Mainly because +/// we want to set up send buffer sizes and avoid sending records one-by-one to each shard. +/// Other than that, there are no technical limitation here, and it could be possible to make it +/// work with regular streams if the batching problem is somehow addressed. +/// +/// +/// ```compile_fail +/// use futures::stream::{self, StreamExt}; +/// use ipa_core::protocol::context::reshard_stream; +/// use ipa_core::ff::boolean::Boolean; +/// use ipa_core::secret_sharing::SharedValue; +/// async { +/// let a = [Boolean::ZERO]; +/// let mut s = stream::iter(a.into_iter()).cycle(); +/// // this should fail to compile: +/// // the trait bound `futures::stream::Cycle<...>: ExactSizeStream` is not satisfied +/// reshard_stream(todo!(), s, todo!()).await; +/// }; +/// ``` +/// /// ## Panics /// When `shard_picker` returns an out-of-bounds index. /// /// ## Errors /// If cross-shard communication fails -pub async fn reshard( +/// +pub async fn reshard_stream( ctx: C, input: L, shard_picker: S, ) -> Result, crate::error::Error> where - L: IntoIterator, - L::IntoIter: ExactSizeIterator, + L: ExactSizeStream, S: Fn(C, RecordId, &K) -> ShardIndex, K: Message + Clone, C: ShardedContext, { - let input = input.into_iter(); let input_len = input.len(); // We set channels capacity to be at least 1 to be able to open send channels to all peers. @@ -426,6 +446,8 @@ where }) .fuse(); + let input = pin!(input); + // This produces a stream of outcomes of send requests. // In order to make it compatible with receive stream, it also returns records that must // stay on this shard, according to `shard_picker`'s decision. @@ -439,13 +461,15 @@ where // tracking per shard to work correctly. If tasks complete out of order, this will cause share // misplacement on the recipient side. ( - input.enumerate().zip(iter::repeat(ctx.clone())), + input + .enumerate() + .zip(stream::iter(iter::repeat(ctx.clone()))), &mut send_channels, ), |(mut input, send_channels)| async { // Process more data as it comes in, or close the sending channels, if there is nothing // left. - if let Some(((i, val), ctx)) = input.next() { + if let Some(((i, val), ctx)) = input.next().await { let dest_shard = shard_picker(ctx, RecordId::from(i), &val); if dest_shard == my_shard { Some(((my_shard, Ok(Some(val.clone()))), (input, send_channels))) @@ -504,6 +528,27 @@ where Ok(r.into_iter().flatten().collect()) } +/// Same as [`reshard_stream`] but takes an iterator with the known size +/// as input. +/// +/// # Errors +/// +/// # Panics +pub async fn reshard_iter( + ctx: C, + input: L, + shard_picker: S, +) -> Result, crate::error::Error> +where + L: IntoIterator, + L::IntoIter: ExactSizeIterator, + S: Fn(C, RecordId, &K) -> ShardIndex, + K: Message + Clone, + C: ShardedContext, +{ + reshard_stream(ctx, stream::iter(input.into_iter()), shard_picker).await +} + /// trait for contexts that allow MPC multiplications that are protected against a malicious helper by using a DZKP #[async_trait] pub trait DZKPContext: Context { @@ -543,7 +588,7 @@ mod tests { protocol::{ basics::ShareKnownValue, context::{ - reshard, step::MaliciousProtocolStep::MaliciousProtocol, upgrade::Upgradable, + reshard_iter, step::MaliciousProtocolStep::MaliciousProtocol, upgrade::Upgradable, Context, ShardedContext, UpgradableContext, Validator, }, prss::SharedRandomness, @@ -830,7 +875,7 @@ mod tests { let input: Vec<_> = (0..SHARDS).map(BA8::truncate_from).collect(); let r = world .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { - reshard(ctx, shard_input, |_, record_id, _| { + reshard_iter(ctx, shard_input, |_, record_id, _| { ShardIndex::from(u32::from(record_id) % SHARDS) }) .await diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs index 0a7f94d76..7bc766917 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs @@ -18,7 +18,7 @@ use crate::{ ff::{boolean_array::BA64, U128Conversions}, helpers::{Direction, Error, Role, TotalRecords}, protocol::{ - context::{reshard, ShardedContext}, + context::{reshard_iter, ShardedContext}, prss::{FromRandom, FromRandomU128, SharedRandomness}, RecordId, }, @@ -88,7 +88,7 @@ trait ShuffleContext: ShardedContext { let data = data.into_iter(); async move { let masking_ctx = self.narrow(&ShuffleStep::Mask); - let mut resharded = assert_send(reshard( + let mut resharded = assert_send(reshard_iter( self.clone(), data.enumerate().map(|(i, item)| { // FIXME(1029): update PRSS trait to compute only left or right part From 51a8a282fb4df02fffbaf356f46ff9baeb22145f Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 9 Oct 2024 16:08:25 -0700 Subject: [PATCH 122/191] add test for reshard-stream --- ipa-core/src/protocol/context/mod.rs | 35 ++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 4e9c38e39..5b1de3505 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -571,7 +571,7 @@ pub trait DZKPContext: Context { mod tests { use std::{iter, iter::repeat}; - use futures::{future::join_all, stream::StreamExt, try_join}; + use futures::{future::join_all, stream, stream::StreamExt, try_join}; use ipa_step::StepNarrow; use rand::{ distributions::{Distribution, Standard}, @@ -588,8 +588,8 @@ mod tests { protocol::{ basics::ShareKnownValue, context::{ - reshard_iter, step::MaliciousProtocolStep::MaliciousProtocol, upgrade::Upgradable, - Context, ShardedContext, UpgradableContext, Validator, + reshard_iter, reshard_stream, step::MaliciousProtocolStep::MaliciousProtocol, + upgrade::Upgradable, Context, ShardedContext, UpgradableContext, Validator, }, prss::SharedRandomness, RecordId, @@ -867,7 +867,34 @@ mod tests { /// Ensure global record order across shards is consistent. #[test] - fn shard_picker() { + fn reshard_stream_test() { + run(|| async move { + const SHARDS: u32 = 5; + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + + let input: Vec<_> = (0..SHARDS).map(BA8::truncate_from).collect(); + let r = world + .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { + let shard_input = stream::iter(shard_input); + reshard_stream(ctx, shard_input, |_, record_id, _| { + ShardIndex::from(u32::from(record_id) % SHARDS) + }) + .await + .unwrap() + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } + + /// Ensure global record order across shards is consistent. + #[test] + fn reshard_iter_test() { run(|| async move { const SHARDS: u32 = 5; let world: TestWorld> = From f0672e5831e0e0b3b103c06cf9b33bd77404f3b8 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 9 Oct 2024 21:09:17 -0700 Subject: [PATCH 123/191] Make disable-metrics feature turn off global recorder It turns out we've been running with metrics overhead the whole time. I tested a run locally with `disable-metrics` feature and to my surprise I still saw the telemetry emitted from helper binary. Upon investigating, I came across our `Verbosity` struct that was installing the collector unconditionally. I changed that and tested that metrics are no longer emitted --- ipa-core/src/cli/metric_collector.rs | 1 + ipa-core/src/cli/verbosity.rs | 21 +++++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/ipa-core/src/cli/metric_collector.rs b/ipa-core/src/cli/metric_collector.rs index 17fd72705..881e775f5 100644 --- a/ipa-core/src/cli/metric_collector.rs +++ b/ipa-core/src/cli/metric_collector.rs @@ -31,6 +31,7 @@ pub fn install_collector() -> CollectorHandle { // register metrics crate::telemetry::metrics::register(); + tracing::info!("Metrics enabled"); CollectorHandle { snapshotter } } diff --git a/ipa-core/src/cli/verbosity.rs b/ipa-core/src/cli/verbosity.rs index 53a2bee39..068af04f5 100644 --- a/ipa-core/src/cli/verbosity.rs +++ b/ipa-core/src/cli/verbosity.rs @@ -32,24 +32,29 @@ impl Verbosity { #[must_use] pub fn setup_logging(&self) -> LoggingHandle { let filter_layer = self.log_filter(); + info!("Logging setup at level {}", filter_layer); + let fmt_layer = fmt::layer() .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) .with_ansi(std::io::stderr().is_terminal()) .with_writer(stderr); - tracing_subscriber::registry() - .with(self.log_filter()) - .with(fmt_layer) - .with(MetricsLayer::new()) - .init(); + let registry = tracing_subscriber::registry() + .with(filter_layer) + .with(fmt_layer); + + if cfg!(feature = "disable-metrics") { + registry.init(); + } else { + registry.with(MetricsLayer::new()).init(); + } let handle = LoggingHandle { - metrics_handle: (!self.quiet).then(install_collector), + metrics_handle: (!self.quiet && !cfg!(feature = "disable-metrics")) + .then(install_collector), }; set_global_panic_hook(); - info!("Logging setup at level {}", filter_layer); - handle } From 726a54936582d540795fbea30fddc6f74a5cc7c8 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 10 Oct 2024 22:54:25 -0700 Subject: [PATCH 124/191] Set up separate runtimes for HTTP and MPC queries (#1338) * Use separate runtimes for HTTP and MPC queries * Start MPC server inside HTTP runtime * Run HTTP client tasks in separate runtime * Fix compilation errors * Fix more compilation errors * Fix more compilation errors --- ipa-core/src/bin/helper.rs | 69 ++++++++++++++-- ipa-core/src/cli/playbook/mod.rs | 4 +- ipa-core/src/lib.rs | 132 ++++++++++++++++++++++++++++--- ipa-core/src/net/client/mod.rs | 69 ++++++++++------ ipa-core/src/net/error.rs | 2 +- ipa-core/src/net/server/mod.rs | 33 ++++---- ipa-core/src/net/test.rs | 22 ++++-- ipa-core/src/net/transport.rs | 45 ++++++++--- 8 files changed, 307 insertions(+), 69 deletions(-) diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 884745180..41fcb88ef 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -15,10 +15,12 @@ use ipa_core::{ }, config::{hpke_registry, HpkeServerConfig, NetworkConfig, ServerConfig, TlsConfig}, error::BoxError, + executor::IpaRuntime, helpers::HelperIdentity, net::{ClientIdentity, HttpShardTransport, HttpTransport, MpcHelperClient}, AppConfig, AppSetup, NonZeroU32PowerOfTwo, }; +use tokio::runtime::Runtime; use tracing::{error, info}; #[cfg(all(not(target_env = "msvc"), not(target_os = "macos")))] @@ -133,9 +135,12 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { private_key_file: sk_path, }); + let query_runtime = new_query_runtime(); let app_config = AppConfig::default() .with_key_registry(hpke_registry(mk_encryption.as_ref()).await?) - .with_active_work(args.active_work); + .with_active_work(args.active_work) + .with_runtime(IpaRuntime::from_tokio_runtime(&query_runtime)); + let (setup, handler) = AppSetup::new(app_config); let server_config = ServerConfig { @@ -153,9 +158,14 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { let network_config_path = args.network.as_deref().unwrap(); let network_config = NetworkConfig::from_toml_str(&fs::read_to_string(network_config_path)?)? .override_scheme(&scheme); - let clients = MpcHelperClient::from_conf(&network_config, &identity); - + let http_runtime = new_http_runtime(); + let clients = MpcHelperClient::from_conf( + &IpaRuntime::from_tokio_runtime(&http_runtime), + &network_config, + &identity, + ); let (transport, server) = HttpTransport::new( + IpaRuntime::from_tokio_runtime(&http_runtime), my_identity, server_config, network_config, @@ -183,18 +193,67 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { let (_addr, server_handle) = server .start_on( + &IpaRuntime::from_tokio_runtime(&http_runtime), listener, // TODO, trace based on the content of the query. None as Option<()>, ) .await; - server_handle.await?; + server_handle.await; + [query_runtime, http_runtime].map(Runtime::shutdown_background); Ok(()) } -#[tokio::main] +/// Creates a new runtime for HTTP stack. It is useful to provide a dedicated +/// scheduler to HTTP tasks, to make sure IPA server can respond to requests, +/// if for some reason query runtime becomes overloaded. +/// When multi-threading feature is enabled it creates a runtime with thread-per-core, +/// otherwise a single-threaded runtime is created. +fn new_http_runtime() -> Runtime { + if cfg!(feature = "multi-threading") { + tokio::runtime::Builder::new_multi_thread() + .thread_name("http-worker") + .enable_all() + .build() + .unwrap() + } else { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .thread_name("http-worker") + .enable_all() + .build() + .unwrap() + } +} + +/// This function creates a runtime suitable for executing MPC queries. +/// When multi-threading feature is enabled it creates a runtime with thread-per-core, +/// otherwise a single-threaded runtime is created. +fn new_query_runtime() -> Runtime { + // it is intentional that IO driver is not enabled here (enable_time() call only). + // query runtime is supposed to use CPU/memory only, no writes to disk and all + // network communication is handled by HTTP runtime. + if cfg!(feature = "multi-threading") { + tokio::runtime::Builder::new_multi_thread() + .thread_name("query-executor") + .enable_time() + .build() + .unwrap() + } else { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .thread_name("query-executor") + .enable_time() + .build() + .unwrap() + } +} + +/// A single thread is enough here, because server spawns additional +/// runtimes to use in MPC queries and HTTP. +#[tokio::main(flavor = "current_thread")] pub async fn main() { let args = Args::parse(); let _handle = args.logging.setup_logging(); diff --git a/ipa-core/src/cli/playbook/mod.rs b/ipa-core/src/cli/playbook/mod.rs index 135fa3117..7d0acb1c8 100644 --- a/ipa-core/src/cli/playbook/mod.rs +++ b/ipa-core/src/cli/playbook/mod.rs @@ -17,6 +17,7 @@ use tokio::time::sleep; pub use self::ipa::{playbook_oprf_ipa, run_query_and_validate}; use crate::{ config::{ClientConfig, NetworkConfig, PeerConfig}, + executor::IpaRuntime, ff::boolean_array::{BA20, BA3, BA8}, helpers::query::DpMechanism, net::{ClientIdentity, MpcHelperClient}, @@ -211,7 +212,8 @@ pub async fn make_clients( // Note: This closure is only called when the selected action uses clients. - let clients = MpcHelperClient::from_conf(&network, &ClientIdentity::None); + let clients = + MpcHelperClient::from_conf(&IpaRuntime::current(), &network, &ClientIdentity::None); while wait > 0 && !clients_ready(&clients).await { tracing::debug!("waiting for servers to come up"); sleep(Duration::from_secs(1)).await; diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 70683fff2..1ce693e01 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -70,10 +70,10 @@ pub(crate) mod rand { #[cfg(all(feature = "shuttle", test))] pub(crate) mod task { - pub use shuttle::future::{JoinError, JoinHandle}; + pub use shuttle::future::JoinError; } -#[cfg(all(feature = "multi-threading", feature = "shuttle"))] +#[cfg(feature = "shuttle")] pub(crate) mod shim { use std::any::Any; @@ -100,9 +100,16 @@ pub(crate) mod task { #[cfg(not(feature = "shuttle"))] pub mod executor { - use std::future::Future; + use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + }; - use tokio::{runtime::Handle, task::JoinHandle}; + use tokio::{ + runtime::{Handle, Runtime}, + task::JoinHandle, + }; /// In prod we use Tokio scheduler, so this struct just wraps /// its runtime handle and mimics the standard executor API. @@ -112,7 +119,8 @@ pub mod executor { pub struct IpaRuntime(Handle); /// Wrapper around Tokio's [`JoinHandle`] - pub struct IpaJoinHandle(JoinHandle); + #[pin_project::pin_project] + pub struct IpaJoinHandle(#[pin] JoinHandle); impl Default for IpaRuntime { fn default() -> Self { @@ -134,26 +142,82 @@ pub mod executor { { IpaJoinHandle(self.0.spawn(future)) } + + /// This is a convenience method to convert a Tokio runtime into + /// an IPA runtime. It does not assume ownership of the Tokio runtime. + /// The caller is responsible for ensuring the Tokio runtime is properly + /// shut down. + #[must_use] + pub fn from_tokio_runtime(rt: &Runtime) -> Self { + Self(rt.handle().clone()) + } + } + + /// allow using [`IpaRuntime`] as Hyper executor + #[cfg(feature = "web-app")] + impl hyper::rt::Executor for IpaRuntime + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + fn execute(&self, fut: Fut) { + // Dropping the handle does not terminate the task + // Clippy wants us to be explicit here. + drop(self.spawn(fut)); + } } impl IpaJoinHandle { - pub fn abort(self) { + pub fn abort(&self) { self.0.abort(); } } + + impl Future for IpaJoinHandle { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().0.poll(cx) { + Poll::Ready(Ok(v)) => Poll::Ready(v), + Poll::Ready(Err(e)) => match e.try_into_panic() { + Ok(p) => std::panic::resume_unwind(p), + Err(e) => panic!("Task is cancelled: {e:?}"), + }, + Poll::Pending => Poll::Pending, + } + } + } } #[cfg(feature = "shuttle")] pub(crate) mod executor { - use std::future::Future; + use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + }; use shuttle_crate::future::{spawn, JoinHandle}; + use crate::shim::Tokio; + /// Shuttle does not support more than one runtime /// so we always use its default #[derive(Clone, Default)] pub struct IpaRuntime; - pub struct IpaJoinHandle(JoinHandle); + #[pin_project::pin_project] + pub struct IpaJoinHandle(#[pin] JoinHandle); + + #[cfg(feature = "web-app")] + impl hyper::rt::Executor for IpaRuntime + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + fn execute(&self, fut: Fut) { + drop(self.spawn(fut)); + } + } impl IpaRuntime { #[must_use] @@ -173,10 +237,25 @@ pub(crate) mod executor { } impl IpaJoinHandle { - pub fn abort(self) { + pub fn abort(&self) { self.0.abort(); } } + + impl Future for IpaJoinHandle { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().0.poll(cx) { + Poll::Ready(Ok(v)) => Poll::Ready(v), + Poll::Ready(Err(e)) => match e.try_into_panic() { + Ok(p) => std::panic::resume_unwind(p), + Err(e) => panic!("Task is cancelled: {e:?}"), + }, + Poll::Pending => Poll::Pending, + } + } + } } #[cfg(all(feature = "shuttle", test))] @@ -265,3 +344,38 @@ macro_rules! mutually_incompatible { } mutually_incompatible!("in-memory-infra", "real-world-infra"); + +#[cfg(test)] +mod tests { + /// Tests in this module ensure both Shuttle and Tokio runtimes conform to the same API + mod executor { + use crate::{executor::IpaRuntime, test_executor::run}; + + #[test] + #[should_panic(expected = "task panicked")] + fn handle_join_panicked() { + run(|| async move { + let rt = IpaRuntime::current(); + rt.spawn(async { panic!("task panicked") }).await; + }); + } + + #[test] + /// It is nearly impossible to intentionally hang a Shuttle task. Its executor + /// detects that immediately and panics with a deadlock error. We only want to test + /// the API, so it is not that important to panic with cancellation error + #[cfg_attr(not(feature = "shuttle"), should_panic(expected = "Task is cancelled"))] + fn handle_abort() { + run(|| async move { + let rt = IpaRuntime::current(); + let handle = rt.spawn(async { + #[cfg(not(feature = "shuttle"))] + futures::future::pending::<()>().await; + }); + + handle.abort(); + handle.await; + }); + } + } +} diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 61fcfece0..37b5654f7 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -18,7 +18,7 @@ use hyper::{header::HeaderName, http::HeaderValue, Request, Response, StatusCode use hyper_rustls::{ConfigBuilderExt, HttpsConnector, HttpsConnectorBuilder}; use hyper_util::{ client::legacy::{connect::HttpConnector, Client}, - rt::{TokioExecutor, TokioTimer}, + rt::TokioTimer, }; use pin_project::pin_project; use rustls::RootCertStore; @@ -29,6 +29,7 @@ use crate::{ ClientConfig, HyperClientConfigurator, NetworkConfig, OwnedCertificate, OwnedPrivateKey, PeerConfig, }, + executor::IpaRuntime, helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, HelperIdentity, @@ -91,20 +92,22 @@ impl ClientIdentity { /// Wrapper around Hyper's [future](hyper::client::ResponseFuture) interface that keeps around /// request endpoint for nicer error messages if request fails. #[pin_project] -pub struct ResponseFuture<'a> { - authority: &'a uri::Authority, +pub struct ResponseFuture { + /// There used to be a reference here, but there is really no need for that, + /// because `uri::Authority` type uses `Bytes` internally. + authority: uri::Authority, #[pin] inner: hyper_util::client::legacy::ResponseFuture, } /// Similar to [fut](ResponseFuture), wraps the response and keeps the URI authority for better /// error messages that show where error is originated from -pub struct ResponseFromEndpoint<'a> { - authority: &'a uri::Authority, +pub struct ResponseFromEndpoint { + authority: uri::Authority, inner: Response, } -impl<'a> ResponseFromEndpoint<'a> { +impl ResponseFromEndpoint { pub fn endpoint(&self) -> String { self.authority.to_string() } @@ -117,13 +120,13 @@ impl<'a> ResponseFromEndpoint<'a> { self.inner.into_body() } - pub fn into_parts(self) -> (&'a uri::Authority, Body) { + pub fn into_parts(self) -> (uri::Authority, Body) { (self.authority, self.inner.into_body()) } } -impl<'a> Future for ResponseFuture<'a> { - type Output = Result, Error>; +impl Future for ResponseFuture { + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -132,7 +135,7 @@ impl<'a> Future for ResponseFuture<'a> { let (http_parts, http_body) = resp.into_parts(); let axum_resp = Response::from_parts(http_parts, Body::new(http_body)); Poll::Ready(Ok(ResponseFromEndpoint { - authority: this.authority, + authority: this.authority.clone(), inner: axum_resp, })) } @@ -168,10 +171,19 @@ impl MpcHelperClient { /// Authentication is not required when calling the report collector APIs. #[must_use] #[allow(clippy::missing_panics_doc)] - pub fn from_conf(conf: &NetworkConfig, identity: &ClientIdentity) -> [MpcHelperClient; 3] { - conf.peers() - .each_ref() - .map(|peer_conf| Self::new(&conf.client, peer_conf.clone(), identity.clone_with_key())) + pub fn from_conf( + runtime: &IpaRuntime, + conf: &NetworkConfig, + identity: &ClientIdentity, + ) -> [MpcHelperClient; 3] { + conf.peers().each_ref().map(|peer_conf| { + Self::new( + runtime.clone(), + &conf.client, + peer_conf.clone(), + identity.clone_with_key(), + ) + }) } /// Create a new client with the given configuration @@ -183,6 +195,7 @@ impl MpcHelperClient { /// If some aspect of the configuration is not valid. #[must_use] pub fn new( + runtime: IpaRuntime, client_config: &ClientConfig, peer_config: PeerConfig, identity: ClientIdentity, @@ -247,19 +260,27 @@ impl MpcHelperClient { None, ) }; - Self::new_internal(peer_config.url, connector, auth_header, client_config) + Self::new_internal( + runtime, + peer_config.url, + connector, + auth_header, + client_config, + ) } #[must_use] fn new_internal( + runtime: IpaRuntime, addr: Uri, connector: HttpsConnector, auth_header: Option<(HeaderName, HeaderValue)>, conf: &C, ) -> Self { - let mut builder = Client::builder(TokioExecutor::new()); + let mut builder = Client::builder(runtime); // the following timer is necessary for http2, in particular for any timeouts // and waits the clients will need to make + // TODO: implement IpaTimer to allow wrapping other than Tokio runtimes builder.timer(TokioTimer::new()); let client = conf.configure(&mut builder).build(connector); let Parts { @@ -278,12 +299,12 @@ impl MpcHelperClient { } } - pub fn request(&self, mut req: Request) -> ResponseFuture<'_> { + pub fn request(&self, mut req: Request) -> ResponseFuture { if let Some((k, v)) = self.auth_header.clone() { req.headers_mut().insert(k, v); } ResponseFuture { - authority: &self.authority, + authority: self.authority.clone(), inner: self.client.request(req), } } @@ -292,7 +313,7 @@ impl MpcHelperClient { /// /// # Errors /// If there was an error reading the response body or if the request itself failed. - pub async fn resp_ok(resp: ResponseFromEndpoint<'_>) -> Result<(), Error> { + pub async fn resp_ok(resp: ResponseFromEndpoint) -> Result<(), Error> { if resp.status().is_success() { Ok(()) } else { @@ -304,7 +325,7 @@ impl MpcHelperClient { /// /// # Errors /// If there was an error collecting the response stream. - async fn response_to_bytes(resp: ResponseFromEndpoint<'_>) -> Result { + async fn response_to_bytes(resp: ResponseFromEndpoint) -> Result { Ok(resp.into_body().collect().await?.to_bytes()) } @@ -487,8 +508,12 @@ pub(crate) mod tests { certificate: None, hpke_config: None, }; - let client = - MpcHelperClient::new(&ClientConfig::default(), peer_config, ClientIdentity::None); + let client = MpcHelperClient::new( + IpaRuntime::current(), + &ClientConfig::default(), + peer_config, + ClientIdentity::None, + ); // The server's self-signed test cert is not in the system truststore, and we didn't supply // it in the client config, so the connection should fail with a certificate error. diff --git a/ipa-core/src/net/error.rs b/ipa-core/src/net/error.rs index e97551f6f..731df19de 100644 --- a/ipa-core/src/net/error.rs +++ b/ipa-core/src/net/error.rs @@ -73,7 +73,7 @@ impl Error { /// /// # Panics /// If the response is not a failure (4xx/5xx status) - pub async fn from_failed_resp(resp: ResponseFromEndpoint<'_>) -> Self { + pub async fn from_failed_resp(resp: ResponseFromEndpoint) -> Self { let status = resp.status(); assert!(status.is_client_error() || status.is_server_error()); // must be failure let (endpoint, body) = resp.into_parts(); diff --git a/ipa-core/src/net/server/mod.rs b/ipa-core/src/net/server/mod.rs index 87d7ee2cd..5de703e0b 100644 --- a/ipa-core/src/net/server/mod.rs +++ b/ipa-core/src/net/server/mod.rs @@ -28,14 +28,12 @@ use axum_server::{ use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use futures::{ future::{ready, BoxFuture, Either, Ready}, - Future, FutureExt, + FutureExt, }; use hyper::{body::Incoming, header::HeaderName, Request}; use metrics::increment_counter; use rustls::{server::WebPkiClientVerifier, RootCertStore}; use rustls_pki_types::CertificateDer; -#[cfg(all(feature = "shuttle", test))] -use shuttle::future as tokio; use tokio_rustls::server::TlsStream; use tower::{layer::layer_fn, Service}; use tower_http::trace::TraceLayer; @@ -44,13 +42,13 @@ use tracing::{error, Span}; use crate::{ config::{NetworkConfig, OwnedCertificate, OwnedPrivateKey, ServerConfig, TlsConfig}, error::BoxError, + executor::{IpaJoinHandle, IpaRuntime}, helpers::HelperIdentity, net::{ parse_certificate_and_private_key_bytes, server::config::HttpServerConfig, Error, HttpTransport, CRYPTO_PROVIDER, }, sync::Arc, - task::JoinHandle, telemetry::metrics::{web::RequestProtocolVersion, REQUESTS_RECEIVED}, }; @@ -121,9 +119,10 @@ impl MpcHelperServer { /// configured, it must be valid.) pub async fn start_on( &self, + runtime: &IpaRuntime, listener: Option, tracing: T, - ) -> (SocketAddr, JoinHandle<()>) { + ) -> (SocketAddr, IpaJoinHandle<()>) { // This should probably come from the server config. // Note that listening on 0.0.0.0 requires accepting a MacOS security // warning on each test run. @@ -147,20 +146,27 @@ impl MpcHelperServer { let svc = svc .layer(layer_fn(SetClientIdentityFromHeader::new)) .into_make_service(); - spawn_server(axum_server::from_tcp(listener), handle.clone(), svc).await + spawn_server( + runtime, + axum_server::from_tcp(listener), + handle.clone(), + svc, + ) + .await } (true, None) => { let addr = SocketAddr::new(BIND_ADDRESS.into(), self.config.port.unwrap_or(0)); let svc = svc .layer(layer_fn(SetClientIdentityFromHeader::new)) .into_make_service(); - spawn_server(axum_server::bind(addr), handle.clone(), svc).await + spawn_server(runtime, axum_server::bind(addr), handle.clone(), svc).await } (false, Some(listener)) => { let rustls_config = rustls_config(&self.config, &self.network_config) .await .expect("invalid TLS configuration"); spawn_server( + runtime, axum_server::from_tcp_rustls(listener, rustls_config).map(|a| { ClientCertRecognizingAcceptor::new(a, self.network_config.clone()) }), @@ -175,6 +181,7 @@ impl MpcHelperServer { .await .expect("invalid TLS configuration"); spawn_server( + runtime, axum_server::bind_rustls(addr, rustls_config).map(|a| { ClientCertRecognizingAcceptor::new(a, self.network_config.clone()) }), @@ -201,30 +208,24 @@ impl MpcHelperServer { ); (bound_addr, task_handle) } - - pub fn start( - &self, - tracing: T, - ) -> impl Future)> + '_ { - self.start_on(None, tracing) - } } /// Spawns a new server with the given configuration. /// This function glues Tower, Axum, Hyper and Axum-Server together, hence the trait bounds. #[allow(clippy::unused_async)] async fn spawn_server( + runtime: &IpaRuntime, mut server: Server, handle: Handle, svc: IntoMakeService, -) -> JoinHandle<()> +) -> IpaJoinHandle<()> where A: Accept + Clone + Send + Sync + 'static, A::Stream: AsyncRead + AsyncWrite + Unpin + Send, A::Service: SendService> + Send + Service>, A::Future: Send, { - tokio::spawn({ + runtime.spawn({ async move { // Apply configuration HttpServerConfig::apply(&mut server.http_builder().http2()); diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index e6edcc0f6..85795940b 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -15,13 +15,13 @@ use std::{ use once_cell::sync::Lazy; use rustls_pki_types::CertificateDer; -use tokio::task::JoinHandle; use crate::{ config::{ ClientConfig, HpkeClientConfig, HpkeServerConfig, NetworkConfig, PeerConfig, ServerConfig, TlsConfig, }, + executor::{IpaJoinHandle, IpaRuntime}, helpers::{HandlerBox, HelperIdentity, RequestHandler}, hpke::{Deserializable as _, IpaPublicKey}, net::{ClientIdentity, HttpTransport, MpcHelperClient, MpcHelperServer}, @@ -201,7 +201,7 @@ impl TestConfigBuilder { pub struct TestServer { pub addr: SocketAddr, - pub handle: JoinHandle<()>, + pub handle: IpaJoinHandle<()>, pub transport: Arc, pub server: MpcHelperServer, pub client: MpcHelperClient, @@ -291,22 +291,34 @@ impl TestServerBuilder { else { panic!("TestConfig should have allocated ports"); }; - let clients = MpcHelperClient::from_conf(&network_config, &identity.clone_with_key()); + let clients = MpcHelperClient::from_conf( + &IpaRuntime::current(), + &network_config, + &identity.clone_with_key(), + ); let handler = self.handler.as_ref().map(HandlerBox::owning_ref); let (transport, server) = HttpTransport::new( + IpaRuntime::current(), HelperIdentity::ONE, server_config, network_config.clone(), clients, handler, ); - let (addr, handle) = server.start_on(Some(server_socket), self.metrics).await; + let (addr, handle) = server + .start_on(&IpaRuntime::current(), Some(server_socket), self.metrics) + .await; // Get the config for HelperIdentity::ONE let h1_peer_config = network_config.peers.into_iter().next().unwrap(); // At some point it might be appropriate to return two clients here -- the first being // another helper and the second being a report collector. For now we use the same client // for both types of calls. - let client = MpcHelperClient::new(&network_config.client, h1_peer_config, identity); + let client = MpcHelperClient::new( + IpaRuntime::current(), + &network_config.client, + h1_peer_config, + identity, + ); TestServer { addr, handle, diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 508bfc8d5..c053b0b8a 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -11,6 +11,7 @@ use pin_project::{pin_project, pinned_drop}; use crate::{ config::{NetworkConfig, ServerConfig}, + executor::IpaRuntime, helpers::{ query::QueryConfig, routing::{Addr, RouteId}, @@ -27,6 +28,7 @@ use crate::{ /// HTTP transport for IPA helper service. /// TODO: rename to MPC pub struct HttpTransport { + http_runtime: IpaRuntime, identity: HelperIdentity, clients: [MpcHelperClient; 3], // TODO(615): supporting multiple queries likely require a hashmap here. It will be ok if we @@ -62,23 +64,26 @@ impl RouteParams for QueryConfig { impl HttpTransport { #[must_use] pub fn new( + runtime: IpaRuntime, identity: HelperIdentity, server_config: ServerConfig, network_config: NetworkConfig, clients: [MpcHelperClient; 3], handler: Option, ) -> (Arc, MpcHelperServer) { - let transport = Self::new_internal(identity, clients, handler); + let transport = Self::new_internal(runtime, identity, clients, handler); let server = MpcHelperServer::new(Arc::clone(&transport), server_config, network_config); (transport, server) } fn new_internal( + runtime: IpaRuntime, identity: HelperIdentity, clients: [MpcHelperClient; 3], handler: Option, ) -> Arc { Arc::new(Self { + http_runtime: runtime, identity, clients, handler, @@ -195,11 +200,16 @@ impl Transport for Arc { let step = >::from(route.gate()).expect("step required when sending records"); let resp_future = self.clients[dest].step(query_id, &step, data)?; - // we don't need to spawn a task here. Gateway's sender interface already does that - // so this can just poll this future. - resp_future - .map_err(Into::into) - .and_then(MpcHelperClient::resp_ok) + + // Use a dedicated HTTP runtime to poll this future for several reasons: + // - avoid blocking this task, if the current runtime is overloaded + // - use the runtime that enables IO (current runtime may not). + self.http_runtime + .spawn( + resp_future + .map_err(Into::into) + .and_then(MpcHelperClient::resp_ok), + ) .await?; Ok(()) } @@ -383,15 +393,22 @@ mod tests { get_test_identity(id) }; let (setup, handler) = AppSetup::new(AppConfig::default()); - let clients = MpcHelperClient::from_conf(network_config, &identity); + let clients = MpcHelperClient::from_conf( + &IpaRuntime::current(), + network_config, + &identity, + ); let (transport, server) = HttpTransport::new( + IpaRuntime::current(), id, server_config, network_config.clone(), clients, Some(handler), ); - server.start_on(Some(socket), ()).await; + server + .start_on(&IpaRuntime::current(), Some(socket), ()) + .await; setup.connect(transport, HttpShardTransport) }, @@ -404,7 +421,11 @@ mod tests { } async fn test_three_helpers(mut conf: TestConfig) { - let clients = MpcHelperClient::from_conf(&conf.network, &ClientIdentity::None); + let clients = MpcHelperClient::from_conf( + &IpaRuntime::current(), + &conf.network, + &ClientIdentity::None, + ); let _helpers = make_helpers( conf.sockets.take().unwrap(), conf.servers, @@ -419,7 +440,11 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn happy_case_twice() { let mut conf = TestConfigBuilder::with_open_ports().build(); - let clients = MpcHelperClient::from_conf(&conf.network, &ClientIdentity::None); + let clients = MpcHelperClient::from_conf( + &IpaRuntime::current(), + &conf.network, + &ClientIdentity::None, + ); let _helpers = make_helpers( conf.sockets.take().unwrap(), conf.servers, From 37635403fe583184b99673ed193c40c50ea1867f Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Fri, 11 Oct 2024 10:43:53 -0700 Subject: [PATCH 125/191] add test to show that IPA fails for cap 1, 2, 3, and 4 --- ipa-core/tests/compact_gate.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ipa-core/tests/compact_gate.rs b/ipa-core/tests/compact_gate.rs index 5cc83beaf..7e31c626e 100644 --- a/ipa-core/tests/compact_gate.rs +++ b/ipa-core/tests/compact_gate.rs @@ -31,6 +31,26 @@ fn compact_gate_cap_8_no_window_semi_honest_encryped_input() { test_compact_gate(IpaSecurityModel::SemiHonest, 8, 0, true); } +#[test] +fn compact_gate_cap_1_no_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 1, 0, true); +} + +#[test] +fn compact_gate_cap_2_no_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 2, 0, true); +} + +#[test] +fn compact_gate_cap_3_no_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 3, 0, true); +} + +#[test] +fn compact_gate_cap_4_no_window_semi_honest_encryped_input() { + test_compact_gate(IpaSecurityModel::SemiHonest, 4, 0, true); +} + #[test] fn compact_gate_cap_8_no_window_semi_honest_plaintext_input() { test_compact_gate(IpaSecurityModel::SemiHonest, 8, 0, false); From f7adc8689201ad06a32b545d7b29e58294c9963d Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Fri, 11 Oct 2024 12:09:23 -0700 Subject: [PATCH 126/191] Validate uniqueness (#1342) * split out method for building HybridReport from EncryptedOprfReport * tmp commit, working on runner * tmp commit with sharding by ciphertext with attached index * add uniqueness check without sharding * convert to HybridReports * raise Error for unsupported plaintext match keys with Hybrid * refactor to UniqueBytesValidator with a more generic trait bound * clean up trait bounds from EncryptedOprfReport * removed unused VerifyCiphertextUniqueness step * update Error to match generic used for duplicate checking * split out uniqueness and decryption * make error propogate rather than panic * add comment with link to justify using 16 bits of the ciphertext * update DuplicateBytes error to count the number of checks, not successful checks * remove type hint for HashSet * update to use the Tag as the unique bytes for collision detection * add comment about test that currently panics eventually checking for expected result * refactor tests to be less copy pasta --- ipa-core/build.rs | 1 + ipa-core/src/error.rs | 2 + ipa-core/src/protocol/hybrid/mod.rs | 1 + ipa-core/src/protocol/hybrid/step.rs | 4 + ipa-core/src/protocol/mod.rs | 1 + ipa-core/src/protocol/step.rs | 2 + ipa-core/src/query/runner/hybrid.rs | 301 ++++++++++++++++++++++++++- ipa-core/src/report/hybrid.rs | 196 +++++++++++++++-- 8 files changed, 478 insertions(+), 30 deletions(-) create mode 100644 ipa-core/src/protocol/hybrid/mod.rs create mode 100644 ipa-core/src/protocol/hybrid/step.rs diff --git a/ipa-core/build.rs b/ipa-core/build.rs index 26155f794..768dc5040 100644 --- a/ipa-core/build.rs +++ b/ipa-core/build.rs @@ -16,6 +16,7 @@ track_steps!( step, }, context::step, + hybrid::step, ipa_prf::{ boolean_ops::step, prf_sharding::step, diff --git a/ipa-core/src/error.rs b/ipa-core/src/error.rs index 771cb1c0a..168827c8e 100644 --- a/ipa-core/src/error.rs +++ b/ipa-core/src/error.rs @@ -104,6 +104,8 @@ pub enum Error { }, #[error("The verification of the shuffle failed: {0}")] ShuffleValidationFailed(String), + #[error("Duplicate bytes found after {0} checks")] + DuplicateBytes(usize), } impl Default for Error { diff --git a/ipa-core/src/protocol/hybrid/mod.rs b/ipa-core/src/protocol/hybrid/mod.rs new file mode 100644 index 000000000..71ff41f4c --- /dev/null +++ b/ipa-core/src/protocol/hybrid/mod.rs @@ -0,0 +1 @@ +pub(crate) mod step; diff --git a/ipa-core/src/protocol/hybrid/step.rs b/ipa-core/src/protocol/hybrid/step.rs new file mode 100644 index 000000000..6fd7406c6 --- /dev/null +++ b/ipa-core/src/protocol/hybrid/step.rs @@ -0,0 +1,4 @@ +use ipa_step_derive::CompactStep; + +#[derive(CompactStep)] +pub(crate) enum HybridStep {} diff --git a/ipa-core/src/protocol/mod.rs b/ipa-core/src/protocol/mod.rs index 9401cec8d..28abf9741 100644 --- a/ipa-core/src/protocol/mod.rs +++ b/ipa-core/src/protocol/mod.rs @@ -2,6 +2,7 @@ pub mod basics; pub mod boolean; pub mod context; pub mod dp; +pub mod hybrid; pub mod ipa_prf; pub mod prss; pub mod step; diff --git a/ipa-core/src/protocol/step.rs b/ipa-core/src/protocol/step.rs index 8346557d2..cf3658018 100644 --- a/ipa-core/src/protocol/step.rs +++ b/ipa-core/src/protocol/step.rs @@ -7,6 +7,8 @@ pub enum ProtocolStep { Prss, #[step(child = crate::protocol::ipa_prf::step::IpaPrfStep)] IpaPrf, + #[step(child = crate::protocol::hybrid::step::HybridStep)] + Hybrid, Multiply, PrimeFieldAddition, /// Steps used in unit tests are grouped under this one. Ideally it should be diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index 73abcaf75..6ff7ed7f2 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -1,26 +1,39 @@ use std::{marker::PhantomData, sync::Arc}; +use futures::{future, stream::iter, StreamExt, TryStreamExt}; + use crate::{ error::Error, + ff::boolean_array::{BA20, BA3, BA8}, helpers::{ query::{HybridQueryParams, QuerySize}, - BodyStream, + BodyStream, LengthDelimitedStream, }, hpke::PrivateKeyRegistry, + protocol::{context::UpgradableContext, ipa_prf::shuffle::Shuffle, step::ProtocolStep::Hybrid}, + report::hybrid::{EncryptedHybridReport, UniqueBytesValidator}, secret_sharing::{replicated::semi_honest::AdditiveShare as ReplicatedShare, SharedValue}, }; +pub type BreakdownKey = BA8; +pub type Value = BA3; +// TODO: remove this when encryption/decryption works for HybridReports +pub type Timestamp = BA20; + pub struct Query { - _config: HybridQueryParams, - _key_registry: Arc, + config: HybridQueryParams, + key_registry: Arc, phantom_data: PhantomData<(C, HV)>, } -impl Query { +impl Query +where + C: UpgradableContext + Shuffle, +{ pub fn new(query_params: HybridQueryParams, key_registry: Arc) -> Self { Self { - _config: query_params, - _key_registry: key_registry, + config: query_params, + key_registry, phantom_data: PhantomData, } } @@ -28,10 +41,280 @@ impl Query { #[tracing::instrument("hybrid_query", skip_all, fields(sz=%query_size))] pub async fn execute( self, - _ctx: C, + ctx: C, query_size: QuerySize, - _input_stream: BodyStream, + input_stream: BodyStream, ) -> Result>, Error> { - unimplemented!() + let Self { + config, + key_registry, + phantom_data: _, + } = self; + tracing::info!("New hybrid query: {config:?}"); + let _ctx = ctx.narrow(&Hybrid); + let sz = usize::from(query_size); + + let mut unique_encrypted_hybrid_reports = UniqueBytesValidator::new(sz); + + if config.plaintext_match_keys { + return Err(Error::Unsupported( + "Hybrid queries do not currently support plaintext match keys".to_string(), + )); + } + + let _input = LengthDelimitedStream::::new(input_stream) + .map_err(Into::::into) + .and_then(|enc_reports| { + future::ready( + unique_encrypted_hybrid_reports + .check_duplicates(&enc_reports) + .map(|()| enc_reports) + .map_err(Into::::into), + ) + }) + .map_ok(|enc_reports| { + iter(enc_reports.into_iter().map({ + |enc_report| { + enc_report + .decrypt::(key_registry.as_ref()) + .map_err(Into::::into) + } + })) + }) + .try_flatten() + .take(sz) + .try_collect::>() + .await?; + + unimplemented!("query::runnner::HybridQuery.execute is not fully implemented") + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::{iter::zip, sync::Arc}; + + use rand::rngs::StdRng; + use rand_core::SeedableRng; + + use crate::{ + ff::{ + boolean_array::{BA16, BA20, BA3, BA8}, + U128Conversions, + }, + helpers::{ + query::{HybridQueryParams, QuerySize}, + BodyStream, + }, + hpke::{KeyPair, KeyRegistry}, + query::runner::HybridQuery, + report::{OprfReport, DEFAULT_KEY_ID}, + secret_sharing::IntoShares, + test_fixture::{ipa::TestRawDataRecord, join3v, Reconstruct, TestWorld}, + }; + + const EXPECTED: &[u128] = &[0, 8, 5]; + + fn build_records() -> Vec { + // TODO: When Encryption/Decryption exists for HybridReports + // update these to use that, rather than generating OprfReports + vec![ + TestRawDataRecord { + timestamp: 0, + user_id: 12345, + is_trigger_report: false, + breakdown_key: 2, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 4, + user_id: 68362, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 10, + user_id: 12345, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 5, + }, + TestRawDataRecord { + timestamp: 12, + user_id: 68362, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 2, + }, + TestRawDataRecord { + timestamp: 20, + user_id: 68362, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 30, + user_id: 68362, + is_trigger_report: true, + breakdown_key: 1, + trigger_value: 7, + }, + ] + } + + struct BufferAndKeyRegistry { + buffers: [Vec; 3], + key_registry: Arc>, + } + + fn build_buffers_from_records(records: &[TestRawDataRecord]) -> BufferAndKeyRegistry { + let mut rng = StdRng::seed_from_u64(42); + let key_id = DEFAULT_KEY_ID; + let key_registry = Arc::new(KeyRegistry::::random(1, &mut rng)); + + let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new()); + + let shares: [Vec>; 3] = records.iter().cloned().share(); + for (buf, shares) in zip(&mut buffers, shares) { + for share in shares { + share + .delimited_encrypt_to(key_id, key_registry.as_ref(), &mut rng, buf) + .unwrap(); + } + } + BufferAndKeyRegistry { + buffers, + key_registry, + } + } + + #[tokio::test] + // placeholder until the protocol is complete. can be updated to make sure we + // get to the unimplemented() call + #[should_panic( + expected = "not implemented: query::runnner::HybridQuery.execute is not fully implemented" + )] + async fn encrypted_hybrid_reports() { + // While this test currently checks for an unimplemented panic it is + // designed to test for a correct result for a complete implementation. + + let records = build_records(); + let query_size = QuerySize::try_from(records.len()).unwrap(); + + let BufferAndKeyRegistry { + buffers, + key_registry, + } = build_buffers_from_records(&records); + + let world = TestWorld::default(); + let contexts = world.contexts(); + #[allow(clippy::large_futures)] + let results = join3v(buffers.into_iter().zip(contexts).map(|(buffer, ctx)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: false, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + })) + .await; + + assert_eq!( + results.reconstruct()[0..3] + .iter() + .map(U128Conversions::as_u128) + .collect::>(), + EXPECTED + ); + } + + // cannot test for Err directly because join3v calls unwrap. This should be sufficient. + #[tokio::test] + #[should_panic(expected = "DuplicateBytes(3)")] + async fn duplicate_encrypted_hybrid_reports() { + let all_records = build_records(); + let records = &all_records[..2].to_vec(); + + let BufferAndKeyRegistry { + mut buffers, + key_registry, + } = build_buffers_from_records(records); + + // this is double, since we duplicate the data below + let query_size = QuerySize::try_from(records.len() * 2).unwrap(); + + // duplicate all the data + for buffer in &mut buffers { + let original = buffer.clone(); + buffer.extend(original); + } + + let world = TestWorld::default(); + let contexts = world.contexts(); + #[allow(clippy::large_futures)] + let _results = join3v(buffers.into_iter().zip(contexts).map(|(buffer, ctx)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: false, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + })) + .await; + } + + // cannot test for Err directly because join3v calls unwrap. This should be sufficient. + #[tokio::test] + #[should_panic( + expected = "Unsupported(\"Hybrid queries do not currently support plaintext match keys\")" + )] + async fn unsupported_plaintext_match_keys_hybrid_query() { + let all_records = build_records(); + let records = &all_records[..2].to_vec(); + let query_size = QuerySize::try_from(records.len()).unwrap(); + + let BufferAndKeyRegistry { + buffers, + key_registry, + } = build_buffers_from_records(records); + + let world = TestWorld::default(); + let contexts = world.contexts(); + #[allow(clippy::large_futures)] + let _results = join3v(buffers.into_iter().zip(contexts).map(|(buffer, ctx)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: true, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + })) + .await; } } diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index c5e4703a1..d322f92ae 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -1,11 +1,14 @@ -use std::ops::{Add, Deref}; +use std::{collections::HashSet, ops::Add}; +use bytes::Bytes; use generic_array::ArrayLength; -use typenum::{Sum, U16}; +use rand_core::{CryptoRng, RngCore}; +use typenum::{Sum, Unsigned, U16}; use crate::{ + error::Error, ff::{boolean_array::BA64, Serializable}, - hpke::PrivateKeyRegistry, + hpke::{EncapsulationSize, PrivateKeyRegistry, PublicKeyRegistry, TagSize}, report::{EncryptedOprfReport, EventType, InvalidReportError, KeyIdentifier}, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, SharedValue}, }; @@ -65,13 +68,35 @@ where BK: SharedValue, V: SharedValue, { + /// # Errors + /// If there is a problem encrypting the report. + pub fn encrypt( + &self, + _key_id: KeyIdentifier, + _key_registry: &impl PublicKeyRegistry, + _rng: &mut R, + ) -> Result, InvalidReportError> { + unimplemented!() + } +} + +#[derive(Clone)] +pub struct EncryptedHybridReport { + bytes: Bytes, +} + +impl EncryptedHybridReport { /// ## Errors - /// If the report contents are invalid. - pub fn from_bytes(data: B, key_registry: &P) -> Result + /// If the report fails to decrypt + pub fn decrypt( + &self, + key_registry: &P, + ) -> Result, InvalidReportError> where P: PrivateKeyRegistry, - B: Deref, - TS: SharedValue, // this is only needed for the backport from EncryptedOprfReport + BK: SharedValue, + V: SharedValue, + TS: SharedValue, Replicated: Serializable, Replicated: Serializable, Replicated: Serializable, @@ -90,19 +115,95 @@ where U16, >: ArrayLength, { - let encrypted_oprf_report = EncryptedOprfReport::::from_bytes(data)?; + let encrypted_oprf_report = + EncryptedOprfReport::::try_from(self.bytes.clone())?; let oprf_report = encrypted_oprf_report.decrypt(key_registry)?; match oprf_report.event_type { - EventType::Source => Ok(Self::Impression(HybridImpressionReport { + EventType::Source => Ok(HybridReport::Impression(HybridImpressionReport { match_key: oprf_report.match_key, breakdown_key: oprf_report.breakdown_key, })), - EventType::Trigger => Ok(Self::Conversion(HybridConversionReport { + EventType::Trigger => Ok(HybridReport::Conversion(HybridConversionReport { match_key: oprf_report.match_key, value: oprf_report.trigger_value, })), } } + + /// TODO: update these when we produce a proper encapsulation of + /// `EncryptedHybridReport`, rather than pigggybacking on `EncryptedOprfReport` + pub fn mk_ciphertext(&self) -> &[u8] { + let encap_key_mk_offset: usize = 0; + let ciphertext_mk_offset: usize = encap_key_mk_offset + EncapsulationSize::USIZE; + let encap_key_btt_offset: usize = + ciphertext_mk_offset + TagSize::USIZE + as Serializable>::Size::USIZE; + + &self.bytes[ciphertext_mk_offset..encap_key_btt_offset] + } +} + +impl TryFrom for EncryptedHybridReport { + type Error = InvalidReportError; + + fn try_from(bytes: Bytes) -> Result { + Ok(EncryptedHybridReport { bytes }) + } +} + +pub trait UniqueBytes { + fn unique_bytes(&self) -> Vec; +} + +impl UniqueBytes for EncryptedHybridReport { + /// We use the `TagSize` (the first 16 bytes of the ciphertext) for collision-detection + /// See [analysis here for uniqueness](https://eprint.iacr.org/2019/624) + fn unique_bytes(&self) -> Vec { + self.mk_ciphertext()[0..TagSize::USIZE].to_vec() + } +} + +#[derive(Debug)] +pub struct UniqueBytesValidator { + hash_set: HashSet>, + check_counter: usize, +} + +impl UniqueBytesValidator { + #[must_use] + pub fn new(size: usize) -> Self { + UniqueBytesValidator { + hash_set: HashSet::with_capacity(size), + check_counter: 0, + } + } + + fn insert(&mut self, value: Vec) -> bool { + self.hash_set.insert(value) + } + + /// Checks that item is unique among all checked thus far + /// + /// ## Errors + /// if the item inserted is not unique among all checked thus far + pub fn check_duplicate(&mut self, item: &U) -> Result<(), Error> { + self.check_counter += 1; + if self.insert(item.unique_bytes()) { + Ok(()) + } else { + Err(Error::DuplicateBytes(self.check_counter)) + } + } + + /// Checks that an iter of items is unique among the iter and any other items checked thus far + /// + /// ## Errors + /// if the and item inserted is not unique among all in this batch and checked previously + pub fn check_duplicates(&mut self, items: &[U]) -> Result<(), Error> { + items + .iter() + .try_for_each(|item| self.check_duplicate(item))?; + Ok(()) + } } #[cfg(test)] @@ -110,8 +211,12 @@ mod test { use rand::{distributions::Alphanumeric, rngs::ThreadRng, thread_rng, Rng}; - use super::{HybridConversionReport, HybridImpressionReport, HybridReport}; + use super::{ + EncryptedHybridReport, HybridConversionReport, HybridImpressionReport, HybridReport, + UniqueBytes, UniqueBytesValidator, + }; use crate::{ + error::Error, ff::boolean_array::{BA20, BA3, BA8}, hpke::{KeyPair, KeyRegistry}, report::{EventType, OprfReport}, @@ -134,6 +239,13 @@ mod test { } } + fn generate_random_bytes(size: usize) -> Vec { + let mut rng = thread_rng(); + let mut bytes = vec![0u8; size]; + rng.fill(&mut bytes[..]); + bytes + } + #[test] fn convert_to_hybrid_impression_report() { let mut rng = thread_rng(); @@ -152,11 +264,13 @@ mod test { let enc_report_bytes = oprf_report .encrypt(key_id, &key_registry, &mut rng) .unwrap(); - let hybrid_report2 = HybridReport::::from_bytes::<_, _, BA20>( - enc_report_bytes.as_slice(), - &key_registry, - ) - .unwrap(); + let enc_report = EncryptedHybridReport { + bytes: enc_report_bytes.into(), + }; + + let hybrid_report2 = enc_report + .decrypt::<_, BA8, BA3, BA20>(&key_registry) + .unwrap(); assert_eq!(hybrid_report, hybrid_report2); } @@ -179,12 +293,52 @@ mod test { let enc_report_bytes = oprf_report .encrypt(key_id, &key_registry, &mut rng) .unwrap(); - let hybrid_report2 = HybridReport::::from_bytes::<_, _, BA20>( - enc_report_bytes.as_slice(), - &key_registry, - ) - .unwrap(); + let enc_report = EncryptedHybridReport { + bytes: enc_report_bytes.into(), + }; + let hybrid_report2 = enc_report + .decrypt::<_, BA8, BA3, BA20>(&key_registry) + .unwrap(); assert_eq!(hybrid_report, hybrid_report2); } + + #[test] + fn unique_encrypted_hybrid_reports() { + #[derive(Clone)] + pub struct UniqueByteHolder { + bytes: Vec, + } + + impl UniqueByteHolder { + pub fn new(size: usize) -> Self { + let bytes = generate_random_bytes(size); + UniqueByteHolder { bytes } + } + } + + impl UniqueBytes for UniqueByteHolder { + fn unique_bytes(&self) -> Vec { + self.bytes.clone() + } + } + + let bytes1 = UniqueByteHolder::new(4); + let bytes2 = UniqueByteHolder::new(4); + let bytes3 = UniqueByteHolder::new(4); + let bytes4 = UniqueByteHolder::new(4); + + let mut unique_bytes = UniqueBytesValidator::new(4); + + unique_bytes.check_duplicate(&bytes1).unwrap(); + + unique_bytes + .check_duplicates(&[bytes2.clone(), bytes3.clone()]) + .unwrap(); + let expected_err = unique_bytes.check_duplicate(&bytes2); + assert!(matches!(expected_err, Err(Error::DuplicateBytes(4)))); + + let expected_err = unique_bytes.check_duplicates(&[bytes4, bytes3]); + assert!(matches!(expected_err, Err(Error::DuplicateBytes(6)))); + } } From 41b057c0c77bf2b2de0b875556f207c1d4a5f708 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 11 Oct 2024 13:05:23 -0700 Subject: [PATCH 127/191] Proof batching for breakdown reveal aggregation (#1323) * Unlimited batch size for reveal aggregation But avoid excessive memory allocations * Batching for breakdown reveal aggregation * Reduce TARGET_PROOF_SIZE for tests * Add a test for growing `pending_records` * More large batch tests * Compact gate fixes * Fixing step count blow-up * More test fixes * Fix a bug and adjust the test to catch it. * Keep semi-honest for shuttle, don't do shuttle for malicious * Optimize vec_chunks --------- Co-authored-by: Alex Koshelev --- ipa-core/src/protocol/basics/reveal.rs | 48 +++-- ipa-core/src/protocol/context/batcher.rs | 78 +++++++- .../src/protocol/context/dzkp_malicious.rs | 2 +- .../src/protocol/context/dzkp_validator.rs | 85 +++++++- ipa-core/src/protocol/dp/mod.rs | 2 +- .../ipa_prf/aggregation/breakdown_reveal.rs | 189 ++++++++++++++---- .../src/protocol/ipa_prf/aggregation/mod.rs | 42 ++-- .../src/protocol/ipa_prf/aggregation/step.rs | 13 +- ipa-core/src/protocol/ipa_prf/mod.rs | 5 +- .../prf_sharding/feature_label_dot_product.rs | 2 +- .../src/protocol/ipa_prf/prf_sharding/mod.rs | 23 +-- .../src/protocol/ipa_prf/prf_sharding/step.rs | 2 - ipa-core/src/query/runner/oprf_ipa.rs | 3 +- ipa-core/src/utils/mod.rs | 1 + ipa-core/src/utils/vec_chunks.rs | 58 ++++++ 15 files changed, 434 insertions(+), 119 deletions(-) create mode 100644 ipa-core/src/utils/vec_chunks.rs diff --git a/ipa-core/src/protocol/basics/reveal.rs b/ipa-core/src/protocol/basics/reveal.rs index 19363e1af..2344896a8 100644 --- a/ipa-core/src/protocol/basics/reveal.rs +++ b/ipa-core/src/protocol/basics/reveal.rs @@ -1,3 +1,27 @@ +// Several of the reveal impls use distinct type parameters for the value being revealed +// and the context-assiciated field. +// +// For MAC, this takes the form of distinct `V` and `CtxF` type parameters. For DZKP, +// this takes the form of a `V` type parameter different from the implicit `Boolean` +// used by the context. +// +// This decoupling is needed to support: +// +// 1. The PRF evaluation protocol, which uses `Fp25519` for the malicious context, but +// needs to reveal `RP25519` values. +// 2. The breakdown reveal aggregation protocol, which uses `Boolean` for the malicious +// context, but needs to reveal `BK` values. +// +// The malicious reveal protocol must check the shares being revealed for consistency, +// but doesn't care that they are in the same field as is used for the malicious +// context. Contrast with multiplication, which can only be supported in the malicious +// context's field. +// +// It also doesn't matter that `V` and `CtxF` support the same vectorization dimension +// `N`, but the compiler would not be able to infer the value of a decoupled +// vectorization dimension for `CtxF` from context, so it's easier to make them the same +// absent a need for them to be different. + use std::{ future::Future, iter::{repeat, zip}, @@ -8,7 +32,6 @@ use futures::{FutureExt, TryFutureExt}; use crate::{ error::Error, - ff::boolean::Boolean, helpers::{Direction, MaybeFuture, Role}, protocol::{ boolean::step::TwoHundredFiftySixBitOpStep, @@ -170,8 +193,6 @@ where } } -// Like the impl for `UpgradedMaliciousContext`, this impl uses distinct `V` and `CtxF` type -// parameters. See the comment on that impl for more details. impl<'a, B, V, CtxF, const N: usize> Reveal> for Replicated where @@ -194,12 +215,12 @@ where } } -impl<'a, B, const N: usize> Reveal> for Replicated +impl<'a, V, B, const N: usize> Reveal> for Replicated where B: ShardBinding, - Boolean: Vectorizable, + V: SharedValue + Vectorizable, { - type Output = >::Array; + type Output = >::Array; async fn generic_reveal<'fut>( &'fut self, @@ -270,15 +291,6 @@ where } } -// This impl uses distinct `V` and `CtxF` type parameters to support the PRF evaluation protocol, -// which uses `Fp25519` for the malicious context, but needs to reveal `RP25519` values. The -// malicious reveal protocol must check the shares being revealed for consistency, but doesn't care -// that they are in the same field as is used for the malicious context. Contrast with -// multiplication, which can only be supported in the malicious context's field. -// -// It also doesn't matter that `V` and `CtxF` support the same vectorization dimension `N`, but the -// compiler would not be able to infer the value of a decoupled vectorization dimension for `CtxF` -// from context, so it's easier to make them the same absent a need for them to be different. impl<'a, V, const N: usize, CtxF> Reveal> for Replicated where CtxF: ExtendableField, @@ -321,12 +333,12 @@ where } } -impl<'a, B, const N: usize> Reveal> for Replicated +impl<'a, V, B, const N: usize> Reveal> for Replicated where B: ShardBinding, - Boolean: Vectorizable, + V: SharedValue + Vectorizable, { - type Output = >::Array; + type Output = >::Array; async fn generic_reveal<'fut>( &'fut self, diff --git a/ipa-core/src/protocol/context/batcher.rs b/ipa-core/src/protocol/context/batcher.rs index cdbfddbce..974968bcf 100644 --- a/ipa-core/src/protocol/context/batcher.rs +++ b/ipa-core/src/protocol/context/batcher.rs @@ -3,7 +3,12 @@ use std::{cmp::min, collections::VecDeque, future::Future}; use bitvec::{bitvec, prelude::BitVec}; use tokio::sync::watch; -use crate::{error::Error, helpers::TotalRecords, protocol::RecordId, sync::Mutex}; +use crate::{ + error::Error, + helpers::TotalRecords, + protocol::{context::dzkp_validator::TARGET_PROOF_SIZE, RecordId}, + sync::Mutex, +}; /// Manages validation of batches of records for malicious protocols. /// @@ -111,13 +116,14 @@ impl<'a, B> Batcher<'a, B> { fn get_batch_by_offset(&mut self, batch_offset: usize) -> &mut BatchState { if self.batches.len() <= batch_offset { self.batches.reserve(batch_offset - self.batches.len() + 1); + let pending_records_capacity = self.records_per_batch.min(TARGET_PROOF_SIZE); while self.batches.len() <= batch_offset { let (validation_result, _) = watch::channel::(false); let state = BatchState { batch: (self.batch_constructor)(self.first_batch + self.batches.len()), validation_result, pending_count: 0, - pending_records: bitvec![0; self.records_per_batch], + pending_records: bitvec![0; pending_records_capacity], }; self.batches.push_back(Some(state)); } @@ -152,10 +158,16 @@ impl<'a, B> Batcher<'a, B> { let total_count = min(self.records_per_batch, remaining_records); let record_offset_in_batch = usize::from(record_id) - first_record_in_batch; let batch = self.get_batch_by_offset(batch_offset); - assert!( - !batch.pending_records[record_offset_in_batch], - "validate_record called twice for record {record_id}", - ); + if batch.pending_records.len() <= record_offset_in_batch { + batch + .pending_records + .resize(record_offset_in_batch + 1, false); + } else { + assert!( + !batch.pending_records[record_offset_in_batch], + "validate_record called twice for record {record_id}", + ); + } // This assertion is stricter than the bounds check in `BitVec::set` when the // batch size is not a multiple of 8, or for a partial final batch. assert!( @@ -273,7 +285,10 @@ impl<'a, B> Batcher<'a, B> { mod tests { use std::{future::ready, pin::pin}; - use futures::future::{poll_immediate, try_join, try_join3, try_join4}; + use futures::{ + future::{join_all, poll_immediate, try_join, try_join3, try_join4}, + FutureExt, + }; use super::*; @@ -553,6 +568,55 @@ mod tests { )); } + #[tokio::test] + async fn large_batch() { + // This test exercises the case where the preallocated size of `pending_records` + // was limited to `TARGET_PROOF_SIZE`, and we need to grow it alter. + let batcher = Batcher::new( + TARGET_PROOF_SIZE + 1, + TotalRecords::specified(TARGET_PROOF_SIZE + 1).unwrap(), + Box::new(|_| Vec::new()), + ); + + let mut futs = (0..TARGET_PROOF_SIZE) + .map(|i| { + batcher + .lock() + .unwrap() + .get_batch(RecordId::from(i)) + .batch + .push(i); + batcher + .lock() + .unwrap() + .validate_record(RecordId::from(i), |_i, _b| async { unreachable!() }) + .map(Result::unwrap) + .boxed() + }) + .collect::>(); + + batcher + .lock() + .unwrap() + .get_batch(RecordId::from(TARGET_PROOF_SIZE)) + .batch + .push(TARGET_PROOF_SIZE); + futs.push( + batcher + .lock() + .unwrap() + .validate_record(RecordId::from(TARGET_PROOF_SIZE), |i, b| { + assert!(i == 0 && b.as_slice() == (0..=TARGET_PROOF_SIZE).collect::>()); + ready(Ok(())) + }) + .map(Result::unwrap) + .boxed(), + ); + join_all(futs).await; + + assert!(batcher.lock().unwrap().is_empty()); + } + #[test] fn into_single_batch() { let batcher = Batcher::new(2, TotalRecords::Unspecified, Box::new(|_| Vec::new())); diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 596338e83..6511663d7 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -38,7 +38,7 @@ impl<'a, B: ShardBinding> DZKPUpgraded<'a, B> { base_ctx: MaliciousContext<'a, B>, ) -> Self { let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch(); - let active_work = if records_per_batch == 1 { + let active_work = if records_per_batch == 1 || records_per_batch == usize::MAX { // If records_per_batch is 1, let active_work be anything. This only happens // in tests; there shouldn't be a risk of deadlocks with one record per // batch; and UnorderedReceiver capacity (which is set from active_work) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 1354f9447..70b963fd3 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -34,6 +34,22 @@ const BIT_ARRAY_LEN: usize = 256; const BIT_ARRAY_MASK: usize = BIT_ARRAY_LEN - 1; const BIT_ARRAY_SHIFT: usize = BIT_ARRAY_LEN.ilog2() as usize; +// The target size of a zero-knowledge proof, in GF(2) multiplies. Seven intermediate +// values are stored for each multiply, so the amount memory required is 7 times this +// value. +// +// To enable computing a read size for `OrdereringSender` that achieves good network +// utilization, the number of records in a proof must be a power of two. Protocols +// typically compute the size of a proof batch by dividing TARGET_PROOF_SIZE by +// an approximate number of multiplies per record, and then rounding up to a power +// of two. Thus, it is not necessary for TARGET_PROOF_SIZE to be a power of two. +// +// A smaller value is used for tests, to enable covering some corner cases with a +// reasonable runtime. Some of these tests use TARGET_PROOF_SIZE directly, so for tests +// it does need to be a power of two. +#[cfg(test)] +pub const TARGET_PROOF_SIZE: usize = 8192; +#[cfg(not(test))] pub const TARGET_PROOF_SIZE: usize = 50_000_000; /// `MultiplicationInputsBlock` is a block of fixed size of intermediate values @@ -257,7 +273,7 @@ impl MultiplicationInputsBatch { // records. let capacity_bits = usize::min( TARGET_PROOF_SIZE, - max_multiplications * multiplication_bit_size, + max_multiplications.saturating_mul(multiplication_bit_size), ); Self { first_record, @@ -295,7 +311,7 @@ impl MultiplicationInputsBatch { // panics when record_id is out of bounds assert!(record_id >= self.first_record); assert!( - record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record)), + usize::from(record_id) < self.max_multiplications + usize::from(self.first_record), "record_id out of range in insert_segment. record {record_id} is beyond \ segment of length {} starting at {}", self.max_multiplications, @@ -326,9 +342,7 @@ impl MultiplicationInputsBatch { // panics when record_id is out of bounds assert!(record_id >= self.first_record); - assert!( - record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record)) - ); + assert!(usize::from(record_id) < self.max_multiplications + usize::from(self.first_record)); // panics when record_id is less than first_record let id_within_batch = usize::from(record_id) - usize::from(self.first_record); @@ -377,9 +391,7 @@ impl MultiplicationInputsBatch { // panics when record_id is out of bounds assert!(record_id >= self.first_record); - assert!( - record_id < RecordId::from(self.max_multiplications + usize::from(self.first_record)) - ); + assert!(usize::from(record_id) < self.max_multiplications + usize::from(self.first_record)); let id_within_batch = usize::from(record_id) - usize::from(self.first_record); let block_id = (segment.len() * id_within_batch) >> BIT_ARRAY_SHIFT; @@ -866,7 +878,7 @@ mod tests { replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, Vectorizable, }, - seq_join::seq_join, + seq_join::{seq_join, SeqJoin}, sharding::NotSharded, test_fixture::{join3v, Reconstruct, Runner, TestWorld}, }; @@ -1254,6 +1266,61 @@ mod tests { } } + #[tokio::test] + async fn large_batch() { + multi_select_malicious::(2 * TARGET_PROOF_SIZE, 2 * TARGET_PROOF_SIZE).await; + } + + // Similar to multi_select_malicious, but instead of using `validated_seq_join`, passes + // `usize::MAX` as the batch size and does a single `v.validate()`. + #[tokio::test] + async fn large_single_batch() { + let count: usize = TARGET_PROOF_SIZE + 1; + let mut rng = thread_rng(); + + let bit: Vec = repeat_with(|| rng.gen::()).take(count).collect(); + let a: Vec = repeat_with(|| rng.gen()).take(count).collect(); + let b: Vec = repeat_with(|| rng.gen()).take(count).collect(); + + let [ab0, ab1, ab2]: [Vec>; 3] = TestWorld::default() + .malicious( + zip(bit.clone(), zip(a.clone(), b.clone())), + |ctx, inputs| async move { + let v = ctx + .set_total_records(count) + .dzkp_validator(TEST_DZKP_STEPS, usize::MAX); + let m_ctx = v.context(); + + let result = seq_join( + m_ctx.active_work(), + stream::iter(inputs).enumerate().map( + |(i, (bit_share, (a_share, b_share)))| { + let m_ctx = m_ctx.clone(); + async move { + select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share) + .await + } + }, + ), + ) + .try_collect() + .await + .unwrap(); + + v.validate().await.unwrap(); + + result + }, + ) + .await; + + let ab: Vec = [ab0, ab1, ab2].reconstruct(); + + for i in 0..count { + assert_eq!(ab[i], if bit[i].into() { a[i] } else { b[i] }); + } + } + #[tokio::test] #[should_panic(expected = "ContextUnsafe(\"DZKPMaliciousContext\")")] async fn missing_validate() { diff --git a/ipa-core/src/protocol/dp/mod.rs b/ipa-core/src/protocol/dp/mod.rs index bb847eabd..b2d8b55e1 100644 --- a/ipa-core/src/protocol/dp/mod.rs +++ b/ipa-core/src/protocol/dp/mod.rs @@ -171,7 +171,7 @@ where let aggregation_input = Box::pin(stream::iter(vector_input_to_agg.into_iter()).map(Ok)); // Step 3: Call `aggregate_values` to sum up Bernoulli noise. let noise_vector: Result>, Error> = - aggregate_values::<_, OV, B>(ctx, aggregation_input, num_bernoulli).await; + aggregate_values::<_, OV, B>(ctx, aggregation_input, num_bernoulli, None).await; noise_vector } /// `apply_dp_noise` takes the noise distribution parameters (`num_bernoulli` and in the future `quantization_scale`) diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index a643d397d..6ff1b19c9 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -1,12 +1,9 @@ -use std::{ - convert::Infallible, - pin::{pin, Pin}, -}; +use std::{convert::Infallible, mem, pin::pin}; -use futures::{stream, Stream}; +use futures::stream; use futures_util::{StreamExt, TryStreamExt}; -use super::{aggregate_values, AggResult}; +use super::aggregate_values; use crate::{ error::{Error, UnwrapInfallible}, ff::{ @@ -16,10 +13,15 @@ use crate::{ }, helpers::TotalRecords, protocol::{ - basics::semi_honest_reveal, - context::Context, + basics::{reveal, Reveal}, + context::{ + dzkp_validator::DZKPValidator, Context, DZKPUpgraded, MaliciousProtocolSteps, + UpgradableContext, + }, ipa_prf::{ - aggregation::step::AggregationStep, + aggregation::{ + aggregate_values_proof_chunk, step::AggregationStep as Step, AGGREGATE_DEPTH, + }, oprf_padding::{apply_dp_padding, PaddingParameters}, prf_sharding::{AttributionOutputs, SecretSharedAttributionOutputs}, shuffle::shuffle_attribution_outputs, @@ -29,9 +31,10 @@ use crate::{ }, secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, FieldSimd, - TransposeFrom, + TransposeFrom, Vectorizable, }, seq_join::seq_join, + utils::vec_chunks::vec_chunks, }; /// Improved Aggregation a.k.a Aggregation revealing breakdown. @@ -55,10 +58,11 @@ pub async fn breakdown_reveal_aggregation( padding_params: &PaddingParameters, ) -> Result>, Error> where - C: Context, + C: UpgradableContext, Boolean: FieldSimd, - Replicated: BooleanProtocols, + Replicated: BooleanProtocols, B>, BK: BreakdownKey, + Replicated: Reveal, Output = >::Array>, TV: BooleanArray + U128Conversions, HV: BooleanArray + U128Conversions, BitDecomposed>: @@ -67,17 +71,65 @@ where // Apply DP padding for Breakdown Reveal Aggregation let attributed_values_padded = apply_dp_padding::<_, AttributionOutputs, Replicated>, B>( - ctx.narrow(&AggregationStep::PaddingDp), + ctx.narrow(&Step::PaddingDp), attributed_values, padding_params, ) .await?; - let attributions = shuffle_attributions(&ctx, attributed_values_padded).await?; - let grouped_tvs = reveal_breakdowns(&ctx, attributions).await?; - let num_rows = grouped_tvs.max_len; - let ctx = ctx.narrow(&AggregationStep::SumContributions); - aggregate_values::<_, HV, B>(ctx, grouped_tvs.into_stream(), num_rows).await + let attributions = shuffle_attributions::<_, BK, TV, B>(&ctx, attributed_values_padded).await?; + // Revealing the breakdowns doesn't do any multiplies, so won't make it as far as + // doing a proof, but we need the validator to obtain an upgraded malicious context. + let validator = ctx.clone().dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::Reveal, + validate: &Step::RevealValidate, + }, + usize::MAX, + ); + let grouped_tvs = reveal_breakdowns(&validator.context(), attributions).await?; + validator.validate().await?; + let mut intermediate_results: Vec>> = grouped_tvs.into(); + + // Any real-world aggregation should be able to complete in two layers (two + // iterations of the `while` loop below). Tests with small `TARGET_PROOF_SIZE` + // may exceed that. + let mut chunk_counter = 0; + let mut depth = 0; + let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()); + + while intermediate_results.len() > 1 { + let mut record_ids = [RecordId::FIRST; AGGREGATE_DEPTH]; + for chunk in vec_chunks(mem::take(&mut intermediate_results), agg_proof_chunk) { + let chunk_len = chunk.len(); + let validator = ctx.clone().dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::aggregate(depth), + validate: &Step::aggregate_validate(chunk_counter), + }, + // We have to specify usize::MAX here because the procession through + // record IDs is different at each step of the reduction. The batch + // size is limited by `vec_chunks`, above. + usize::MAX, + ); + let result = aggregate_values::<_, HV, B>( + validator.context(), + stream::iter(chunk).map(Ok).boxed(), + chunk_len, + Some(&mut record_ids), + ) + .await?; + validator.validate().await?; + chunk_counter += 1; + intermediate_results.push(result); + } + depth += 1; + } + + Ok(intermediate_results + .into_iter() + .next() + .expect("aggregation input must not be empty")) } /// Shuffles attribution Breakdown key and Trigger Value secret shares. Input @@ -94,7 +146,7 @@ where BK: BreakdownKey, TV: BooleanArray + U128Conversions, { - let shuffle_ctx = parent_ctx.narrow(&AggregationStep::Shuffle); + let shuffle_ctx = parent_ctx.narrow(&Step::Shuffle); shuffle_attribution_outputs::<_, BK, TV, BA64>(shuffle_ctx, contribs).await } @@ -116,25 +168,17 @@ where Replicated: BooleanProtocols, Boolean: FieldSimd, BK: BreakdownKey, + Replicated: Reveal>::Array>, TV: BooleanArray + U128Conversions, { - let reveal_ctx = parent_ctx - .narrow(&AggregationStep::RevealStep) - .set_total_records(TotalRecords::specified(attributions.len())?); + let reveal_ctx = parent_ctx.set_total_records(TotalRecords::specified(attributions.len())?); let reveal_work = stream::iter(attributions).enumerate().map(|(i, ao)| { let record_id = RecordId::from(i); let reveal_ctx = reveal_ctx.clone(); async move { - let revealed_bk = semi_honest_reveal( - reveal_ctx, - record_id, - None, - &ao.attributed_breakdown_key_bits, - ) - .await? - // Full reveal is used, meaning it is not possible to return None here - .unwrap(); + let revealed_bk = + reveal(reveal_ctx, record_id, &ao.attributed_breakdown_key_bits).await?; let revealed_bk = BK::from_array(&revealed_bk); let Ok(bk) = usize::try_from(revealed_bk.as_u128()) else { return Err(Error::Internal); @@ -173,22 +217,27 @@ impl GroupedTriggerValues { self.max_len = self.tvs[bk].len(); } } +} - fn into_stream<'fut>(mut self) -> Pin> + Send + 'fut>> - where - Boolean: FieldSimd, - BitDecomposed>: - for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, - { - let iter = (0..self.max_len).map(move |_| { - let slice: [Replicated; B] = self +impl From> + for Vec>> +where + Boolean: FieldSimd, + BitDecomposed>: + for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, +{ + fn from( + mut grouped_tvs: GroupedTriggerValues, + ) -> Vec>> { + let iter = (0..grouped_tvs.max_len).map(move |_| { + let slice: [Replicated; B] = grouped_tvs .tvs .each_mut() .map(|tv| tv.pop().unwrap_or(Replicated::ZERO)); - Ok(BitDecomposed::transposed_from(&slice).unwrap_infallible()) + BitDecomposed::transposed_from(&slice).unwrap_infallible() }); - Box::pin(stream::iter(iter)) + iter.collect() } } @@ -197,6 +246,8 @@ pub mod tests { use futures::TryFutureExt; use rand::{seq::SliceRandom, Rng}; + #[cfg(not(feature = "shuttle"))] + use crate::{ff::boolean_array::BA16, test_executor::run}; use crate::{ ff::{ boolean::Boolean, @@ -249,7 +300,7 @@ pub mod tests { } inputs.shuffle(&mut rng); let result: Vec<_> = world - .upgraded_semi_honest(inputs.into_iter(), |ctx, input_rows| async move { + .semi_honest(inputs.into_iter(), |ctx, input_rows| async move { let aos = input_rows .into_iter() .map(|ti| SecretSharedAttributionOutputs { @@ -277,4 +328,58 @@ pub mod tests { assert_eq!(result, expectation); }); } + + #[test] + #[cfg(not(feature = "shuttle"))] // too slow + fn malicious_happy_path() { + type HV = BA16; + run(|| async { + let world = TestWorld::default(); + let mut rng = rand::thread_rng(); + let mut expectation = Vec::new(); + for _ in 0..32 { + expectation.push(rng.gen_range(0u128..512)); + } + // The size of input needed here to get complete coverage (more precisely, + // the size of input to the final aggregation using `aggregate_values`) + // depends on `TARGET_PROOF_SIZE`. + let expectation = expectation; // no more mutability for safety + let mut inputs = Vec::new(); + for (bk, expected_hv) in expectation.iter().enumerate() { + let mut remainder = *expected_hv; + while remainder > 7 { + let tv = rng.gen_range(0u128..8); + remainder -= tv; + inputs.push(input_row(bk, tv)); + } + inputs.push(input_row(bk, remainder)); + } + inputs.shuffle(&mut rng); + let result: Vec<_> = world + .malicious(inputs.into_iter(), |ctx, input_rows| async move { + let aos = input_rows + .into_iter() + .map(|ti| SecretSharedAttributionOutputs { + attributed_breakdown_key_bits: ti.0, + capped_attributed_trigger_value: ti.1, + }) + .collect(); + breakdown_reveal_aggregation::<_, BA5, BA3, HV, 32>( + ctx, + aos, + &PaddingParameters::relaxed(), + ) + .map_ok(|d: BitDecomposed>| { + Vec::transposed_from(&d).unwrap() + }) + .await + .unwrap() + }) + .await + .reconstruct(); + let result = result.iter().map(|v: &HV| v.as_u128()).collect::>(); + assert_eq!(32, result.len()); + assert_eq!(result, expectation); + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index 41f91e243..a5ca281e6 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -1,4 +1,4 @@ -use std::{any::type_name, iter, pin::Pin}; +use std::{any::type_name, cmp::max, iter, pin::Pin}; use futures::{Stream, StreamExt, TryStreamExt}; use tracing::Instrument; @@ -97,9 +97,13 @@ pub type AggResult = Result /// /// $\sum_{i = 1}^k 2^{k - i} (b + i - 1) \approx 2^k (b + 1) = N (b + 1)$ pub fn aggregate_values_proof_chunk(input_width: usize, input_item_bits: usize) -> usize { - TARGET_PROOF_SIZE / input_width / (input_item_bits + 1) + max(2, TARGET_PROOF_SIZE / input_width / (input_item_bits + 1)).next_power_of_two() } +// This is the step count for AggregateChunkStep. We need it to size RecordId arrays. +// This value must be at least the log of the aggregation chunk size. +pub const AGGREGATE_DEPTH: usize = 24; + /// Aggregate output contributions /// /// In the case of attribution, each item in `aggregated_stream` is a vector of values to be added @@ -121,6 +125,7 @@ pub async fn aggregate_values<'ctx, 'fut, C, OV, const B: usize>( ctx: C, mut aggregated_stream: Pin> + Send + 'fut>>, mut num_rows: usize, + record_ids: Option<&mut [RecordId; AGGREGATE_DEPTH]>, ) -> Result>, Error> where 'ctx: 'fut, @@ -138,24 +143,31 @@ where OV::BITS, ); + let mut record_id_store = None; + let record_ids = + record_ids.unwrap_or_else(|| record_id_store.insert([RecordId::FIRST; AGGREGATE_DEPTH])); + let mut depth = 0; while num_rows > 1 { // Indeterminate TotalRecords is currently required because aggregation does not poll // futures in parallel (thus cannot reach a batch of records). // // We reduce pairwise, passing through the odd record at the end if there is one, so the - // number of outputs (`next_num_rows`) gets rounded up. If calculating an explicit total - // records, that would get rounded down. + // number of outputs (`next_num_rows`) gets rounded up. The number of addition operations + // (number of used record IDs) gets rounded down. let par_agg_ctx = ctx .narrow(&AggregateChunkStep::from(depth)) .set_total_records(TotalRecords::Indeterminate); let next_num_rows = (num_rows + 1) / 2; + let base_record_id = record_ids[depth]; + record_ids[depth] += num_rows / 2; aggregated_stream = Box::pin( FixedLength::new(aggregated_stream, num_rows) .try_chunks(2) .enumerate() .then(move |(i, chunk_res)| { let ctx = par_agg_ctx.clone(); + let record_id = base_record_id + i; async move { match chunk_res { Err(e) => { @@ -170,7 +182,6 @@ where assert_eq!(chunk_pair.len(), 2); let b = chunk_pair.pop().unwrap(); let a = chunk_pair.pop().unwrap(); - let record_id = RecordId::from(i); if a.len() < usize::try_from(OV::BITS).unwrap() { // If we have enough output bits, add and keep the carry. let (mut sum, carry) = integer_add::<_, AdditionStep, B>( @@ -198,7 +209,7 @@ where "reduce", depth = depth, rows = num_rows, - record = i + record = u32::from(record_id), )) }), ); @@ -268,7 +279,7 @@ pub mod tests { let result: BitDecomposed = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -291,7 +302,7 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -317,7 +328,7 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -345,7 +356,7 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -372,7 +383,7 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -391,7 +402,7 @@ pub mod tests { run(|| async move { let result = TestWorld::default() .upgraded_semi_honest((), |ctx, ()| { - aggregate_values::<_, BA8, 8>(ctx, stream::empty().boxed(), 0) + aggregate_values::<_, BA8, 8>(ctx, stream::empty().boxed(), 0, None) }) .await .map(Result::unwrap) @@ -412,7 +423,7 @@ pub mod tests { let result = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len(); - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await; @@ -432,7 +443,7 @@ pub mod tests { let _ = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len() + 1; - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -454,7 +465,7 @@ pub mod tests { let _ = TestWorld::default() .upgraded_semi_honest(inputs.into_iter(), |ctx, inputs| { let num_rows = inputs.len() - 1; - aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows) + aggregate_values::<_, BA8, 8>(ctx, stream::iter(inputs).boxed(), num_rows, None) }) .await .map(Result::unwrap) @@ -540,6 +551,7 @@ pub mod tests { ctx, stream::iter(inputs).boxed(), num_rows, + None, ) }) .await diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs index 3c7b5da95..0995a8e54 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs @@ -9,13 +9,18 @@ pub(crate) enum AggregationStep { PaddingDp, #[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)] Shuffle, - RevealStep, - #[step(child = AggregateChunkStep)] - SumContributions, + Reveal, + #[step(child = crate::protocol::context::step::DzkpSingleBatchStep)] + RevealValidate, // only partly used -- see code + #[step(count = 4, child = AggregateChunkStep)] + Aggregate(usize), + #[step(count = 600, child = crate::protocol::context::step::DzkpSingleBatchStep)] + AggregateValidate(usize), } +// The step count here is duplicated as the AGGREGATE_DEPTH constant in the code. #[derive(CompactStep)] -#[step(count = 32, child = AggregateValuesStep, name = "depth")] +#[step(count = 24, child = AggregateValuesStep, name = "depth")] pub(crate) struct AggregateChunkStep(usize); #[derive(CompactStep)] diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 2a34c030b..9626731e4 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -243,7 +243,8 @@ where PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, Replicated: Reveal, Output = >::Array>, - Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul> + + Reveal, Output = >::Array>, Replicated: BooleanArrayMul>, Replicated: BooleanArrayMul>, BitDecomposed>: @@ -754,7 +755,7 @@ mod compact_gate_tests { fn step_count_limit() { // This is an arbitrary limit intended to catch changes that unintentionally // blow up the step count. It can be increased, within reason. - const STEP_COUNT_LIMIT: u32 = 24_000; + const STEP_COUNT_LIMIT: u32 = 35_000; assert!( ProtocolStep::STEP_COUNT < STEP_COUNT_LIMIT, "Step count of {actual} exceeds limit of {STEP_COUNT_LIMIT}.", diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs index 19d6caac0..e2ae77c96 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs @@ -258,7 +258,7 @@ where seq_join(sh_ctx.active_work(), stream::iter(chunked_user_results)).try_flatten_iters(), ); let aggregated_result: BitDecomposed> = - aggregate_values::<_, HV, B>(binary_m_ctx, flattened_stream, num_outputs).await?; + aggregate_values::<_, HV, B>(binary_m_ctx, flattened_stream, num_outputs, None).await?; let transposed_aggregated_result: Vec> = Vec::transposed_from(&aggregated_result)?; diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 703826fd6..0490a7f87 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -22,7 +22,7 @@ use crate::{ }, helpers::{repeat_n, stream::TryFlattenItersExt, TotalRecords}, protocol::{ - basics::{select, BooleanArrayMul, BooleanProtocols, SecureMul, ShareKnownValue}, + basics::{select, BooleanArrayMul, BooleanProtocols, Reveal, SecureMul, ShareKnownValue}, boolean::{ or::or, step::{EightBitStep, ThirtyTwoBitStep}, @@ -33,7 +33,6 @@ use crate::{ Context, DZKPContext, DZKPUpgraded, MaliciousProtocolSteps, UpgradableContext, }, ipa_prf::{ - aggregation::aggregate_values_proof_chunk, boolean_ops::{ addition_sequential::integer_add, comparison_and_subtraction_sequential::{compare_gt, integer_sub}, @@ -51,7 +50,7 @@ use crate::{ }, secret_sharing::{ replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, - BitDecomposed, FieldSimd, SharedValue, TransposeFrom, + BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, }, }; @@ -482,7 +481,8 @@ where Replicated: BooleanProtocols>, Replicated: BooleanProtocols, B>, Replicated: BooleanProtocols, AGG_CHUNK>, - Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul> + + Reveal, Output = >::Array>, Replicated: BooleanArrayMul>, Replicated: BooleanArrayMul>, BitDecomposed>: @@ -538,22 +538,13 @@ where attribution_window_seconds, ); - let validator = sh_ctx.dzkp_validator( - MaliciousProtocolSteps { - protocol: &Step::Aggregate, - validate: &Step::AggregateValidate, - }, - aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()).next_power_of_two(), - ); let user_contributions = flattened_user_results.try_collect::>().await?; - let result = breakdown_reveal_aggregation::<_, _, _, HV, B>( - validator.context(), + breakdown_reveal_aggregation::<_, BK, TV, HV, B>( + sh_ctx.narrow(&Step::Aggregate), user_contributions, padding_parameters, ) - .await; - validator.validate().await?; - result + .await } #[tracing::instrument(name = "attribute_cap", skip_all, fields(unique_match_keys = input.len()))] diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs index 03255d342..710b0a7e3 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs @@ -12,8 +12,6 @@ pub(crate) enum AttributionStep { AttributeValidate, #[step(child = crate::protocol::ipa_prf::aggregation::step::AggregationStep)] Aggregate, - #[step(child = crate::protocol::context::step::DzkpSingleBatchStep)] - AggregateValidate, } #[derive(CompactStep)] diff --git a/ipa-core/src/query/runner/oprf_ipa.rs b/ipa-core/src/query/runner/oprf_ipa.rs index fa8b787d8..20a112265 100644 --- a/ipa-core/src/query/runner/oprf_ipa.rs +++ b/ipa-core/src/query/runner/oprf_ipa.rs @@ -68,7 +68,8 @@ where PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, Replicated: Reveal, Output = >::Array>, - Replicated: BooleanArrayMul>, + Replicated: BooleanArrayMul> + + Reveal, Output = >::Array>, Replicated: BooleanArrayMul>, Replicated: BooleanArrayMul>, Vec>: diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index e8dfd95ae..5bbfd8c87 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -2,6 +2,7 @@ pub mod array; pub mod arraychunks; #[cfg(target_pointer_width = "64")] mod power_of_two; +pub mod vec_chunks; #[cfg(target_pointer_width = "64")] pub use power_of_two::NonZeroU32PowerOfTwo; diff --git a/ipa-core/src/utils/vec_chunks.rs b/ipa-core/src/utils/vec_chunks.rs new file mode 100644 index 000000000..9732bf184 --- /dev/null +++ b/ipa-core/src/utils/vec_chunks.rs @@ -0,0 +1,58 @@ +use std::cmp::min; + +pub struct VecChunks { + vec: Vec, + pos: usize, + chunk_size: usize, +} + +impl Iterator for VecChunks { + type Item = Vec; + + fn next(&mut self) -> Option { + let start = self.pos; + let len = min(self.vec.len() - start, self.chunk_size); + (len != 0).then(|| { + self.pos += len; + self.vec[start..start + len].to_vec() + }) + } +} + +pub fn vec_chunks(vec: Vec, chunk_size: usize) -> impl Iterator> { + assert!(chunk_size != 0); + VecChunks { + vec, + pos: 0, + chunk_size, + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use super::vec_chunks; + use crate::ff::{Field, Fp61BitPrime}; + + #[test] + fn vec_chunk_iter() { + let elements = vec![Fp61BitPrime::ONE; 4]; + + let mut vec_chunk_iterator = vec_chunks(elements, 3); + + assert_eq!( + vec_chunk_iterator.next().unwrap(), + vec![Fp61BitPrime::ONE; 3] + ); + assert_eq!( + vec_chunk_iterator.next().unwrap(), + vec![Fp61BitPrime::ONE; 1] + ); + assert!(vec_chunk_iterator.next().is_none()); + } + + #[test] + fn vec_chunk_empty() { + let vec = Vec::::new(); + assert!(vec_chunks(vec, 1).next().is_none()); + } +} From 6283343c6bd7aa0890bd3d65b21ba0da23ade658 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 11 Oct 2024 12:04:19 -0700 Subject: [PATCH 128/191] Make it work for cap 1, 2 and 4. It is still gross that we use 3 bit trigger values for everything, but oh well... --- ipa-core/src/query/runner/oprf_ipa.rs | 4 +++- ipa-core/tests/compact_gate.rs | 5 ----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/ipa-core/src/query/runner/oprf_ipa.rs b/ipa-core/src/query/runner/oprf_ipa.rs index fa8b787d8..27d2c7938 100644 --- a/ipa-core/src/query/runner/oprf_ipa.rs +++ b/ipa-core/src/query/runner/oprf_ipa.rs @@ -147,13 +147,15 @@ where #[cfg(not(feature = "relaxed-dp"))] let padding_params = PaddingParameters::default(); match config.per_user_credit_cap { + 1 => oprf_ipa::<_, BA8, BA3, HV, BA20, 1, 256>(ctx, input, aws, dp_params, padding_params).await, + 2 | 4 => oprf_ipa::<_, BA8, BA3, HV, BA20, 2, 256>(ctx, input, aws, dp_params, padding_params).await, 8 => oprf_ipa::<_, BA8, BA3, HV, BA20, 3, 256>(ctx, input, aws, dp_params, padding_params).await, 16 => oprf_ipa::<_, BA8, BA3, HV, BA20, 4, 256>(ctx, input, aws, dp_params, padding_params).await, 32 => oprf_ipa::<_, BA8, BA3, HV, BA20, 5, 256>(ctx, input, aws, dp_params, padding_params).await, 64 => oprf_ipa::<_, BA8, BA3, HV, BA20, 6, 256>(ctx, input, aws, dp_params, padding_params).await, 128 => oprf_ipa::<_, BA8, BA3, HV, BA20, 7, 256>(ctx, input, aws, dp_params, padding_params).await, _ => panic!( - "Invalid value specified for per-user cap: {:?}. Must be one of 8, 16, 32, 64, or 128.", + "Invalid value specified for per-user cap: {:?}. Must be one of 1, 2, 4, 8, 16, 32, 64, or 128.", config.per_user_credit_cap ), } diff --git a/ipa-core/tests/compact_gate.rs b/ipa-core/tests/compact_gate.rs index 7e31c626e..354ad438c 100644 --- a/ipa-core/tests/compact_gate.rs +++ b/ipa-core/tests/compact_gate.rs @@ -41,11 +41,6 @@ fn compact_gate_cap_2_no_window_semi_honest_encryped_input() { test_compact_gate(IpaSecurityModel::SemiHonest, 2, 0, true); } -#[test] -fn compact_gate_cap_3_no_window_semi_honest_encryped_input() { - test_compact_gate(IpaSecurityModel::SemiHonest, 3, 0, true); -} - #[test] fn compact_gate_cap_4_no_window_semi_honest_encryped_input() { test_compact_gate(IpaSecurityModel::SemiHonest, 4, 0, true); From b86bf90cf83a741233821d6426e7e13f4f065880 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Mon, 14 Oct 2024 13:19:08 -0700 Subject: [PATCH 129/191] Review follow-up from #1323 --- .../src/protocol/context/dzkp_malicious.rs | 4 ++ .../ipa_prf/aggregation/breakdown_reveal.rs | 13 +++-- ipa-core/src/utils/mod.rs | 1 - ipa-core/src/utils/vec_chunks.rs | 58 ------------------- 4 files changed, 11 insertions(+), 65 deletions(-) delete mode 100644 ipa-core/src/utils/vec_chunks.rs diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 6511663d7..671dfa08d 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -43,6 +43,10 @@ impl<'a, B: ShardBinding> DZKPUpgraded<'a, B> { // in tests; there shouldn't be a risk of deadlocks with one record per // batch; and UnorderedReceiver capacity (which is set from active_work) // must be at least two. + // + // Also rely on the protocol to ensure an appropriate active_work if + // records_per_batch is `usize::MAX` (unlimited batch size). Allocating + // storage for `usize::MAX` active records won't work. base_ctx.active_work() } else { // Adjust active_work to match records_per_batch. If it is less, we will diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index 6ff1b19c9..65ec97230 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, mem, pin::pin}; +use std::{convert::Infallible, pin::pin}; use futures::stream; use futures_util::{StreamExt, TryStreamExt}; @@ -34,7 +34,6 @@ use crate::{ TransposeFrom, Vectorizable, }, seq_join::seq_join, - utils::vec_chunks::vec_chunks, }; /// Improved Aggregation a.k.a Aggregation revealing breakdown. @@ -100,7 +99,8 @@ where while intermediate_results.len() > 1 { let mut record_ids = [RecordId::FIRST; AGGREGATE_DEPTH]; - for chunk in vec_chunks(mem::take(&mut intermediate_results), agg_proof_chunk) { + let mut next_intermediate_results = Vec::new(); + for chunk in intermediate_results.chunks(agg_proof_chunk) { let chunk_len = chunk.len(); let validator = ctx.clone().dzkp_validator( MaliciousProtocolSteps { @@ -109,21 +109,22 @@ where }, // We have to specify usize::MAX here because the procession through // record IDs is different at each step of the reduction. The batch - // size is limited by `vec_chunks`, above. + // size is limited by `intermediate_results.chunks()`, above. usize::MAX, ); let result = aggregate_values::<_, HV, B>( validator.context(), - stream::iter(chunk).map(Ok).boxed(), + stream::iter(chunk).map(|v| Ok(v.clone())).boxed(), chunk_len, Some(&mut record_ids), ) .await?; validator.validate().await?; chunk_counter += 1; - intermediate_results.push(result); + next_intermediate_results.push(result); } depth += 1; + intermediate_results = next_intermediate_results; } Ok(intermediate_results diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index 5bbfd8c87..e8dfd95ae 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -2,7 +2,6 @@ pub mod array; pub mod arraychunks; #[cfg(target_pointer_width = "64")] mod power_of_two; -pub mod vec_chunks; #[cfg(target_pointer_width = "64")] pub use power_of_two::NonZeroU32PowerOfTwo; diff --git a/ipa-core/src/utils/vec_chunks.rs b/ipa-core/src/utils/vec_chunks.rs deleted file mode 100644 index 9732bf184..000000000 --- a/ipa-core/src/utils/vec_chunks.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::cmp::min; - -pub struct VecChunks { - vec: Vec, - pos: usize, - chunk_size: usize, -} - -impl Iterator for VecChunks { - type Item = Vec; - - fn next(&mut self) -> Option { - let start = self.pos; - let len = min(self.vec.len() - start, self.chunk_size); - (len != 0).then(|| { - self.pos += len; - self.vec[start..start + len].to_vec() - }) - } -} - -pub fn vec_chunks(vec: Vec, chunk_size: usize) -> impl Iterator> { - assert!(chunk_size != 0); - VecChunks { - vec, - pos: 0, - chunk_size, - } -} - -#[cfg(all(test, unit_test))] -mod tests { - use super::vec_chunks; - use crate::ff::{Field, Fp61BitPrime}; - - #[test] - fn vec_chunk_iter() { - let elements = vec![Fp61BitPrime::ONE; 4]; - - let mut vec_chunk_iterator = vec_chunks(elements, 3); - - assert_eq!( - vec_chunk_iterator.next().unwrap(), - vec![Fp61BitPrime::ONE; 3] - ); - assert_eq!( - vec_chunk_iterator.next().unwrap(), - vec![Fp61BitPrime::ONE; 1] - ); - assert!(vec_chunk_iterator.next().is_none()); - } - - #[test] - fn vec_chunk_empty() { - let vec = Vec::::new(); - assert!(vec_chunks(vec, 1).next().is_none()); - } -} From a6b44756b62ef444e4174c4cc2ddb79fe5f7e5ae Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Mon, 14 Oct 2024 12:14:32 -0700 Subject: [PATCH 130/191] TransportRestriction and Client --- ipa-core/src/helpers/mod.rs | 4 + ipa-core/src/helpers/transport/mod.rs | 92 ++++++++++++++- ipa-core/src/net/client/mod.rs | 161 ++++++++++++++------------ ipa-core/src/net/mod.rs | 5 + ipa-core/src/net/server/mod.rs | 15 ++- ipa-core/src/net/test.rs | 14 +-- ipa-core/src/net/transport.rs | 4 +- ipa-core/src/sharding.rs | 48 +++++++- 8 files changed, 246 insertions(+), 97 deletions(-) diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 2c43ccd53..5f86f4305 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -181,6 +181,10 @@ impl HelperIdentity { pub const TWO: Self = Self { id: 2 }; pub const THREE: Self = Self { id: 3 }; + pub const ONE_STR: &'static str = "A"; + pub const TWO_STR: &'static str = "B"; + pub const THREE_STR: &'static str = "C"; + /// Given a helper identity, return an array of the identities of the other two helpers. // The order that helpers are returned here is not intended to be meaningful, however, // it is currently used directly to determine the assignment of roles in diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index c3bb307d8..de20a7d90 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -44,22 +44,47 @@ pub trait Identity: Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord + Hash + Send + Sync + 'static { fn as_str(&self) -> Cow<'static, str>; + + /// Parses a ref to a string representation of this identity + /// + /// # Errors + /// If there where any problems parsing the identity. + fn from_str(s: &str) -> Result; } impl Identity for ShardIndex { fn as_str(&self) -> Cow<'static, str> { Cow::Owned(self.to_string()) } + + fn from_str(s: &str) -> Result { + s.parse::() + .map_err(|_e| { + crate::error::Error::InvalidId(format!("The string {s} is an invalid Shard Index")) + }) + .map(ShardIndex::from) + } } impl Identity for HelperIdentity { fn as_str(&self) -> Cow<'static, str> { Cow::Borrowed(match *self { - Self::ONE => "A", - Self::TWO => "B", - Self::THREE => "C", + Self::ONE => Self::ONE_STR, + Self::TWO => Self::TWO_STR, + Self::THREE => Self::THREE_STR, _ => unreachable!(), }) } + + fn from_str(s: &str) -> Result { + match s { + Self::ONE_STR => Ok(Self::ONE), + Self::TWO_STR => Ok(Self::TWO), + Self::THREE_STR => Ok(Self::THREE), + _ => Err(crate::error::Error::InvalidId(format!( + "The string {s} is an invalid Helper Identity" + ))), + } + } } /// Role is an identifier of helper peer, only valid within a given query. For every query, there @@ -68,6 +93,17 @@ impl Identity for Role { fn as_str(&self) -> Cow<'static, str> { Cow::Borrowed(Role::as_static_str(self)) } + + fn from_str(s: &str) -> Result { + match s { + Self::H1_STR => Ok(Self::H1), + Self::H2_STR => Ok(Self::H2), + Self::H3_STR => Ok(Self::H3), + _ => Err(crate::error::Error::InvalidId(format!( + "The string {s} is an invalid Role" + ))), + } + } } pub trait ResourceIdentifier: Sized {} @@ -229,3 +265,53 @@ pub trait Transport: Clone + Send + Sync + 'static { ::clone(self) } } + +#[cfg(all(test, unit_test))] +mod tests { + use crate::{ + helpers::{HelperIdentity, Role, TransportIdentity}, + sharding::ShardIndex, + }; + + #[test] + fn helper_from_str() { + assert_eq!(HelperIdentity::from_str("A").unwrap(), HelperIdentity::ONE); + assert_eq!(HelperIdentity::from_str("B").unwrap(), HelperIdentity::TWO); + assert_eq!( + HelperIdentity::from_str("C").unwrap(), + HelperIdentity::THREE + ); + } + + #[test] + #[should_panic(expected = "The string H1 is an invalid Helper Identity")] + fn invalid_helper_from_str() { + assert_eq!(HelperIdentity::from_str("H1").unwrap(), HelperIdentity::ONE); + } + + #[test] + fn shard_from_str() { + assert_eq!(ShardIndex::from_str("42").unwrap(), ShardIndex::from(42)); + assert_eq!(ShardIndex::from_str("9").unwrap(), ShardIndex::from(9)); + assert_eq!(ShardIndex::from_str("0").unwrap(), ShardIndex::from(0)); + } + + #[test] + #[should_panic(expected = "The string -1 is an invalid Shard Index")] + fn invalid_shard_from_str() { + assert_eq!(ShardIndex::from_str("-1").unwrap(), ShardIndex::from(0)); + } + + #[test] + fn role_from_str() { + assert_eq!(Role::from_str("H1").unwrap(), Role::H1); + assert_eq!(Role::from_str("H2").unwrap(), Role::H2); + assert_eq!(Role::from_str("H3").unwrap(), Role::H3); + } + + #[test] + #[should_panic(expected = "The string A is an invalid Role")] + fn invalid_role_from_str() { + assert_eq!(Role::from_str("A").unwrap(), Role::H1); + } +} diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 37b5654f7..5b331520d 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -2,6 +2,7 @@ use std::{ collections::HashMap, future::Future, io::{self, BufRead}, + marker::PhantomData, pin::Pin, sync::Arc, task::{ready, Context, Poll}, @@ -32,18 +33,19 @@ use crate::{ executor::IpaRuntime, helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, - HelperIdentity, + TransportIdentity, }, - net::{http_serde, server::HTTP_CLIENT_ID_HEADER, Error, CRYPTO_PROVIDER}, + net::{http_serde, Error, CRYPTO_PROVIDER, HTTP_CLIENT_ID_HEADER}, protocol::{Gate, QueryId}, + sharding::{Ring, TransportRestriction}, }; -#[derive(Default)] -pub enum ClientIdentity { +#[derive(Default, Debug)] +pub enum ClientIdentity { /// Claim the specified helper identity without any additional authentication. /// /// This is only supported for HTTP clients. - Helper(HelperIdentity), + Header(R::Identity), /// Authenticate with an X.509 certificate or a certificate chain. /// @@ -55,7 +57,7 @@ pub enum ClientIdentity { None, } -impl ClientIdentity { +impl ClientIdentity { /// Authenticates clients with an X.509 certificate using the provided certificate and private /// key. Certificate must be in PEM format, private key encoding must be [`PKCS8`]. /// @@ -80,10 +82,10 @@ impl ClientIdentity { /// to own a private key, and we need to create 3 with the same config, we provide Clone /// capabilities via this method to `ClientIdentity`. #[must_use] - pub fn clone_with_key(&self) -> ClientIdentity { + pub fn clone_with_key(&self) -> ClientIdentity { match self { Self::Certificate((c, pk)) => Self::Certificate((c.clone(), pk.clone_key())), - Self::Helper(h) => Self::Helper(*h), + Self::Header(h) => Self::Header(*h), Self::None => Self::None, } } @@ -153,39 +155,15 @@ impl Future for ResponseFuture { /// TODO: It probably isn't necessary to always use `[MpcHelperClient; 3]`. Instead, a single /// client can be configured to talk to all three helpers. #[derive(Debug, Clone)] -pub struct MpcHelperClient { +pub struct MpcHelperClient { client: Client, Body>, scheme: uri::Scheme, authority: uri::Authority, auth_header: Option<(HeaderName, HeaderValue)>, + _restriction: PhantomData, } -impl MpcHelperClient { - /// Create a set of clients for the MPC helpers in the supplied helper network configuration. - /// - /// This function returns a set of three clients, which may be used to talk to each of the - /// helpers. - /// - /// `identity` configures whether and how the client will authenticate to the server. It is for - /// the helper making the calls, so the same one is used for all three of the clients. - /// Authentication is not required when calling the report collector APIs. - #[must_use] - #[allow(clippy::missing_panics_doc)] - pub fn from_conf( - runtime: &IpaRuntime, - conf: &NetworkConfig, - identity: &ClientIdentity, - ) -> [MpcHelperClient; 3] { - conf.peers().each_ref().map(|peer_conf| { - Self::new( - runtime.clone(), - &conf.client, - peer_conf.clone(), - identity.clone_with_key(), - ) - }) - } - +impl MpcHelperClient { /// Create a new client with the given configuration /// /// `identity`, if present, configures whether and how the client will authenticate to the server @@ -198,7 +176,8 @@ impl MpcHelperClient { runtime: IpaRuntime, client_config: &ClientConfig, peer_config: PeerConfig, - identity: ClientIdentity, + identity: ClientIdentity, + header_name: &'static HeaderName, ) -> Self { let (connector, auth_header) = if peer_config.url.scheme() == Some(&Scheme::HTTP) { // This connector works for both http and https. A regular HttpConnector would suffice, @@ -208,7 +187,10 @@ impl MpcHelperClient { error!("certificate identity ignored for HTTP client"); None } - ClientIdentity::Helper(id) => Some((HTTP_CLIENT_ID_HEADER.clone(), id.into())), + ClientIdentity::Header(id) => Some(( + header_name.clone(), + HeaderValue::from_str(id.as_str().as_ref()).unwrap(), + )), ClientIdentity::None => None, }; ( @@ -238,7 +220,7 @@ impl MpcHelperClient { ClientIdentity::Certificate((cert_chain, pk)) => builder .with_client_auth_cert(cert_chain, pk) .expect("Can setup client authentication with certificate"), - ClientIdentity::Helper(_) => { + ClientIdentity::Header(_) => { error!("header-passed identity ignored for HTTPS client"); builder.with_no_client_auth() } @@ -296,6 +278,7 @@ impl MpcHelperClient { scheme, authority, auth_header, + _restriction: PhantomData, } } @@ -356,6 +339,66 @@ impl MpcHelperClient { } } + /// Sends a batch of messages associated with a query's step to another helper. Messages are a + /// contiguous block of records. Also includes [`crate::protocol::RecordId`] information and + /// [`crate::helpers::network::ChannelId`]. + /// # Errors + /// If the request has illegal arguments, or fails to deliver to helper + /// # Panics + /// If messages size > max u32 (unlikely) + pub fn step> + Send + 'static>( + &self, + query_id: QueryId, + gate: &Gate, + data: S, + ) -> Result { + let data = data.map(|v| Ok::(Bytes::from(v))); + let body = axum::body::Body::from_stream(data); + let req = http_serde::query::step::Request::new(query_id, gate.clone(), body); + let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; + Ok(self.request(req)) + } + + /// Used to communicate from one helper to another. Specifically, the helper that receives a + /// "create query" from an external party must communicate the intent to start a query to the + /// other helpers, which this prepare query does. + /// # Errors + /// If the request has illegal arguments, or fails to deliver to helper + pub async fn prepare_query(&self, data: PrepareQuery) -> Result<(), Error> { + let req = http_serde::query::prepare::Request::new(data); + let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; + let resp = self.request(req).await?; + Self::resp_ok(resp).await + } +} + +impl MpcHelperClient { + /// Create a set of clients for the MPC helpers in the supplied helper network configuration. + /// + /// This function returns a set of three clients, which may be used to talk to each of the + /// helpers. + /// + /// `identity` configures whether and how the client will authenticate to the server. It is for + /// the helper making the calls, so the same one is used for all three of the clients. + /// Authentication is not required when calling the report collector APIs. + #[must_use] + #[allow(clippy::missing_panics_doc)] + pub fn from_conf( + runtime: &IpaRuntime, + conf: &NetworkConfig, + identity: &ClientIdentity, + ) -> [Self; 3] { + conf.peers().each_ref().map(|peer_conf| { + Self::new( + runtime.clone(), + &conf.client, + peer_conf.clone(), + identity.clone_with_key(), + &HTTP_CLIENT_ID_HEADER, + ) + }) + } + /// Intended to be called externally, by the report collector. Informs the MPC ring that /// the external party wants to start a new query. /// # Errors @@ -374,18 +417,6 @@ impl MpcHelperClient { } } - /// Used to communicate from one helper to another. Specifically, the helper that receives a - /// "create query" from an external party must communicate the intent to start a query to the - /// other helpers, which this prepare query does. - /// # Errors - /// If the request has illegal arguments, or fails to deliver to helper - pub async fn prepare_query(&self, data: PrepareQuery) -> Result<(), Error> { - let req = http_serde::query::prepare::Request::new(data); - let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; - let resp = self.request(req).await?; - Self::resp_ok(resp).await - } - /// Intended to be called externally, e.g. by the report collector. After the report collector /// calls "create query", it must then send the data for the query to each of the clients. This /// query input contains the data intended for a helper. @@ -398,26 +429,6 @@ impl MpcHelperClient { Self::resp_ok(resp).await } - /// Sends a batch of messages associated with a query's step to another helper. Messages are a - /// contiguous block of records. Also includes [`crate::protocol::RecordId`] information and - /// [`crate::helpers::network::ChannelId`]. - /// # Errors - /// If the request has illegal arguments, or fails to deliver to helper - /// # Panics - /// If messages size > max u32 (unlikely) - pub fn step> + Send + 'static>( - &self, - query_id: QueryId, - gate: &Gate, - data: S, - ) -> Result { - let data = data.map(|v| Ok::(Bytes::from(v))); - let body = axum::body::Body::from_stream(data); - let req = http_serde::query::step::Request::new(query_id, gate.clone(), body); - let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; - Ok(self.request(req)) - } - /// Retrieve the status of a query. /// /// ## Errors @@ -485,8 +496,8 @@ pub(crate) mod tests { use crate::{ ff::{FieldType, Fp31}, helpers::{ - make_owned_handler, query::QueryType::TestMultiply, BytesStream, HelperResponse, - RequestHandler, RoleAssignment, Transport, MESSAGE_PAYLOAD_SIZE_BYTES, + make_owned_handler, query::QueryType::TestMultiply, BytesStream, HelperIdentity, + HelperResponse, RequestHandler, RoleAssignment, Transport, MESSAGE_PAYLOAD_SIZE_BYTES, }, net::test::TestServer, protocol::step::TestExecutionStep, @@ -509,10 +520,10 @@ pub(crate) mod tests { hpke_config: None, }; let client = MpcHelperClient::new( - IpaRuntime::current(), &ClientConfig::default(), peer_config, - ClientIdentity::None, + ClientIdentity::::None, + &HTTP_CLIENT_ID_HEADER, ); // The server's self-signed test cert is not in the system truststore, and we didn't supply @@ -680,7 +691,7 @@ pub(crate) mod tests { .await .unwrap(); - MpcHelperClient::resp_ok(resp).await.unwrap(); + MpcHelperClient::::resp_ok(resp).await.unwrap(); let mut stream = Arc::clone(&transport) .receive(HelperIdentity::ONE, (QueryId, expected_step.clone())) diff --git a/ipa-core/src/net/mod.rs b/ipa-core/src/net/mod.rs index cb1373c7c..332af7f49 100644 --- a/ipa-core/src/net/mod.rs +++ b/ipa-core/src/net/mod.rs @@ -3,6 +3,7 @@ use std::{ sync::Arc, }; +use axum::http::HeaderName; use once_cell::sync::Lazy; use rustls::crypto::CryptoProvider; use rustls_pki_types::CertificateDer; @@ -24,6 +25,10 @@ pub use transport::{HttpShardTransport, HttpTransport}; pub const APPLICATION_JSON: &str = "application/json"; pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; +pub static HTTP_CLIENT_ID_HEADER: HeaderName = + HeaderName::from_static("x-unverified-client-identity"); +pub static HTTP_SHARD_INDEX_HEADER: HeaderName = + HeaderName::from_static("x-unverified-shard-index"); /// This has the same meaning as const defined in h2 crate, but we don't import it directly. /// According to the [`spec`] it cannot exceed 2^31 - 1. diff --git a/ipa-core/src/net/server/mod.rs b/ipa-core/src/net/server/mod.rs index 5de703e0b..1f7d9b6ec 100644 --- a/ipa-core/src/net/server/mod.rs +++ b/ipa-core/src/net/server/mod.rs @@ -15,6 +15,7 @@ use ::tokio::{ net::TcpStream, }; use axum::{ + http::HeaderValue, response::{IntoResponse, Response}, routing::IntoMakeService, Router, @@ -43,7 +44,7 @@ use crate::{ config::{NetworkConfig, OwnedCertificate, OwnedPrivateKey, ServerConfig, TlsConfig}, error::BoxError, executor::{IpaJoinHandle, IpaRuntime}, - helpers::HelperIdentity, + helpers::{HelperIdentity, TransportIdentity}, net::{ parse_certificate_and_private_key_bytes, server::config::HttpServerConfig, Error, HttpTransport, CRYPTO_PROVIDER, @@ -443,6 +444,13 @@ impl SetClientIdentityFromHeader { fn new(inner: S) -> Self { Self { inner } } + + fn parse_client_id(header_value: &HeaderValue) -> Result { + let header_str = header_value.to_str()?; + HelperIdentity::from_str(header_str) + .map_err(|e| Error::InvalidHeader(Box::new(e))) + .map(ClientIdentity) + } } impl, Response = Response>> Service> @@ -459,10 +467,10 @@ impl, Response = Response>> Service> fn call(&mut self, mut req: Request) -> Self::Future { if let Some(header_value) = req.headers().get(&HTTP_CLIENT_ID_HEADER) { - let id_result = serde_json::from_slice(header_value.as_ref()) + let id_result = Self::parse_client_id(header_value) .map_err(|e| Error::InvalidHeader(format!("{HTTP_CLIENT_ID_HEADER}: {e}").into())); match id_result { - Ok(id) => req.extensions_mut().insert(ClientIdentity(id)), + Ok(id) => req.extensions_mut().insert(id), Err(err) => return ready(Ok(err.into_response())).right_future(), }; } @@ -722,6 +730,7 @@ mod e2e_tests { let expected = expected_req(addr.to_string()); let req = http_req(&expected, uri::Scheme::HTTP, addr.to_string()); let response = client.request(req).await.unwrap(); + println!("{}", response.status()); assert_eq!(response.status(), StatusCode::OK); assert_eq!( diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index 85795940b..d56dbf387 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -273,7 +273,7 @@ impl TestServerBuilder { pub async fn build(self) -> TestServer { let identity = if self.disable_https { - ClientIdentity::Helper(HelperIdentity::ONE) + ClientIdentity::Header(HelperIdentity::ONE) } else { get_test_identity(HelperIdentity::ONE) }; @@ -297,6 +297,7 @@ impl TestServerBuilder { &identity.clone_with_key(), ); let handler = self.handler.as_ref().map(HandlerBox::owning_ref); + let client = clients[0].clone(); let (transport, server) = HttpTransport::new( IpaRuntime::current(), HelperIdentity::ONE, @@ -308,17 +309,6 @@ impl TestServerBuilder { let (addr, handle) = server .start_on(&IpaRuntime::current(), Some(server_socket), self.metrics) .await; - // Get the config for HelperIdentity::ONE - let h1_peer_config = network_config.peers.into_iter().next().unwrap(); - // At some point it might be appropriate to return two clients here -- the first being - // another helper and the second being a report collector. For now we use the same client - // for both types of calls. - let client = MpcHelperClient::new( - IpaRuntime::current(), - &network_config.client, - h1_peer_config, - identity, - ); TestServer { addr, handle, diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index c053b0b8a..990527b1d 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -21,7 +21,7 @@ use crate::{ }, net::{client::MpcHelperClient, error::Error, MpcHelperServer}, protocol::{Gate, QueryId}, - sharding::ShardIndex, + sharding::{Ring, ShardIndex}, sync::Arc, }; @@ -208,7 +208,7 @@ impl Transport for Arc { .spawn( resp_future .map_err(Into::into) - .and_then(MpcHelperClient::resp_ok), + .and_then(MpcHelperClient::::resp_ok), ) .await?; Ok(()) diff --git a/ipa-core/src/sharding.rs b/ipa-core/src/sharding.rs index e4f9475b7..671a1675a 100644 --- a/ipa-core/src/sharding.rs +++ b/ipa-core/src/sharding.rs @@ -3,9 +3,53 @@ use std::{ num::TryFromIntError, }; +use serde::{Deserialize, Serialize}; + +use crate::helpers::{HelperIdentity, TransportIdentity}; + +/// This simple trait is used to make aware on what transport dimnsion one is running. Structs like +/// [`crate::net::client::MpcHelperClient`] use it to know whether they are talking to other +/// servers as Shards inside a Helper or as a Helper talking to another Helper in a Ring. This +/// trait can be used to limit the functions exposed by a struct impl depending on the context that +/// it's being used. Continuing the previous example, the functions a +/// [`crate::net::client::MpcHelperClient`] provides are dependent on whether it's communicating +/// with another Shard or another Helper. +/// +/// This trait is a safety restriction so that structs or traits only expose an API that's +/// meaningful for their specific context. When used as a generic bound, it also spreads through +/// the types making it harder to be misused or combining incompatible types, e.g. Using a +/// [`ShardIndex`] with a [`Ring`]. +pub trait TransportRestriction: Debug + Send + Sync + Clone + 'static { + /// The meaningful identity used in this transport dimension. + type Identity: TransportIdentity; +} + +/// This marker is used to restrict communication inside a single Helper, with other shards. +#[derive(Debug, Copy, Clone)] +pub struct Sharding; + +/// This marker is used to restrict communication inter Helpers. This communication usually has +/// more restrictions. 3 Hosts with the same sharding index are conencted in a Ring. +#[derive(Debug, Copy, Clone)] +pub struct Ring; + +impl TransportRestriction for Sharding { + type Identity = ShardIndex; +} +impl TransportRestriction for Ring { + type Identity = HelperIdentity; +} + /// A unique zero-based index of the helper shard. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ShardIndex(u32); +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +#[serde(from = "u32")] +pub struct ShardIndex(pub u32); + +impl From for u32 { + fn from(value: ShardIndex) -> Self { + value.0 + } +} #[derive(Debug, Copy, Clone)] pub struct Sharded { From fdd92791f78ace59a5326768a6d5326fb6e7b983 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Mon, 14 Oct 2024 17:24:55 -0700 Subject: [PATCH 131/191] Fix: ClientIdentity::Header --- ipa-core/src/bin/helper.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 41fcb88ef..1b2a50398 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -127,7 +127,7 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { }), ) } - (None, None) => (ClientIdentity::Helper(my_identity), None), + (None, None) => (ClientIdentity::Header(my_identity), None), _ => panic!("should have been rejected by clap"), }; From 2fee2b91f28d3a92a7844f6787254855a701a2b5 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Mon, 14 Oct 2024 18:02:08 -0700 Subject: [PATCH 132/191] Fix: nits --- ipa-core/src/net/client/mod.rs | 3 ++- ipa-core/src/net/transport.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 5b331520d..62cb04b75 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -520,6 +520,7 @@ pub(crate) mod tests { hpke_config: None, }; let client = MpcHelperClient::new( + IpaRuntime::current(), &ClientConfig::default(), peer_config, ClientIdentity::::None, @@ -691,7 +692,7 @@ pub(crate) mod tests { .await .unwrap(); - MpcHelperClient::::resp_ok(resp).await.unwrap(); + MpcHelperClient::::resp_ok(resp).await.unwrap(); let mut stream = Arc::clone(&transport) .receive(HelperIdentity::ONE, (QueryId, expected_step.clone())) diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 990527b1d..9a7732a33 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -388,7 +388,7 @@ mod tests { zip(HelperIdentity::make_three(), zip(sockets, server_config)).map( |(id, (socket, server_config))| async move { let identity = if disable_https { - ClientIdentity::Helper(id) + ClientIdentity::Header(id) } else { get_test_identity(id) }; From dc4d96433425267f03c454737b36fc4618208b77 Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Tue, 15 Oct 2024 16:50:18 -0700 Subject: [PATCH 133/191] Hybrid Impression Stuff --- ipa-core/src/hpke/mod.rs | 55 +++++ ipa-core/src/report/hybrid.rs | 356 ++++++++++++++++++++++++++--- ipa-core/src/report/hybrid_info.rs | 67 ++++++ ipa-core/src/report/mod.rs | 1 + 4 files changed, 451 insertions(+), 28 deletions(-) create mode 100644 ipa-core/src/report/hybrid_info.rs diff --git a/ipa-core/src/hpke/mod.rs b/ipa-core/src/hpke/mod.rs index 19b21bf67..a1810aaa4 100644 --- a/ipa-core/src/hpke/mod.rs +++ b/ipa-core/src/hpke/mod.rs @@ -126,6 +126,36 @@ pub fn open_in_place<'a, R: PrivateKeyRegistry>( Ok(pt) } +/// Version of `open_in_place` that doesn't require Info struct. +pub fn hybrid_open_in_place<'a, R: PrivateKeyRegistry>( + key_registry: &R, + enc: &[u8], + ciphertext: &'a mut [u8], + key_id: u8, + info: &[u8], +) -> Result<&'a [u8], CryptError> { + let encap_key = ::EncappedKey::from_bytes(enc)?; + let (ct, tag) = ciphertext.split_at_mut(ciphertext.len() - AeadTag::::size()); + let tag = AeadTag::::from_bytes(tag)?; + let sk = key_registry + .private_key(key_id) + .ok_or(CryptError::NoSuchKey(key_id))?; + + single_shot_open_in_place_detached::<_, IpaKdf, IpaKem>( + &OpModeR::Base, + sk, + &encap_key, + &info, + ct, + &[], + &tag, + )?; + + // at this point ct is no longer a pointer to the ciphertext. + let pt = ct; + Ok(pt) +} + // Avoids a clippy "complex type" warning on the return type from `seal_in_place`. // Not intended to be widely used. pub(crate) type Ciphertext<'a> = ( @@ -161,6 +191,31 @@ pub(crate) fn seal_in_place<'a, R: CryptoRng + RngCore, K: PublicKeyRegistry>( Ok((encap_key, plaintext, tag)) } +/// Version of `seal_in_place` that doesn't require Info struct. +pub(crate) fn hybrid_seal_in_place<'a, R: CryptoRng + RngCore, K: PublicKeyRegistry>( + key_registry: &K, + plaintext: &'a mut [u8], + key_id: u8, + info: &[u8], + rng: &mut R, +) -> Result, CryptError> { + let pk_r = key_registry + .public_key(key_id) + .ok_or(CryptError::NoSuchKey(key_id))?; + + let (encap_key, tag) = single_shot_seal_in_place_detached::( + &OpModeS::Base, + pk_r, + &info, + plaintext, + &[], + rng, + )?; + + // at this point `plaintext` is no longer a pointer to the plaintext. + Ok((encap_key, plaintext, tag)) +} + #[cfg(all(test, unit_test))] mod tests { use generic_array::GenericArray; diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index d322f92ae..ba7d799f4 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -1,18 +1,71 @@ -use std::{collections::HashSet, ops::Add}; +use std::{collections::HashSet, ops::Add, ops::Deref}; +use std::marker::PhantomData; +use std::fmt::{Display, Formatter}; -use bytes::Bytes; -use generic_array::ArrayLength; +use bytes::{Bytes, BufMut}; +use generic_array::{ArrayLength, GenericArray}; +use hpke::Serializable as _; use rand_core::{CryptoRng, RngCore}; use typenum::{Sum, Unsigned, U16}; use crate::{ - error::Error, + error::{Error, BoxError}, ff::{boolean_array::BA64, Serializable}, - hpke::{EncapsulationSize, PrivateKeyRegistry, PublicKeyRegistry, TagSize}, + hpke::{hybrid_open_in_place, hybrid_seal_in_place, EncapsulationSize, PrivateKeyRegistry, PublicKeyRegistry, TagSize, CryptError}, report::{EncryptedOprfReport, EventType, InvalidReportError, KeyIdentifier}, + report::hybrid_info::HybridImpressionInfo, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, SharedValue}, }; +// TODO(679): This needs to come from configuration. +static HELPER_ORIGIN: &str = "github.com/private-attribution"; + +#[derive(Debug)] +pub struct NonAsciiStringError { + input: String, +} + +impl Display for NonAsciiStringError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "string contains non-ascii symbols: {}", self.input) + } +} + +impl std::error::Error for NonAsciiStringError {} + +impl From<&'_ [u8]> for NonAsciiStringError { + fn from(input: &[u8]) -> Self { + Self { + input: String::from_utf8( + input + .iter() + .copied() + .flat_map(std::ascii::escape_default) + .collect::>(), + ) + .unwrap(), + } + } +} + +impl From<&'_ str> for NonAsciiStringError { + fn from(input: &str) -> Self { + Self::from(input.as_bytes()) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum InvalidHybridReportError { + #[error("bad site_domain: {0}")] + NonAsciiString(#[from] NonAsciiStringError), + #[error("en/decryption failure: {0}")] + Crypt(#[from] CryptError), + #[error("failed to deserialize field {0}: {1}")] + DeserializationError(&'static str, #[source] BoxError), + #[error("report is too short: {0}, expected length at least: {1}")] + Length(usize, usize), +} + #[derive(Clone, Debug, Eq, PartialEq)] pub struct HybridImpressionReport where @@ -22,6 +75,114 @@ where breakdown_key: Replicated, } +impl Serializable for HybridImpressionReport +where + BK: SharedValue, + Replicated: Serializable, + as Serializable>::Size: Add, + < as Serializable>::Size as Add< as Serializable>::Size>>:: Output: ArrayLength, +{ + type Size = < as Serializable>::Size as Add< as Serializable>::Size>>:: Output; + type DeserializationError = InvalidHybridReportError; // as Serializable>::DeserializationError; + + fn serialize(&self, buf: &mut GenericArray) { + let mk_sz = as Serializable>::Size::USIZE; + let bk_sz = as Serializable>::Size::USIZE; + + self.match_key + .serialize(GenericArray::from_mut_slice(&mut buf[..mk_sz])); + + self.breakdown_key + .serialize(GenericArray::from_mut_slice(&mut buf[mk_sz..mk_sz + bk_sz])); + } + fn deserialize(buf: &GenericArray) -> Result { + let mk_sz = as Serializable>::Size::USIZE; + let bk_sz = as Serializable>::Size::USIZE; + let match_key = + Replicated::::deserialize(GenericArray::from_slice(&buf[..mk_sz])) + .map_err(|e| InvalidHybridReportError::DeserializationError("match_key", e.into()))?; + let breakdown_key = + Replicated::::deserialize(GenericArray::from_slice(&buf[mk_sz..mk_sz + bk_sz])) + .map_err(|e| InvalidHybridReportError::DeserializationError("breakdown_key", e.into()))?; + Ok(Self { match_key, breakdown_key }) + } +} + +impl HybridImpressionReport +where + BK: SharedValue, + Replicated: Serializable, + as Serializable>::Size: Add, + < as Serializable>::Size as Add< as Serializable>::Size>>:: Output: ArrayLength, +{ + const BTT_END: usize = as Serializable>::Size::USIZE; + + /// # Panics + /// If report length does not fit in `u16`. + pub fn encrypted_len(&self) -> u16 { + let len = EncryptedHybridImpressionReport::::SITE_DOMAIN_OFFSET; + len.try_into().unwrap() + } + /// # Errors + /// If there is a problem encrypting the report. + pub fn encrypt( + &self, + key_id: KeyIdentifier, + key_registry: &impl PublicKeyRegistry, + rng: &mut R, + ) -> Result, InvalidHybridReportError> { + let mut out = Vec::with_capacity(usize::from(self.encrypted_len())); + self.encrypt_to(key_id, key_registry, rng, &mut out)?; + debug_assert_eq!(out.len(), usize::from(self.encrypted_len())); + Ok(out) + } + + /// # Errors + /// If there is a problem encrypting the report. + pub fn encrypt_to( + &self, + key_id: KeyIdentifier, + key_registry: &impl PublicKeyRegistry, + rng: &mut R, + out: &mut B, + ) -> Result<(), InvalidHybridReportError> { + let info = HybridImpressionInfo::new(key_id, HELPER_ORIGIN)?; + + let mut plaintext_mk = GenericArray::default(); + self.match_key.serialize(&mut plaintext_mk); + + let mut plaintext_btt = vec![0u8; Self::BTT_END]; + self.breakdown_key + .serialize(GenericArray::from_mut_slice(&mut plaintext_btt[..])); + + let (encap_key_mk, ciphertext_mk, tag_mk) = hybrid_seal_in_place( + key_registry, + plaintext_mk.as_mut(), + key_id, + &info.to_bytes(), + rng, + )?; + + let (encap_key_btt, ciphertext_btt, tag_btt) = hybrid_seal_in_place( + key_registry, + plaintext_btt.as_mut(), + key_id, + &info.to_bytes(), + rng, + )?; + + out.put_slice(&encap_key_mk.to_bytes()); + out.put_slice(ciphertext_mk); + out.put_slice(&tag_mk.to_bytes()); + out.put_slice(&encap_key_btt.to_bytes()); + out.put_slice(ciphertext_btt); + out.put_slice(&tag_btt.to_bytes()); + out.put_slice(&[key_id]); + + Ok(()) + } +} + #[derive(Clone, Debug, Eq, PartialEq)] pub struct HybridConversionReport where @@ -41,28 +202,6 @@ where Conversion(HybridConversionReport), } -#[allow(dead_code)] -pub struct HybridImpressionInfo<'a> { - pub key_id: KeyIdentifier, - pub helper_origin: &'a str, -} - -#[allow(dead_code)] -pub struct HybridConversionInfo<'a> { - pub key_id: KeyIdentifier, - pub helper_origin: &'a str, - pub converion_site_domain: &'a str, - pub timestamp: u64, - pub epsilon: f64, - pub sensitivity: f64, -} - -#[allow(dead_code)] -pub enum HybridInfo<'a> { - Impression(HybridImpressionInfo<'a>), - Conversion(HybridConversionInfo<'a>), -} - impl HybridReport where BK: SharedValue, @@ -80,6 +219,117 @@ where } } +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct EncryptedHybridImpressionReport +where + B: Deref, + BK: SharedValue, +{ + data: B, + phantom_data: PhantomData, +} + +impl EncryptedHybridImpressionReport +where + B: Deref, + BK: SharedValue, + Replicated: Serializable, + as Serializable>::Size: Add, + < as Serializable>::Size as Add>::Output: ArrayLength, +{ + const ENCAP_KEY_MK_OFFSET: usize = 0; + const CIPHERTEXT_MK_OFFSET: usize = Self::ENCAP_KEY_MK_OFFSET + EncapsulationSize::USIZE; + const ENCAP_KEY_BTT_OFFSET: usize = (Self::CIPHERTEXT_MK_OFFSET + + TagSize::USIZE + + as Serializable>::Size::USIZE); + const CIPHERTEXT_BTT_OFFSET: usize = Self::ENCAP_KEY_BTT_OFFSET + EncapsulationSize::USIZE; + + const KEY_IDENTIFIER_OFFSET: usize = (Self::CIPHERTEXT_BTT_OFFSET + + TagSize::USIZE + + as Serializable>::Size::USIZE); + const SITE_DOMAIN_OFFSET: usize = Self::KEY_IDENTIFIER_OFFSET + 1; + + pub fn encap_key_mk(&self) -> &[u8] { + &self.data[Self::ENCAP_KEY_MK_OFFSET..Self::CIPHERTEXT_MK_OFFSET] + } + + pub fn mk_ciphertext(&self) -> &[u8] { + &self.data[Self::CIPHERTEXT_MK_OFFSET..Self::ENCAP_KEY_BTT_OFFSET] + } + + pub fn encap_key_btt(&self) -> &[u8] { + &self.data[Self::ENCAP_KEY_BTT_OFFSET..Self::CIPHERTEXT_BTT_OFFSET] + } + + pub fn btt_ciphertext(&self) -> &[u8] { + &self.data[Self::CIPHERTEXT_BTT_OFFSET..Self::KEY_IDENTIFIER_OFFSET] + } + + pub fn key_id(&self) -> KeyIdentifier { + self.data[Self::KEY_IDENTIFIER_OFFSET] + } + + /// ## Errors + /// If the report contents are invalid. + pub fn from_bytes(bytes: B) -> Result { + if bytes.len() < Self::SITE_DOMAIN_OFFSET { + return Err(InvalidHybridReportError::Length( + bytes.len(), + Self::SITE_DOMAIN_OFFSET, + )); + } + Ok(Self { + data: bytes, + phantom_data: PhantomData, + }) + } + + /// ## Errors + /// If the match key shares in the report cannot be decrypted (e.g. due to a + /// failure of the authenticated encryption). + /// ## Panics + /// Should not panic. Only panics if a `Report` constructor failed to validate the + /// contents properly, which would be a bug. + pub fn decrypt( + &self, + key_registry: &P, + ) -> Result, InvalidHybridReportError> { + type CTMKLength = Sum< as Serializable>::Size, TagSize>; + type CTBTTLength = < as Serializable>::Size as Add>::Output; + + let info = HybridImpressionInfo::new(self.key_id(), HELPER_ORIGIN).unwrap(); // validated on construction + + let mut ct_mk: GenericArray = + *GenericArray::from_slice(self.mk_ciphertext()); + let plaintext_mk = hybrid_open_in_place( + key_registry, + self.encap_key_mk(), + &mut ct_mk, + self.key_id(), + &info.to_bytes(), + )?; + let mut ct_btt: GenericArray> = + GenericArray::from_slice(self.btt_ciphertext()).clone(); + + let plaintext_btt = hybrid_open_in_place( + key_registry, + self.encap_key_btt(), + &mut ct_btt, + self.key_id(), + &info.to_bytes(), + )?; + + Ok(HybridImpressionReport:: { + match_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_mk)) + .map_err(|e| InvalidHybridReportError::DeserializationError("matchkey", e.into()))?, + breakdown_key: Replicated::::deserialize(GenericArray::from_slice( + &plaintext_btt[..], + )) + .map_err(|e| InvalidHybridReportError::DeserializationError("is_trigger", e.into()))?, + }) + } +} + #[derive(Clone)] pub struct EncryptedHybridReport { bytes: Bytes, @@ -210,14 +460,16 @@ impl UniqueBytesValidator { mod test { use rand::{distributions::Alphanumeric, rngs::ThreadRng, thread_rng, Rng}; + use typenum::Unsigned; use super::{ - EncryptedHybridReport, HybridConversionReport, HybridImpressionReport, HybridReport, + EncryptedHybridReport, EncryptedHybridImpressionReport, GenericArray, HybridConversionReport, HybridImpressionReport, HybridReport, UniqueBytes, UniqueBytesValidator, }; use crate::{ error::Error, ff::boolean_array::{BA20, BA3, BA8}, + ff::Serializable, hpke::{KeyPair, KeyRegistry}, report::{EventType, OprfReport}, secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, @@ -341,4 +593,52 @@ mod test { let expected_err = unique_bytes.check_duplicates(&[bytes4, bytes3]); assert!(matches!(expected_err, Err(Error::DuplicateBytes(6)))); } + + #[test] + fn serialization_hybrid_impression() { + let mut rng = thread_rng(); + let b = EventType::Source; + let oprf_report = build_oprf_report(b, &mut rng); + + let hybrid_impression_report = HybridImpressionReport:: { + match_key: oprf_report.match_key.clone(), + breakdown_key: oprf_report.breakdown_key.clone(), + }; + let mut hybrid_impression_report_bytes = + vec![0u8; as Serializable>::Size::USIZE]; + hybrid_impression_report.serialize(GenericArray::from_mut_slice( + &mut hybrid_impression_report_bytes[..], + )); + let hybrid_impression_report2 = HybridImpressionReport::::deserialize( + GenericArray::from_mut_slice(&mut hybrid_impression_report_bytes[..]), + ) + .unwrap(); + assert_eq!(hybrid_impression_report, hybrid_impression_report2); + } + + #[test] + fn enc_dec_roundtrip_hybrid_impression() { + let mut rng = thread_rng(); + let b = EventType::Source; + let oprf_report = build_oprf_report(b, &mut rng); + + let hybrid_impression_report = HybridImpressionReport:: { + match_key: oprf_report.match_key.clone(), + breakdown_key: oprf_report.breakdown_key.clone(), + }; + + let key_registry = KeyRegistry::::random(1, &mut rng); + let key_id = 0; + + let enc_report_bytes = hybrid_impression_report + .encrypt(key_id, &key_registry, &mut rng) + .unwrap(); + + let enc_report = + EncryptedHybridImpressionReport::::from_bytes(enc_report_bytes.as_slice()) + .unwrap(); + let dec_report: HybridImpressionReport = enc_report.decrypt(&key_registry).unwrap(); + + assert_eq!(dec_report, hybrid_impression_report); + } } diff --git a/ipa-core/src/report/hybrid_info.rs b/ipa-core/src/report/hybrid_info.rs new file mode 100644 index 000000000..f0faf2b04 --- /dev/null +++ b/ipa-core/src/report/hybrid_info.rs @@ -0,0 +1,67 @@ +use crate::{ + report::KeyIdentifier, + report::hybrid::NonAsciiStringError, +}; + +const DOMAIN: &str = "private-attribution"; + +pub struct HybridImpressionInfo<'a> { + pub key_id: KeyIdentifier, + pub helper_origin: &'a str, +} + +#[allow(dead_code)] +pub struct HybridConversionInfo<'a> { + pub key_id: KeyIdentifier, + pub helper_origin: &'a str, + pub converion_site_domain: &'a str, + pub timestamp: u64, + pub epsilon: f64, + pub sensitivity: f64, +} + +#[allow(dead_code)] +pub enum HybridInfo<'a> { + Impression(HybridImpressionInfo<'a>), + Conversion(HybridConversionInfo<'a>), +} + +impl<'a> HybridImpressionInfo<'a> { + /// Creates a new instance. + /// + /// ## Errors + /// if helper or site origin is not a valid ASCII string. + pub fn new(key_id: KeyIdentifier, helper_origin: &'a str) -> Result { + // If the types of errors returned from this function change, then the validation in + // `EncryptedReport::from_bytes` may need to change as well. + if !helper_origin.is_ascii() { + return Err(helper_origin.into()); + } + + Ok(Self { + key_id, + helper_origin, + }) + } + + // Converts this instance into an owned byte slice that can further be used to create HPKE + // sender or receiver context. + pub(super) fn to_bytes(&self) -> Box<[u8]> { + let info_len = DOMAIN.len() + + self.helper_origin.len() + + 2 // delimiters(?) + + std::mem::size_of_val(&self.key_id); + let mut r = Vec::with_capacity(info_len); + + r.extend_from_slice(DOMAIN.as_bytes()); + r.push(0); + r.extend_from_slice(self.helper_origin.as_bytes()); + r.push(0); + + r.push(self.key_id); + + debug_assert_eq!(r.len(), info_len, "HPKE Info length estimation is incorrect and leads to extra allocation or wasted memory"); + + r.into_boxed_slice() + } +} diff --git a/ipa-core/src/report/mod.rs b/ipa-core/src/report/mod.rs index 28fd9e683..192c87aca 100644 --- a/ipa-core/src/report/mod.rs +++ b/ipa-core/src/report/mod.rs @@ -1,3 +1,4 @@ pub mod ipa; pub use self::ipa::*; pub mod hybrid; +pub mod hybrid_info; From 17510bccd6578698ceb55d50dacc36ad25885011 Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Wed, 16 Oct 2024 10:21:26 -0700 Subject: [PATCH 134/191] rustfmt --- ipa-core/src/report/hybrid.rs | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index ba7d799f4..4ce50207d 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -1,19 +1,22 @@ -use std::{collections::HashSet, ops::Add, ops::Deref}; -use std::marker::PhantomData; use std::fmt::{Display, Formatter}; +use std::marker::PhantomData; +use std::{collections::HashSet, ops::Add, ops::Deref}; -use bytes::{Bytes, BufMut}; +use bytes::{BufMut, Bytes}; use generic_array::{ArrayLength, GenericArray}; use hpke::Serializable as _; use rand_core::{CryptoRng, RngCore}; use typenum::{Sum, Unsigned, U16}; use crate::{ - error::{Error, BoxError}, + error::{BoxError, Error}, ff::{boolean_array::BA64, Serializable}, - hpke::{hybrid_open_in_place, hybrid_seal_in_place, EncapsulationSize, PrivateKeyRegistry, PublicKeyRegistry, TagSize, CryptError}, - report::{EncryptedOprfReport, EventType, InvalidReportError, KeyIdentifier}, + hpke::{ + hybrid_open_in_place, hybrid_seal_in_place, CryptError, EncapsulationSize, + PrivateKeyRegistry, PublicKeyRegistry, TagSize, + }, report::hybrid_info::HybridImpressionInfo, + report::{EncryptedOprfReport, EventType, InvalidReportError, KeyIdentifier}, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, SharedValue}, }; @@ -321,7 +324,9 @@ where Ok(HybridImpressionReport:: { match_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_mk)) - .map_err(|e| InvalidHybridReportError::DeserializationError("matchkey", e.into()))?, + .map_err(|e| { + InvalidHybridReportError::DeserializationError("matchkey", e.into()) + })?, breakdown_key: Replicated::::deserialize(GenericArray::from_slice( &plaintext_btt[..], )) @@ -463,8 +468,9 @@ mod test { use typenum::Unsigned; use super::{ - EncryptedHybridReport, EncryptedHybridImpressionReport, GenericArray, HybridConversionReport, HybridImpressionReport, HybridReport, - UniqueBytes, UniqueBytesValidator, + EncryptedHybridImpressionReport, EncryptedHybridReport, GenericArray, + HybridConversionReport, HybridImpressionReport, HybridReport, UniqueBytes, + UniqueBytesValidator, }; use crate::{ error::Error, From 22a7eeb9e09bd241cc3095fead803aa87c100785 Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Wed, 16 Oct 2024 10:43:08 -0700 Subject: [PATCH 135/191] ./pre-commit --- ipa-core/src/hpke/mod.rs | 8 ++++++-- ipa-core/src/report/hybrid.rs | 31 ++++++++++++++++++------------ ipa-core/src/report/hybrid_info.rs | 5 +---- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/ipa-core/src/hpke/mod.rs b/ipa-core/src/hpke/mod.rs index a1810aaa4..66991515a 100644 --- a/ipa-core/src/hpke/mod.rs +++ b/ipa-core/src/hpke/mod.rs @@ -127,6 +127,8 @@ pub fn open_in_place<'a, R: PrivateKeyRegistry>( } /// Version of `open_in_place` that doesn't require Info struct. +/// ## Errors +/// If ciphertext cannot be opened for any reason. pub fn hybrid_open_in_place<'a, R: PrivateKeyRegistry>( key_registry: &R, enc: &[u8], @@ -145,7 +147,7 @@ pub fn hybrid_open_in_place<'a, R: PrivateKeyRegistry>( &OpModeR::Base, sk, &encap_key, - &info, + info, ct, &[], &tag, @@ -192,6 +194,8 @@ pub(crate) fn seal_in_place<'a, R: CryptoRng + RngCore, K: PublicKeyRegistry>( } /// Version of `seal_in_place` that doesn't require Info struct. +/// ## Errors +/// If the match key cannot be sealed for any reason. pub(crate) fn hybrid_seal_in_place<'a, R: CryptoRng + RngCore, K: PublicKeyRegistry>( key_registry: &K, plaintext: &'a mut [u8], @@ -206,7 +210,7 @@ pub(crate) fn hybrid_seal_in_place<'a, R: CryptoRng + RngCore, K: PublicKeyRegis let (encap_key, tag) = single_shot_seal_in_place_detached::( &OpModeS::Base, pk_r, - &info, + info, plaintext, &[], rng, diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index 4ce50207d..6235b17cf 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -1,6 +1,9 @@ -use std::fmt::{Display, Formatter}; -use std::marker::PhantomData; -use std::{collections::HashSet, ops::Add, ops::Deref}; +use std::{ + collections::HashSet, + fmt::{Display, Formatter}, + marker::PhantomData, + ops::{Add, Deref}, +}; use bytes::{BufMut, Bytes}; use generic_array::{ArrayLength, GenericArray}; @@ -15,8 +18,10 @@ use crate::{ hybrid_open_in_place, hybrid_seal_in_place, CryptError, EncapsulationSize, PrivateKeyRegistry, PublicKeyRegistry, TagSize, }, - report::hybrid_info::HybridImpressionInfo, - report::{EncryptedOprfReport, EventType, InvalidReportError, KeyIdentifier}, + report::{ + hybrid_info::HybridImpressionInfo, EncryptedOprfReport, EventType, InvalidReportError, + KeyIdentifier, + }, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, SharedValue}, }; @@ -327,10 +332,10 @@ where .map_err(|e| { InvalidHybridReportError::DeserializationError("matchkey", e.into()) })?, - breakdown_key: Replicated::::deserialize(GenericArray::from_slice( - &plaintext_btt[..], - )) - .map_err(|e| InvalidHybridReportError::DeserializationError("is_trigger", e.into()))?, + breakdown_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_btt)) + .map_err(|e| { + InvalidHybridReportError::DeserializationError("is_trigger", e.into()) + })?, }) } } @@ -474,8 +479,10 @@ mod test { }; use crate::{ error::Error, - ff::boolean_array::{BA20, BA3, BA8}, - ff::Serializable, + ff::{ + boolean_array::{BA20, BA3, BA8}, + Serializable, + }, hpke::{KeyPair, KeyRegistry}, report::{EventType, OprfReport}, secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, @@ -611,7 +618,7 @@ mod test { breakdown_key: oprf_report.breakdown_key.clone(), }; let mut hybrid_impression_report_bytes = - vec![0u8; as Serializable>::Size::USIZE]; + [0u8; as Serializable>::Size::USIZE]; hybrid_impression_report.serialize(GenericArray::from_mut_slice( &mut hybrid_impression_report_bytes[..], )); diff --git a/ipa-core/src/report/hybrid_info.rs b/ipa-core/src/report/hybrid_info.rs index f0faf2b04..f6bb657c1 100644 --- a/ipa-core/src/report/hybrid_info.rs +++ b/ipa-core/src/report/hybrid_info.rs @@ -1,7 +1,4 @@ -use crate::{ - report::KeyIdentifier, - report::hybrid::NonAsciiStringError, -}; +use crate::report::{hybrid::NonAsciiStringError, KeyIdentifier}; const DOMAIN: &str = "private-attribution"; From da336afcd4c4a1b0636fc89192b774c67a7dd4c0 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 16 Oct 2024 13:05:34 -0700 Subject: [PATCH 136/191] Fixes for aggregation batching * Don't start all batches from record 0. * Fix a bug in insert_segment_large when segments are added out-of-order. * Remove unused `last_record` and tweak `is_empty`. * Add a few tests. --- .../src/protocol/context/dzkp_validator.rs | 310 ++++++++++++------ 1 file changed, 218 insertions(+), 92 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 70b963fd3..d21f1d1c3 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -1,4 +1,4 @@ -use std::{cmp, collections::BTreeMap, fmt::Debug, future::ready}; +use std::{collections::BTreeMap, fmt::Debug, future::ready}; use async_trait::async_trait; use bitvec::prelude::{BitArray, BitSlice, Lsb0}; @@ -109,6 +109,31 @@ impl MultiplicationInputsBlock { }) } + /// set using bitslices + /// ## Errors + /// Errors when length of slices is not 256 bit + #[allow(clippy::too_many_arguments)] + fn set( + &mut self, + x_left: &BitSliceType, + x_right: &BitSliceType, + y_left: &BitSliceType, + y_right: &BitSliceType, + prss_left: &BitSliceType, + prss_right: &BitSliceType, + z_right: &BitSliceType, + ) -> Result<(), BoxError> { + self.x_left = BitArray::try_from(x_left)?; + self.x_right = BitArray::try_from(x_right)?; + self.y_left = BitArray::try_from(y_left)?; + self.y_right = BitArray::try_from(y_right)?; + self.prss_left = BitArray::try_from(prss_left)?; + self.prss_right = BitArray::try_from(prss_right)?; + self.z_right = BitArray::try_from(z_right)?; + + Ok(()) + } + /// `Convert` allows to convert `MultiplicationInputs` into a format compatible with DZKPs /// This is the convert function called by the prover. fn convert_prover(&self) -> Vec> { @@ -237,29 +262,29 @@ impl<'a> SegmentEntry<'a> { /// `MultiplicationInputsBatch` stores a batch of multiplication inputs in a vector of `MultiplicationInputsBlock`. /// `first_record` is the first `RecordId` for the current batch. -/// `last_record` keeps track of the highest record that has been added to the batch. /// `max_multiplications` is the maximum amount of multiplications performed within a this batch. /// It is used to determine the vector length during the allocation. /// If there are more multiplications, it will cause a panic! /// `multiplication_bit_size` is the bit size of a single multiplication. The size will be consistent /// across all multiplications of a gate. -/// `is_empty` keeps track of whether any value has been added #[derive(Clone, Debug)] struct MultiplicationInputsBatch { - first_record: RecordId, - last_record: RecordId, + first_record: Option, max_multiplications: usize, multiplication_bit_size: usize, - is_empty: bool, vec: Vec, } impl MultiplicationInputsBatch { - /// Creates a new store for multiplication intermediates for records starting from - /// `first_record`. The size of the allocated vector is + /// Creates a new store for multiplication intermediates. The first record is + /// specified by `first_record`, or if that is `None`, is set automatically the + /// first time a segment is added to the batch. Once the first record is set, + /// attempting to add a segment before the first record will panic. + /// + /// The size of the allocated vector is /// `ceil((max_multiplications * multiplication_bit_size) / BIT_ARRAY_LEN)`. fn new( - first_record: RecordId, + first_record: Option, max_multiplications: usize, multiplication_bit_size: usize, ) -> Self { @@ -277,10 +302,8 @@ impl MultiplicationInputsBatch { ); Self { first_record, - last_record: first_record, max_multiplications, multiplication_bit_size, - is_empty: false, vec: Vec::with_capacity((capacity_bits + BIT_ARRAY_MASK) >> BIT_ARRAY_SHIFT), } } @@ -293,7 +316,7 @@ impl MultiplicationInputsBatch { /// returns whether the store is empty fn is_empty(&self) -> bool { - self.is_empty + self.vec.is_empty() } /// `insert_segment` allows to include a new segment in `MultiplicationInputsBatch`. @@ -308,21 +331,27 @@ impl MultiplicationInputsBatch { // check segment size debug_assert_eq!(segment.len(), self.multiplication_bit_size); + let first_record = *self.first_record.get_or_insert(record_id); + // panics when record_id is out of bounds - assert!(record_id >= self.first_record); assert!( - usize::from(record_id) < self.max_multiplications + usize::from(self.first_record), + record_id >= first_record, + "record_id out of range in insert_segment. record {record_id} is before \ + first record {first_record}", + ); + assert!( + usize::from(record_id) + < self + .max_multiplications + .saturating_add(usize::from(first_record)), "record_id out of range in insert_segment. record {record_id} is beyond \ segment of length {} starting at {}", self.max_multiplications, - self.first_record, + first_record, ); - // update last record - self.last_record = cmp::max(self.last_record, record_id); - // panics when record_id is too large to fit in, i.e. when it is out of bounds - if segment.len() <= 256 { + if segment.len() < 256 { self.insert_segment_small(record_id, segment); } else { self.insert_segment_large(record_id, &segment); @@ -337,15 +366,8 @@ impl MultiplicationInputsBatch { /// than the first record of the batch, i.e. `first_record` /// or too large, i.e. `first_record+max_multiplications` fn insert_segment_small(&mut self, record_id: RecordId, segment: Segment) { - // check length - debug_assert!(segment.len() <= 256); - - // panics when record_id is out of bounds - assert!(record_id >= self.first_record); - assert!(usize::from(record_id) < self.max_multiplications + usize::from(self.first_record)); - // panics when record_id is less than first_record - let id_within_batch = usize::from(record_id) - usize::from(self.first_record); + let id_within_batch = usize::from(record_id) - usize::from(self.first_record.unwrap()); // round up segment length to a power of two since we want to have divisors of 256 let length = segment.len().next_power_of_two(); @@ -386,14 +408,7 @@ impl MultiplicationInputsBatch { /// than the first record of the batch, i.e. `first_record` /// or too large, i.e. `first_record+max_multiplications` fn insert_segment_large(&mut self, record_id: RecordId, segment: &Segment) { - // check length - debug_assert_eq!(segment.len() % 256, 0); - - // panics when record_id is out of bounds - assert!(record_id >= self.first_record); - assert!(usize::from(record_id) < self.max_multiplications + usize::from(self.first_record)); - - let id_within_batch = usize::from(record_id) - usize::from(self.first_record); + let id_within_batch = usize::from(record_id) - usize::from(self.first_record.unwrap()); let block_id = (segment.len() * id_within_batch) >> BIT_ARRAY_SHIFT; let length_in_blocks = segment.len() >> BIT_ARRAY_SHIFT; if self.vec.len() < block_id { @@ -402,8 +417,9 @@ impl MultiplicationInputsBatch { } for i in 0..length_in_blocks { - self.vec.push( - MultiplicationInputsBlock::clone_from( + if self.vec.len() > block_id + i { + MultiplicationInputsBlock::set( + &mut self.vec[block_id + i], &segment.x_left.0[256 * i..256 * (i + 1)], &segment.x_right.0[256 * i..256 * (i + 1)], &segment.y_left.0[256 * i..256 * (i + 1)], @@ -412,8 +428,21 @@ impl MultiplicationInputsBatch { &segment.prss_right.0[256 * i..256 * (i + 1)], &segment.z_right.0[256 * i..256 * (i + 1)], ) - .unwrap(), - ); + .unwrap(); + } else { + self.vec.push( + MultiplicationInputsBlock::clone_from( + &segment.x_left.0[256 * i..256 * (i + 1)], + &segment.x_right.0[256 * i..256 * (i + 1)], + &segment.y_left.0[256 * i..256 * (i + 1)], + &segment.y_right.0[256 * i..256 * (i + 1)], + &segment.prss_left.0[256 * i..256 * (i + 1)], + &segment.prss_right.0[256 * i..256 * (i + 1)], + &segment.z_right.0[256 * i..256 * (i + 1)], + ) + .unwrap(), + ); + } } } @@ -457,12 +486,15 @@ impl MultiplicationInputsBatch { #[derive(Debug)] pub(super) struct Batch { max_multiplications_per_gate: usize, - first_record: RecordId, + first_record: Option, inner: BTreeMap, } impl Batch { - fn new(first_record: RecordId, max_multiplications_per_gate: usize) -> Self { + /// Creates a new `Batch` for multiplication intermediates from multiple gates. The + /// first record is specified by `first_record`, or if that is `None`, is set + /// automatically for each gate the first time a segment from that gate is added. + fn new(first_record: Option, max_multiplications_per_gate: usize) -> Self { Self { max_multiplications_per_gate, first_record, @@ -802,10 +834,9 @@ impl<'a, B: ShardBinding> MaliciousDZKPValidator<'a, B> { max_multiplications_per_gate, ctx.total_records(), Box::new(move |batch_index| { - Batch::new( - RecordId::from(batch_index * max_multiplications_per_gate), - max_multiplications_per_gate, - ) + let first_record = (max_multiplications_per_gate != usize::MAX) + .then(|| RecordId::from(batch_index * max_multiplications_per_gate)); + Batch::new(first_record, max_multiplications_per_gate) }), ); let inner = Arc::new(MaliciousDZKPValidatorInner { @@ -1362,22 +1393,27 @@ mod tests { .await; } + fn segment_from_entry(entry: SegmentEntry) -> Segment { + Segment::from_entries( + entry.clone(), + entry.clone(), + entry.clone(), + entry.clone(), + entry.clone(), + entry.clone(), + entry, + ) + } + #[test] fn batch_allocation_small() { const SIZE: usize = 1; - let mut batch = Batch::new(RecordId::FIRST, SIZE); + let mut batch = Batch::new(None, SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); - let segment_entry = >::as_segment_entry(&zero_vec); - let segment = Segment::from_entries( - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry, - ); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); batch.push(Gate::default(), RecordId::FIRST, segment); assert_eq!(batch.inner.get(&Gate::default()).unwrap().vec.len(), 1); assert!(batch.inner.get(&Gate::default()).unwrap().vec.capacity() >= SIZE); @@ -1387,19 +1423,12 @@ mod tests { #[test] fn batch_allocation_big() { const SIZE: usize = 2 * TARGET_PROOF_SIZE; - let mut batch = Batch::new(RecordId::FIRST, SIZE); + let mut batch = Batch::new(None, SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); - let segment_entry = >::as_segment_entry(&zero_vec); - let segment = Segment::from_entries( - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry, - ); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); batch.push(Gate::default(), RecordId::FIRST, segment); assert_eq!(batch.inner.get(&Gate::default()).unwrap().vec.len(), 1); assert!( @@ -1415,19 +1444,12 @@ mod tests { #[test] fn batch_fill() { const SIZE: usize = 10; - let mut batch = Batch::new(RecordId::FIRST, SIZE); + let mut batch = Batch::new(None, SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); - let segment_entry = >::as_segment_entry(&zero_vec); - let segment = Segment::from_entries( - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry, - ); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); for i in 0..SIZE { batch.push(Gate::default(), RecordId::from(i), segment.clone()); } @@ -1436,25 +1458,129 @@ mod tests { assert!(batch.inner.get(&Gate::default()).unwrap().vec.capacity() <= 2); } + #[test] + fn batch_fill_out_of_order() { + let mut batch = Batch::new(None, 3); + let ba0 = BA256::from((0, 0)); + let ba1 = BA256::from((0, 1)); + let ba2 = BA256::from((0, 2)); + let segment = segment_from_entry(>::as_segment_entry( + &ba0, + )); + batch.push(Gate::default(), RecordId::from(0), segment.clone()); + let segment = segment_from_entry(>::as_segment_entry( + &ba2, + )); + batch.push(Gate::default(), RecordId::from(2), segment.clone()); + let segment = segment_from_entry(>::as_segment_entry( + &ba1, + )); + batch.push(Gate::default(), RecordId::from(1), segment.clone()); + assert_eq!(batch.inner.get(&Gate::default()).unwrap().vec.len(), 3); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[0].x_left, + ba0.as_bitslice() + ); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[1].x_left, + ba1.as_bitslice() + ); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[2].x_left, + ba2.as_bitslice() + ); + } + + #[test] + fn batch_fill_at_offset() { + let mut batch = Batch::new(None, 3); + let ba0 = BA256::from((0, 0)); + let ba1 = BA256::from((0, 1)); + let ba2 = BA256::from((0, 2)); + let segment = segment_from_entry(>::as_segment_entry( + &ba0, + )); + batch.push(Gate::default(), RecordId::from(4), segment.clone()); + let segment = segment_from_entry(>::as_segment_entry( + &ba1, + )); + batch.push(Gate::default(), RecordId::from(5), segment.clone()); + let segment = segment_from_entry(>::as_segment_entry( + &ba2, + )); + batch.push(Gate::default(), RecordId::from(6), segment.clone()); + assert_eq!(batch.inner.get(&Gate::default()).unwrap().vec.len(), 3); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[0].x_left, + ba0.as_bitslice() + ); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[1].x_left, + ba1.as_bitslice() + ); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[2].x_left, + ba2.as_bitslice() + ); + } + + #[test] + fn batch_explicit_first_record() { + let mut batch = Batch::new(Some(RecordId::from(4)), 3); + let ba6 = BA256::from((0, 6)); + let segment = segment_from_entry(>::as_segment_entry( + &ba6, + )); + batch.push(Gate::default(), RecordId::from(6), segment.clone()); + assert_eq!(batch.inner.get(&Gate::default()).unwrap().vec.len(), 3); + assert_eq!( + batch.inner.get(&Gate::default()).unwrap().vec[2].x_left, + ba6.as_bitslice() + ); + } + + #[test] + fn batch_is_empty() { + const SIZE: usize = 10; + let mut batch = Batch::new(None, SIZE); + assert!(batch.is_empty()); + let zero = Boolean::ZERO; + let zero_vec: >::Array = zero.into_array(); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); + batch.push(Gate::default(), RecordId::FIRST, segment); + assert!(!batch.is_empty()); + } + + #[test] + #[should_panic( + expected = "record_id out of range in insert_segment. record 0 is before first record 10" + )] + fn batch_underflow() { + const SIZE: usize = 10; + let mut batch = Batch::new(None, SIZE); + let zero = Boolean::ZERO; + let zero_vec: >::Array = zero.into_array(); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); + batch.push(Gate::default(), RecordId::from(10), segment.clone()); + batch.push(Gate::default(), RecordId::from(0), segment.clone()); + } + #[test] #[should_panic( expected = "record_id out of range in insert_segment. record 10 is beyond segment of length 10 starting at 0" )] fn batch_overflow() { const SIZE: usize = 10; - let mut batch = Batch::new(RecordId::FIRST, SIZE); + let mut batch = Batch::new(None, SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); - let segment_entry = >::as_segment_entry(&zero_vec); - let segment = Segment::from_entries( - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry.clone(), - segment_entry, - ); + let segment = segment_from_entry(>::as_segment_entry( + &zero_vec, + )); for i in 0..=SIZE { batch.push(Gate::default(), RecordId::from(i), segment.clone()); } @@ -1585,13 +1711,13 @@ mod tests { // test for small and large segments, i.e. 8bit and 512 bit for segment_size in [8usize, 512usize] { // generate batch for the prover - let mut batch_prover = Batch::new(RecordId::FIRST, 1024 / segment_size); + let mut batch_prover = Batch::new(None, 1024 / segment_size); // generate batch for the verifier on the left of the prover - let mut batch_left = Batch::new(RecordId::FIRST, 1024 / segment_size); + let mut batch_left = Batch::new(None, 1024 / segment_size); // generate batch for the verifier on the right of the prover - let mut batch_right = Batch::new(RecordId::FIRST, 1024 / segment_size); + let mut batch_right = Batch::new(None, 1024 / segment_size); // fill the batches with random values populate_batch( From 8c672f8fa8344c98af4897206d63d6fe1d7864d0 Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Wed, 16 Oct 2024 14:32:40 -0700 Subject: [PATCH 137/191] impl UniqueBytes, test non-ascii strings --- ipa-core/src/report/hybrid.rs | 28 ++++++++++++++++++++++++++-- ipa-core/src/report/hybrid_info.rs | 1 + 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index 6235b17cf..b9c7c5014 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -91,7 +91,7 @@ where < as Serializable>::Size as Add< as Serializable>::Size>>:: Output: ArrayLength, { type Size = < as Serializable>::Size as Add< as Serializable>::Size>>:: Output; - type DeserializationError = InvalidHybridReportError; // as Serializable>::DeserializationError; + type DeserializationError = InvalidHybridReportError; fn serialize(&self, buf: &mut GenericArray) { let mk_sz = as Serializable>::Size::USIZE; @@ -422,6 +422,21 @@ impl UniqueBytes for EncryptedHybridReport { } } +impl UniqueBytes for EncryptedHybridImpressionReport +where + B: Deref, + BK: SharedValue, + Replicated: Serializable, + as Serializable>::Size: Add, + < as Serializable>::Size as Add>::Output: ArrayLength, +{ + /// We use the `TagSize` (the first 16 bytes of the ciphertext) for collision-detection + /// See [analysis here for uniqueness](https://eprint.iacr.org/2019/624) + fn unique_bytes(&self) -> Vec { + self.mk_ciphertext()[0..TagSize::USIZE].to_vec() + } +} + #[derive(Debug)] pub struct UniqueBytesValidator { hash_set: HashSet>, @@ -484,7 +499,9 @@ mod test { Serializable, }, hpke::{KeyPair, KeyRegistry}, - report::{EventType, OprfReport}, + report::{ + hybrid::NonAsciiStringError, hybrid_info::HybridImpressionInfo, EventType, OprfReport, + }, secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, }; @@ -654,4 +671,11 @@ mod test { assert_eq!(dec_report, hybrid_impression_report); } + + #[test] + fn non_ascii_string() { + let non_ascii_string = "☃️☃️☃️"; + let err = HybridImpressionInfo::new(0, non_ascii_string).unwrap_err(); + assert!(matches!(err, NonAsciiStringError { input: _ })); + } } diff --git a/ipa-core/src/report/hybrid_info.rs b/ipa-core/src/report/hybrid_info.rs index f6bb657c1..c41849121 100644 --- a/ipa-core/src/report/hybrid_info.rs +++ b/ipa-core/src/report/hybrid_info.rs @@ -2,6 +2,7 @@ use crate::report::{hybrid::NonAsciiStringError, KeyIdentifier}; const DOMAIN: &str = "private-attribution"; +#[derive(Debug)] pub struct HybridImpressionInfo<'a> { pub key_id: KeyIdentifier, pub helper_origin: &'a str, From fee714bed743fcb14b3c06daaf1745eae193eb62 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 16 Oct 2024 15:58:13 -0700 Subject: [PATCH 138/191] Additional comments about aggregation proof batching --- .../src/protocol/context/dzkp_validator.rs | 9 +++++++++ .../ipa_prf/aggregation/breakdown_reveal.rs | 18 ++++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index d21f1d1c3..3e83d797f 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -494,6 +494,15 @@ impl Batch { /// Creates a new `Batch` for multiplication intermediates from multiple gates. The /// first record is specified by `first_record`, or if that is `None`, is set /// automatically for each gate the first time a segment from that gate is added. + /// + /// Once the first record is set, attempting to add a segment before the first + /// record will panic. It is likely, but not guaranteed, that protocol execution + /// proceeds in order, so a problem here can easily escape testing. + /// * When using the `Batcher` in multi-batch mode, `first_record` is calculated + /// from the batch index and the number of records in a batch, so there is no + /// possibility of attempting to add a record before the start of the batch. + /// * The only protocol that manages batches explicitly is the aggregation protocol + /// (`breakdown_reveal_aggregation`). It is structured to operate in order. fn new(first_record: Option, max_multiplications_per_gate: usize) -> Self { Self { max_multiplications_per_gate, diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index 65ec97230..7e10adf57 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -50,6 +50,19 @@ use crate::{ /// 2. Reveal breakdown keys. This is the key difference to the previous /// aggregation (see [`reveal_breakdowns`]). /// 3. Add all values for each breakdown. +/// +/// This protocol explicitly manages proof batches for DZKP-based malicious security by +/// processing chunks of values from `intermediate_results.chunks()`. Procession +/// through record IDs is not uniform for all of the gates in the protocol. The first +/// layer of the reduction adds N pairs of records, the second layer adds N/2 pairs of +/// records, etc. This has a few consequences: +/// * We must specify a batch size of `usize::MAX` when calling `dzkp_validator`. +/// * We must track record IDs across chunks, so that subsequent chunks can +/// start from the last record ID that was used in the previous chunk. +/// * Because the first record ID in the proof batch is set implicitly, we must +/// guarantee that it submits multiplication intermediates before any other +/// record. This is currently ensured by the serial operation of the aggregation +/// protocol (i.e. by not using `seq_join`). #[tracing::instrument(name = "breakdown_reveal_aggregation", skip_all, fields(total = attributed_values.len()))] pub async fn breakdown_reveal_aggregation( ctx: C, @@ -107,10 +120,7 @@ where protocol: &Step::aggregate(depth), validate: &Step::aggregate_validate(chunk_counter), }, - // We have to specify usize::MAX here because the procession through - // record IDs is different at each step of the reduction. The batch - // size is limited by `intermediate_results.chunks()`, above. - usize::MAX, + usize::MAX, // See note about batching above. ); let result = aggregate_values::<_, HV, B>( validator.context(), From e43ca3d55858d6c4ef322196897e4f26319a76c8 Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Wed, 16 Oct 2024 16:04:57 -0700 Subject: [PATCH 139/191] condense seal_in_place and open_in_place methods --- ipa-core/src/hpke/info.rs | 2 +- ipa-core/src/hpke/mod.rs | 74 +++++------------------------------ ipa-core/src/report/hybrid.rs | 12 +++--- ipa-core/src/report/ipa.rs | 34 +++++++++++++--- 4 files changed, 45 insertions(+), 77 deletions(-) diff --git a/ipa-core/src/hpke/info.rs b/ipa-core/src/hpke/info.rs index 584f46525..e0b7a794b 100644 --- a/ipa-core/src/hpke/info.rs +++ b/ipa-core/src/hpke/info.rs @@ -52,7 +52,7 @@ impl<'a> Info<'a> { /// Converts this instance into an owned byte slice that can further be used to create HPKE /// sender or receiver context. - pub(super) fn to_bytes(&self) -> Box<[u8]> { + pub(crate) fn to_bytes(&self) -> Box<[u8]> { let info_len = DOMAIN.len() + self.helper_origin.len() + self.site_domain.len() diff --git a/ipa-core/src/hpke/mod.rs b/ipa-core/src/hpke/mod.rs index 66991515a..2b7f2bb80 100644 --- a/ipa-core/src/hpke/mod.rs +++ b/ipa-core/src/hpke/mod.rs @@ -97,39 +97,6 @@ impl From for CryptError { /// /// [`HPKE decryption`]: https://datatracker.ietf.org/doc/html/rfc9180#name-encryption-and-decryption pub fn open_in_place<'a, R: PrivateKeyRegistry>( - key_registry: &R, - enc: &[u8], - ciphertext: &'a mut [u8], - info: &Info, -) -> Result<&'a [u8], CryptError> { - let key_id = info.key_id; - let info = info.to_bytes(); - let encap_key = ::EncappedKey::from_bytes(enc)?; - let (ct, tag) = ciphertext.split_at_mut(ciphertext.len() - AeadTag::::size()); - let tag = AeadTag::::from_bytes(tag)?; - let sk = key_registry - .private_key(key_id) - .ok_or(CryptError::NoSuchKey(key_id))?; - - single_shot_open_in_place_detached::<_, IpaKdf, IpaKem>( - &OpModeR::Base, - sk, - &encap_key, - &info, - ct, - &[], - &tag, - )?; - - // at this point ct is no longer a pointer to the ciphertext. - let pt = ct; - Ok(pt) -} - -/// Version of `open_in_place` that doesn't require Info struct. -/// ## Errors -/// If ciphertext cannot be opened for any reason. -pub fn hybrid_open_in_place<'a, R: PrivateKeyRegistry>( key_registry: &R, enc: &[u8], ciphertext: &'a mut [u8], @@ -169,34 +136,6 @@ pub(crate) type Ciphertext<'a> = ( /// ## Errors /// If the match key cannot be sealed for any reason. pub(crate) fn seal_in_place<'a, R: CryptoRng + RngCore, K: PublicKeyRegistry>( - key_registry: &K, - plaintext: &'a mut [u8], - info: &'a Info, - rng: &mut R, -) -> Result, CryptError> { - let key_id = info.key_id; - let info = info.to_bytes(); - let pk_r = key_registry - .public_key(key_id) - .ok_or(CryptError::NoSuchKey(key_id))?; - - let (encap_key, tag) = single_shot_seal_in_place_detached::( - &OpModeS::Base, - pk_r, - &info, - plaintext, - &[], - rng, - )?; - - // at this point `plaintext` is no longer a pointer to the plaintext. - Ok((encap_key, plaintext, tag)) -} - -/// Version of `seal_in_place` that doesn't require Info struct. -/// ## Errors -/// If the match key cannot be sealed for any reason. -pub(crate) fn hybrid_seal_in_place<'a, R: CryptoRng + RngCore, K: PublicKeyRegistry>( key_registry: &K, plaintext: &'a mut [u8], key_id: u8, @@ -292,7 +231,8 @@ mod tests { let (encap_key, ciphertext, tag) = seal_in_place( &self.registry, plaintext.as_mut_slice(), - &info, + info.key_id, + &info.to_bytes(), &mut self.rng, ) .unwrap(); @@ -341,7 +281,13 @@ mod tests { Self::SITE_DOMAIN, ) .unwrap(); - open_in_place(&self.registry, &enc.enc, enc.ct.as_mut(), &info)?; + open_in_place( + &self.registry, + &enc.enc, + enc.ct.as_mut(), + info.key_id, + &info.to_bytes(), + )?; // TODO: fix once array split is a thing. Ok(XorReplicated::deserialize_infallible( @@ -526,7 +472,7 @@ mod tests { _ => panic!("bad test setup: only 5 fields can be corrupted, asked to corrupt: {corrupted_info_field}") }; - open_in_place(&suite.registry, &encryption.enc, &mut encryption.ct, &info).unwrap_err(); + open_in_place(&suite.registry, &encryption.enc, &mut encryption.ct, info.key_id, &info.to_bytes()).unwrap_err(); } } } diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index b9c7c5014..81c5cc9eb 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -15,8 +15,8 @@ use crate::{ error::{BoxError, Error}, ff::{boolean_array::BA64, Serializable}, hpke::{ - hybrid_open_in_place, hybrid_seal_in_place, CryptError, EncapsulationSize, - PrivateKeyRegistry, PublicKeyRegistry, TagSize, + open_in_place, seal_in_place, CryptError, EncapsulationSize, PrivateKeyRegistry, + PublicKeyRegistry, TagSize, }, report::{ hybrid_info::HybridImpressionInfo, EncryptedOprfReport, EventType, InvalidReportError, @@ -163,7 +163,7 @@ where self.breakdown_key .serialize(GenericArray::from_mut_slice(&mut plaintext_btt[..])); - let (encap_key_mk, ciphertext_mk, tag_mk) = hybrid_seal_in_place( + let (encap_key_mk, ciphertext_mk, tag_mk) = seal_in_place( key_registry, plaintext_mk.as_mut(), key_id, @@ -171,7 +171,7 @@ where rng, )?; - let (encap_key_btt, ciphertext_btt, tag_btt) = hybrid_seal_in_place( + let (encap_key_btt, ciphertext_btt, tag_btt) = seal_in_place( key_registry, plaintext_btt.as_mut(), key_id, @@ -309,7 +309,7 @@ where let mut ct_mk: GenericArray = *GenericArray::from_slice(self.mk_ciphertext()); - let plaintext_mk = hybrid_open_in_place( + let plaintext_mk = open_in_place( key_registry, self.encap_key_mk(), &mut ct_mk, @@ -319,7 +319,7 @@ where let mut ct_btt: GenericArray> = GenericArray::from_slice(self.btt_ciphertext()).clone(); - let plaintext_btt = hybrid_open_in_place( + let plaintext_btt = open_in_place( key_registry, self.encap_key_btt(), &mut ct_btt, diff --git a/ipa-core/src/report/ipa.rs b/ipa-core/src/report/ipa.rs index a9da93454..a71358f64 100644 --- a/ipa-core/src/report/ipa.rs +++ b/ipa-core/src/report/ipa.rs @@ -407,11 +407,23 @@ where let mut ct_mk: GenericArray = *GenericArray::from_slice(self.mk_ciphertext()); - let plaintext_mk = open_in_place(key_registry, self.encap_key_mk(), &mut ct_mk, &info)?; + let plaintext_mk = open_in_place( + key_registry, + self.encap_key_mk(), + &mut ct_mk, + self.key_id(), + &info.to_bytes(), + )?; let mut ct_btt: GenericArray> = GenericArray::from_slice(self.btt_ciphertext()).clone(); - let plaintext_btt = open_in_place(key_registry, self.encap_key_btt(), &mut ct_btt, &info)?; + let plaintext_btt = open_in_place( + key_registry, + self.encap_key_btt(), + &mut ct_btt, + self.key_id(), + &info.to_bytes(), + )?; Ok(OprfReport:: { timestamp: Replicated::::deserialize(GenericArray::from_slice( @@ -577,11 +589,21 @@ where ..(Self::TV_OFFSET + as Serializable>::Size::USIZE)], )); - let (encap_key_mk, ciphertext_mk, tag_mk) = - seal_in_place(key_registry, plaintext_mk.as_mut(), &info, rng)?; + let (encap_key_mk, ciphertext_mk, tag_mk) = seal_in_place( + key_registry, + plaintext_mk.as_mut(), + key_id, + &info.to_bytes(), + rng, + )?; - let (encap_key_btt, ciphertext_btt, tag_btt) = - seal_in_place(key_registry, plaintext_btt.as_mut(), &info, rng)?; + let (encap_key_btt, ciphertext_btt, tag_btt) = seal_in_place( + key_registry, + plaintext_btt.as_mut(), + key_id, + &info.to_bytes(), + rng, + )?; out.put_slice(&encap_key_mk.to_bytes()); out.put_slice(ciphertext_mk); From 94d8fde2320e3d54bde52ccc554e7882060a52af Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Wed, 16 Oct 2024 17:22:42 -0700 Subject: [PATCH 140/191] remove from [u8] for NonAsciiStringError --- ipa-core/src/report/hybrid.rs | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index 81c5cc9eb..e12c92ba4 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -41,24 +41,11 @@ impl Display for NonAsciiStringError { impl std::error::Error for NonAsciiStringError {} -impl From<&'_ [u8]> for NonAsciiStringError { - fn from(input: &[u8]) -> Self { - Self { - input: String::from_utf8( - input - .iter() - .copied() - .flat_map(std::ascii::escape_default) - .collect::>(), - ) - .unwrap(), - } - } -} - impl From<&'_ str> for NonAsciiStringError { fn from(input: &str) -> Self { - Self::from(input.as_bytes()) + Self { + input: input.to_string(), + } } } From ad77ff828ce194278bd25fdf6da43ee5429c5ba7 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 16 Oct 2024 19:09:44 -0700 Subject: [PATCH 141/191] Add ipa-metrics crate Follows the design described here https://docs.google.com/document/d/1cRnE024bi7KJYeqMT8yMYluIplA8992bNIaeKBmrXp4/edit#heading=h.d8g3vju6yaqd --- Cargo.toml | 2 +- ipa-metrics/Cargo.toml | 20 +++ ipa-metrics/src/collector.rs | 162 +++++++++++++++++++ ipa-metrics/src/context.rs | 151 +++++++++++++++++ ipa-metrics/src/controller.rs | 30 ++++ ipa-metrics/src/key.rs | 287 +++++++++++++++++++++++++++++++++ ipa-metrics/src/kind.rs | 6 + ipa-metrics/src/label.rs | 160 ++++++++++++++++++ ipa-metrics/src/lib.rs | 56 +++++++ ipa-metrics/src/partitioned.rs | 211 ++++++++++++++++++++++++ ipa-metrics/src/producer.rs | 37 +++++ ipa-metrics/src/store.rs | 201 +++++++++++++++++++++++ 12 files changed, 1322 insertions(+), 1 deletion(-) create mode 100644 ipa-metrics/Cargo.toml create mode 100644 ipa-metrics/src/collector.rs create mode 100644 ipa-metrics/src/context.rs create mode 100644 ipa-metrics/src/controller.rs create mode 100644 ipa-metrics/src/key.rs create mode 100644 ipa-metrics/src/kind.rs create mode 100644 ipa-metrics/src/label.rs create mode 100644 ipa-metrics/src/lib.rs create mode 100644 ipa-metrics/src/partitioned.rs create mode 100644 ipa-metrics/src/producer.rs create mode 100644 ipa-metrics/src/store.rs diff --git a/Cargo.toml b/Cargo.toml index 377020368..1aed2b4b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["ipa-core", "ipa-step", "ipa-step-derive", "ipa-step-test"] +members = ["ipa-core", "ipa-step", "ipa-step-derive", "ipa-step-test", "ipa-metrics"] [profile.release] incremental = true diff --git a/ipa-metrics/Cargo.toml b/ipa-metrics/Cargo.toml new file mode 100644 index 000000000..ebaeb9473 --- /dev/null +++ b/ipa-metrics/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "ipa-metrics" +version = "0.1.0" +edition = "2021" + +[features] +default = [] +# support metric partitioning +partitions = [] + +[dependencies] +# crossbeam channels are faster than std +crossbeam-channel = "0.5" +# This crate uses raw entry API that is unstable in stdlib +hashbrown = "0.15" +# Fast non-collision-resistant hashing +rustc-hash = "2.0.0" +# logging +tracing = "0.1" + diff --git a/ipa-metrics/src/collector.rs b/ipa-metrics/src/collector.rs new file mode 100644 index 000000000..bf0a6bc06 --- /dev/null +++ b/ipa-metrics/src/collector.rs @@ -0,0 +1,162 @@ +use std::cell::RefCell; + +use crossbeam_channel::{Receiver, Select}; + +use crate::{ControllerCommand, MetricsStore}; + +thread_local! { + /// Collector that is installed in a thread. It is responsible for receiving metrics from + /// all threads and aggregating them. + static COLLECTOR: RefCell> = const { RefCell::new(None) } +} + +pub struct Installed; + +impl Installed { + pub fn block_until_shutdown(&self) -> MetricsStore { + MetricsCollector::with_current_mut(|c| { + c.event_loop(); + + std::mem::take(&mut c.local_store) + }) + } +} + +pub struct MetricsCollector { + pub(super) rx: Receiver, + pub(super) local_store: MetricsStore, + pub(super) command_rx: Receiver, +} + +impl MetricsCollector { + pub fn install(self) -> Installed { + COLLECTOR.with_borrow_mut(|c| { + assert!(c.replace(self).is_none(), "Already initialized"); + }); + + Installed + } + + fn event_loop(&mut self) { + let mut select = Select::new(); + let data_idx = select.recv(&self.rx); + let command_idx = select.recv(&self.command_rx); + + loop { + let next_op = select.select(); + match next_op.index() { + i if i == data_idx => match next_op.recv(&self.rx) { + Ok(store) => { + tracing::trace!("Collector received more data: {store:?}"); + println!("Collector received more data: {store:?}"); + self.local_store.merge(store) + } + Err(e) => { + tracing::debug!("No more threads collecting metrics. Disconnected: {e}"); + select.remove(data_idx); + } + }, + i if i == command_idx => match next_op.recv(&self.command_rx) { + Ok(ControllerCommand::Snapshot(tx)) => { + tracing::trace!("Snapshot request received"); + println!("snapshot request received"); + tx.send(self.local_store.clone()).unwrap(); + } + Ok(ControllerCommand::Stop(tx)) => { + tx.send(()).unwrap(); + break; + } + Err(e) => { + tracing::debug!("Metric controller is disconnected: {e}"); + break; + } + }, + _ => unreachable!(), + } + } + } + + pub fn with_current_mut T, T>(f: F) -> T { + COLLECTOR.with_borrow_mut(|c| { + let collector = c.as_mut().expect("Collector is installed"); + f(collector) + }) + } +} + +impl Drop for MetricsCollector { + fn drop(&mut self) { + tracing::debug!("Collector is dropped"); + } +} + +#[cfg(test)] +mod tests { + use std::{ + thread, + thread::{Scope, ScopedJoinHandle}, + }; + + use crate::{counter, installer, producer::Producer, thread_installer}; + + struct MeteredScope<'scope, 'env: 'scope>(&'scope Scope<'scope, 'env>, Producer); + + impl<'scope, 'env: 'scope> MeteredScope<'scope, 'env> { + fn spawn(&self, f: F) -> ScopedJoinHandle<'scope, T> + where + F: FnOnce() -> T + Send + 'scope, + T: Send + 'scope, + { + let producer = self.1.clone(); + + self.0.spawn(move || { + producer.install(); + let r = f(); + let _ = producer.drop_handle(); + + r + }) + } + } + + trait IntoMetered<'scope, 'env: 'scope> { + fn metered(&'scope self, meter: Producer) -> MeteredScope<'scope, 'env>; + } + + impl<'scope, 'env: 'scope> IntoMetered<'scope, 'env> for Scope<'scope, 'env> { + fn metered(&'scope self, meter: Producer) -> MeteredScope<'scope, 'env> { + MeteredScope(self, meter) + } + } + + #[test] + fn start_stop() { + let (collector, producer, controller) = installer(); + let handle = thread::spawn(|| { + let store = collector.install().block_until_shutdown(); + store.counter_value(counter!("foo")) + }); + + thread::scope(move |s| { + let s = s.metered(producer); + s.spawn(|| counter!("foo", 3)).join().unwrap(); + s.spawn(|| counter!("foo", 5)).join().unwrap(); + controller.stop().unwrap(); + }); + + assert_eq!(8, handle.join().unwrap()) + } + + #[test] + fn with_thread() { + let (producer, controller, handle) = thread_installer().unwrap(); + thread::scope(move |s| { + let s = s.metered(producer); + s.spawn(|| counter!("baz", 4)); + s.spawn(|| counter!("bar", 1)); + s.spawn(|| controller.stop().unwrap()); + }); + + handle.join().unwrap() // Collector thread should be terminated by now + } +} diff --git a/ipa-metrics/src/context.rs b/ipa-metrics/src/context.rs new file mode 100644 index 000000000..2680c0876 --- /dev/null +++ b/ipa-metrics/src/context.rs @@ -0,0 +1,151 @@ +use std::{cell::RefCell, mem}; + +use crossbeam_channel::Sender; + +use crate::MetricsStore; + +thread_local! { + pub(crate) static METRICS_CTX: RefCell = const { RefCell::new(MetricsContext::new()) } +} + +#[macro_export] +macro_rules! counter { + ($metric:expr, $val:expr $(, $l:expr => $v:expr)*) => {{ + let name = $crate::metric_name!($metric $(, $l => $v)*); + $crate::MetricsCurrentThreadContext::store_mut(|store| store.counter(&name).inc($val)) + }}; + ($metric:expr $(, $l:expr => $v:expr)*) => {{ + $crate::metric_name!($metric $(, $l => $v)*) + }}; +} + +/// Provides access to the metric store associated with the current thread. +/// If there is no store associated with the current thread, it will create a new one. +pub struct CurrentThreadContext; + +impl CurrentThreadContext { + pub fn init(tx: Sender) { + METRICS_CTX.with_borrow_mut(|ctx| ctx.init(tx)); + } + + pub fn flush() { + METRICS_CTX.with_borrow_mut(|ctx| ctx.flush()); + } + + pub fn is_connected() -> bool { + METRICS_CTX.with_borrow(|ctx| ctx.is_connected()) + } + + pub fn store T, T>(f: F) -> T { + METRICS_CTX.with_borrow(|ctx| f(ctx.store())) + } + + pub fn store_mut T, T>(f: F) -> T { + METRICS_CTX.with_borrow_mut(|ctx| f(ctx.store_mut())) + } +} + +/// This context is used inside thread-local storage, +/// so it must be wrapped inside [`std::cell::RefCell`]. +/// +/// For single-threaded applications, it is possible +/// to use it w/o connecting to the collector thread. +pub struct MetricsContext { + store: MetricsStore, + /// Handle to send metrics to the collector thread + tx: Option>, +} + +impl Default for MetricsContext { + fn default() -> Self { + Self::new() + } +} + +impl MetricsContext { + pub const fn new() -> Self { + Self { + store: MetricsStore::new(), + tx: None, + } + } + + /// Connects this context to the collector thread. + /// Sender will be used to send data from this thread + fn init(&mut self, tx: Sender) { + assert!(self.tx.is_none(), "Already connected"); + + self.tx = Some(tx); + } + + pub fn store(&self) -> &MetricsStore { + &self.store + } + + pub fn store_mut(&mut self) -> &mut MetricsStore { + &mut self.store + } + + fn is_connected(&self) -> bool { + self.tx.is_some() + } + + fn flush(&mut self) { + if self.is_connected() { + let store = mem::take(&mut self.store); + match self.tx.as_ref().unwrap().send(store) { + Ok(_) => {} + Err(e) => { + tracing::warn!("MetricsContext is not connected: {e}"); + } + } + } else { + tracing::warn!("MetricsContext is not connected"); + } + } +} + +impl Drop for MetricsContext { + fn drop(&mut self) { + if !self.store.is_empty() { + tracing::warn!( + "Non-empty metric store is dropped: {} metrics lost", + self.store.len() + ); + } + } +} + +#[cfg(test)] +mod tests { + use crate::MetricsContext; + + /// Each thread has its local store by default, and it is exclusive to it + #[test] + #[cfg(feature = "partitions")] + fn local_store() { + use crate::context::CurrentThreadContext; + + crate::set_partition(0xdeadbeef); + counter!("foo", 7); + + std::thread::spawn(|| { + counter!("foo", 1); + counter!("foo", 5); + assert_eq!( + 5, + CurrentThreadContext::store(|store| store.counter_value(&counter!("foo"))) + ); + }); + + assert_eq!( + 7, + CurrentThreadContext::store(|store| store.counter_value(&counter!("foo"))) + ); + } + + #[test] + fn default() { + assert_eq!(0, MetricsContext::default().store().len()) + } +} diff --git a/ipa-metrics/src/controller.rs b/ipa-metrics/src/controller.rs new file mode 100644 index 000000000..a70802f38 --- /dev/null +++ b/ipa-metrics/src/controller.rs @@ -0,0 +1,30 @@ +use crossbeam_channel::Sender; + +use crate::MetricsStore; + +pub enum Command { + Snapshot(Sender), + Stop(Sender<()>), +} + +pub struct Controller { + pub(super) tx: Sender, +} + +impl Controller { + pub fn snapshot(&self) -> Result { + let (tx, rx) = crossbeam_channel::bounded(0); + self.tx + .send(Command::Snapshot(tx)) + .map_err(|e| format!("An error occurred while requesting metrics snapshot: {e}"))?; + rx.recv().map_err(|e| format!("Disconnected channel: {e}")) + } + + pub fn stop(self) -> Result<(), String> { + let (tx, rx) = crossbeam_channel::bounded(0); + self.tx + .send(Command::Stop(tx)) + .map_err(|e| format!("An error occurred while requesting metrics snapshot: {e}"))?; + rx.recv().map_err(|e| format!("Disconnected channel: {e}")) + } +} diff --git a/ipa-metrics/src/key.rs b/ipa-metrics/src/key.rs new file mode 100644 index 000000000..dec06a108 --- /dev/null +++ b/ipa-metrics/src/key.rs @@ -0,0 +1,287 @@ +//! Metric names supported by this crate. +//! +//! Providing a good use for metrics is a tradeoff between +//! performance and ergonomics. Flexible metric engines support +//! dynamic names, like "bytes.sent.{ip}" or "cpu.{core}.instructions" +//! but that comes with a significant performance cost. +//! String interning helps to mitigate this on the storage site +//! but callsites need to allocate those at every call. +//! +//! IPA metrics can be performance sensitive. There are counters +//! incremented on every send and receive operation, so they need +//! to be fast. For this reason, dynamic metric names are not supported. +//! Metric name can only be a string, known at compile time. +//! +//! However, it is not flexible enough. Most metrics have dimensions +//! attached to them. IPA example is `bytes.sent` metric with step breakdown. +//! It is very useful to know the required throughput per circuit. +//! +//! This metric engine supports up to 5 dimensions attached to every metric, +//! again trying to strike a good balance between performance and usability. + +use std::{ + array, + hash::{Hash, Hasher}, + iter, + iter::repeat, +}; + +pub use Name as MetricName; +pub(super) use OwnedName as OwnedMetricName; + +use crate::label::{Label, OwnedLabel, MAX_LABELS}; + +#[macro_export] +macro_rules! metric_name { + // Match when two key-value pairs are provided + // TODO: enforce uniqueness at compile time + ($metric:expr, $l1:expr => $v1:expr, $l2:expr => $v2:expr) => {{ + use $crate::UniqueElements; + let labels = [ + $crate::Label { + name: $l1, + val: $v1, + }, + $crate::Label { + name: $l2, + val: $v2, + }, + ] + .enforce_unique(); + $crate::MetricName::from_parts($metric, labels) + }}; + // Match when one key-value pair is provided + ($metric:expr, $l1:expr => $v1:expr) => {{ + $crate::MetricName::from_parts( + $metric, + [$crate::Label { + name: $l1, + val: $v1, + }], + ) + }}; + // Match when no key-value pairs are provided + ($metric:expr) => {{ + $crate::MetricName::from_parts($metric, []) + }}; +} + +/// Metric name that is created at callsite on each metric invocation. +/// For this reason, it is performance sensitive - it tries to borrow +/// whatever it can from callee stack. +#[derive(Debug, PartialEq)] +pub struct Name<'lv, const LABELS: usize = 0> { + pub(super) key: &'static str, + labels: [Label<'lv>; LABELS], +} + +impl<'lv, const LABELS: usize> Name<'lv, LABELS> { + pub fn from_parts>(key: I, labels: [Label<'lv>; LABELS]) -> Self { + assert!( + LABELS <= MAX_LABELS, + "Maximum 5 labels per metric is supported" + ); + + Self { + key: key.into(), + labels, + } + } + + /// [`ToOwned`] trait does not work because of + /// extra [`Borrow`] requirement + pub(super) fn to_owned(&self) -> OwnedName { + let labels: [_; 5] = array::from_fn(|i| { + if i < self.labels.len() { + Some(self.labels[i].to_owned()) + } else { + None + } + }); + + OwnedName { + key: self.key, + labels, + } + } +} + +/// Same as [`Name`], but intended for internal use. This is an owned +/// version of it, that does not borrow anything from outside. +/// This is the key inside metric stores which are simple hashmaps. +#[derive(Debug, Clone)] +pub struct OwnedName { + pub(super) key: &'static str, + labels: [Option; 5], +} + +impl OwnedName { + pub fn key(&self) -> &'static str { + self.key + } + + pub fn labels(&self) -> impl Iterator { + self.labels.iter().filter_map(|l| l.as_ref()) + } +} + +impl Hash for Name<'_, LABELS> { + fn hash(&self, state: &mut H) { + state.write(self.key.as_bytes()); + for label in &self.labels { + label.hash(state) + } + } +} + +impl From<&'static str> for Name<'_, 0> { + fn from(value: &'static str) -> Self { + Self { + key: value, + labels: [], + } + } +} + +pub trait UniqueElements { + fn enforce_unique(self) -> Self; +} + +impl UniqueElements for [Label<'_>; 2] { + fn enforce_unique(self) -> Self { + if self[0].name == self[1].name { + panic!("label names must be unique") + } + + self + } +} + +impl<'a, const LABELS: usize> PartialEq> for OwnedName { + fn eq(&self, other: &Name<'a, LABELS>) -> bool { + self.key == other.key + && iter::zip( + &self.labels, + other.labels.iter().map(Some).chain(repeat(None)), + ) + .all(|(a, b)| match (a, b) { + (Some(a), Some(b)) => a.as_borrowed() == *b, + (None, None) => true, + _ => false, + }) + } +} + +impl PartialEq for OwnedName { + fn eq(&self, other: &OwnedName) -> bool { + self.key == other.key + && iter::zip(&self.labels, &other.labels).all(|(a, b)| match (a, b) { + (Some(a), Some(b)) => a == b, + (None, None) => true, + _ => false, + }) + } +} + +impl Eq for OwnedName {} + +impl Hash for OwnedName { + fn hash(&self, state: &mut H) { + state.write(self.key.as_bytes()); + for label in self.labels.iter().flatten() { + label.hash(state) + } + } +} + +#[cfg(test)] +pub fn compute_hash(value: V) -> u64 { + let mut hasher = std::hash::DefaultHasher::default(); + value.hash(&mut hasher); + + hasher.finish() +} + +#[cfg(test)] +mod tests { + + use crate::{ + key::{compute_hash, Name}, + label::Label, + }; + + #[test] + fn eq() { + let name = Name::from("foo"); + assert_eq!(name.to_owned(), name); + } + + #[test] + fn hash_eq() { + let a = Name::from("foo"); + let b = Name::from("foo"); + assert_eq!(compute_hash(&a), compute_hash(b)); + assert_eq!(compute_hash(&a), compute_hash(a.to_owned())); + } + + #[test] + fn not_eq() { + let foo = Name::from("foo"); + let bar = Name::from("bar"); + assert_ne!(foo.to_owned(), bar); + } + + #[test] + fn hash_not_eq() { + let foo = Name::from("foo"); + let bar = Name::from("bar"); + assert_ne!(compute_hash(&foo), compute_hash(&bar)); + assert_ne!(compute_hash(foo), compute_hash(bar.to_owned())); + } + + #[test] + #[should_panic(expected = "Maximum 5 labels per metric is supported")] + fn more_than_5_labels() { + let _ = Name::from_parts( + "foo", + [ + Label { + name: "label_1", + val: &1, + }, + Label { + name: "label_2", + val: &1, + }, + Label { + name: "label_3", + val: &1, + }, + Label { + name: "label_4", + val: &1, + }, + Label { + name: "label_5", + val: &1, + }, + Label { + name: "label_6", + val: &1, + }, + ], + ); + } + + #[test] + fn eq_is_consistent() { + let a_name = metric_name!("foo", "label_1" => &1); + let b_name = metric_name!("foo", "label_1" => &1, "label_2" => &2); + + assert_eq!(a_name, a_name); + assert_eq!(a_name.to_owned(), a_name); + + assert_ne!(b_name.to_owned(), a_name); + assert_ne!(a_name.to_owned(), b_name); + } +} diff --git a/ipa-metrics/src/kind.rs b/ipa-metrics/src/kind.rs new file mode 100644 index 000000000..b6abf7e4e --- /dev/null +++ b/ipa-metrics/src/kind.rs @@ -0,0 +1,6 @@ +//! Different metric types supported by this crate. +//! Currently, only counters are supported. +//! TODO: add more + +/// Counters are simple 8 byte values. +pub type CounterValue = u64; diff --git a/ipa-metrics/src/label.rs b/ipa-metrics/src/label.rs new file mode 100644 index 000000000..d35680e08 --- /dev/null +++ b/ipa-metrics/src/label.rs @@ -0,0 +1,160 @@ +use std::{ + fmt::{Debug, Display, Formatter}, + hash::{Hash, Hasher}, +}; + +pub use Value as LabelValue; + +pub const MAX_LABELS: usize = 5; + +/// Dimension value (or label value) must be sendable to another thread +/// and there must be a way to show it +pub trait Value: Display + Send { + /// Creates a unique hash for this value. + /// It is easy to create collisions, so better avoid them, + /// by assigning a unique integer to each value + fn hash(&self) -> u64; + + /// Creates an owned copy of this value. Dynamic dispatch + /// is required, because values are stored in a generic store + /// that can't be specialized for value types. + fn boxed(&self) -> Box; +} + +impl LabelValue for u32 { + fn hash(&self) -> u64 { + u64::from(*self) + } + + fn boxed(&self) -> Box { + Box::new(*self) + } +} + +#[derive()] +pub struct Label<'lv> { + pub name: &'static str, + pub val: &'lv dyn Value, +} + +impl Label<'_> { + pub fn to_owned(&self) -> OwnedLabel { + OwnedLabel { + name: self.name, + val: self.val.boxed(), + } + } +} + +impl Debug for Label<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Label") + .field("name", &self.name) + .field("val", &format!("{}", self.val)) + .finish() + } +} + +impl Hash for Label<'_> { + fn hash(&self, state: &mut H) { + state.write(self.name.as_bytes()); + state.write_u64(self.val.hash()); + } +} + +impl PartialEq for Label<'_> { + fn eq(&self, other: &Self) -> bool { + // name check should be fast - just pointer comparison. + // val check is more involved with dynamic dispatch, so we can consider + // making label immutable and storing a hash of the value in place + self.name == other.name && self.val.hash() == other.val.hash() + } +} + +pub struct OwnedLabel { + pub name: &'static str, + pub val: Box, +} + +impl Clone for OwnedLabel { + fn clone(&self) -> Self { + Self { + name: self.name, + val: self.val.boxed(), + } + } +} + +impl OwnedLabel { + pub fn as_borrowed(&self) -> Label<'_> { + Label { + name: self.name, + val: self.val.as_ref(), + } + } + + pub fn name(&self) -> &'static str { + self.name + } + + pub fn str_value(&self) -> String { + self.val.to_string() + } +} + +impl Hash for OwnedLabel { + fn hash(&self, state: &mut H) { + self.as_borrowed().hash(state) + } +} + +impl Debug for OwnedLabel { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OwnedLabel") + .field("name", &self.name) + .field("val", &format!("{}", self.val)) + .finish() + } +} + +impl PartialEq for OwnedLabel { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.val.hash() == other.val.hash() + } +} + +#[cfg(test)] +mod tests { + + use crate::{key::compute_hash, metric_name}; + + #[test] + fn one_label() { + let foo_1 = metric_name!("foo", "l1" => &1); + let foo_2 = metric_name!("foo", "l1" => &2); + + assert_ne!(foo_1.to_owned(), foo_2); + assert_ne!(compute_hash(&foo_1), compute_hash(&foo_2)); + assert_ne!(foo_2.to_owned(), foo_1); + + assert_eq!(compute_hash(&foo_1), compute_hash(foo_1.to_owned())) + } + + #[test] + #[should_panic(expected = "label names must be unique")] + fn unique() { + metric_name!("foo", "l1" => &1, "l1" => &0); + } + + #[test] + fn non_commutative() { + assert_ne!( + compute_hash(&metric_name!("foo", "l1" => &1, "l2" => &0)), + compute_hash(&metric_name!("foo", "l1" => &0, "l2" => &1)), + ); + assert_ne!( + compute_hash(&metric_name!("foo", "l1" => &1)), + compute_hash(&metric_name!("foo", "l1" => &1, "l2" => &1)), + ); + } +} diff --git a/ipa-metrics/src/lib.rs b/ipa-metrics/src/lib.rs new file mode 100644 index 000000000..87786e91d --- /dev/null +++ b/ipa-metrics/src/lib.rs @@ -0,0 +1,56 @@ +mod collector; +mod context; +mod controller; +mod key; +mod kind; +mod label; +#[cfg(feature = "partitions")] +mod partitioned; +mod producer; +mod store; + +use std::{io, thread::JoinHandle}; + +pub use collector::MetricsCollector; +pub use context::{CurrentThreadContext as MetricsCurrentThreadContext, MetricsContext}; +pub use controller::{Command as ControllerCommand, Controller as MetricsCollectorController}; +pub use key::{MetricName, OwnedName, UniqueElements}; +pub use label::{Label, LabelValue}; +#[cfg(feature = "partitions")] +pub use partitioned::{ + current_partition, set_or_unset_partition, set_partition, Partition as MetricPartition, + PartitionedStore as MetricsStore, +}; +pub use producer::Producer as MetricsProducer; +#[cfg(not(feature = "partitions"))] +pub use store::Store as MetricsStore; + +pub fn installer() -> ( + MetricsCollector, + MetricsProducer, + MetricsCollectorController, +) { + let (command_tx, command_rx) = crossbeam_channel::unbounded(); + let (tx, rx) = crossbeam_channel::unbounded(); + ( + MetricsCollector { + rx, + local_store: MetricsStore::default(), + command_rx, + }, + MetricsProducer { tx }, + MetricsCollectorController { tx: command_tx }, + ) +} + +pub fn thread_installer( +) -> io::Result<(MetricsProducer, MetricsCollectorController, JoinHandle<()>)> { + let (collector, producer, controller) = installer(); + let handle = std::thread::Builder::new() + .name("metric-collector".to_string()) + .spawn(|| { + collector.install().block_until_shutdown(); + })?; + + Ok((producer, controller, handle)) +} diff --git a/ipa-metrics/src/partitioned.rs b/ipa-metrics/src/partitioned.rs new file mode 100644 index 000000000..6e97b2446 --- /dev/null +++ b/ipa-metrics/src/partitioned.rs @@ -0,0 +1,211 @@ +use std::{borrow::Borrow, cell::Cell}; + +use hashbrown::hash_map::Entry; +use rustc_hash::FxBuildHasher; + +use crate::{ + kind::CounterValue, + store::{CounterHandle, Store}, + MetricName, +}; + +/// Each partition is a unique 16 byte value. +pub type Partition = u128; + +pub fn set_partition(new: Partition) { + PARTITION.set(Some(new)); +} + +pub fn set_or_unset_partition(new: Option) { + PARTITION.set(new); +} + +pub fn current_partition() -> Option { + PARTITION.get() +} + +thread_local! { + static PARTITION: Cell> = Cell::new(None); +} + +/// Provides the same functionality as [`Store`], but partitioned +/// across many dimensions. There is an extra price for it, so +/// don't use it, unless you need it. +/// The dimension is set through [`std::thread::LocalKey`], so +/// each thread can set only one dimension at a time. +/// +/// The API of this struct will match [`Store`] as they +/// can be used interchangeably. +#[derive(Clone, Debug)] +pub struct PartitionedStore { + inner: hashbrown::HashMap, + default_store: Store, +} + +impl Default for PartitionedStore { + fn default() -> Self { + Self::new() + } +} + +impl PartitionedStore { + pub const fn new() -> Self { + Self { + inner: hashbrown::HashMap::with_hasher(FxBuildHasher), + default_store: Store::new(), + } + } + + pub fn with_current_partition T, T>(&mut self, f: F) -> T { + let mut store = self.get_mut(current_partition()); + f(&mut store) + } + + pub fn with_partition T, T>( + &self, + partition: Partition, + f: F, + ) -> Option { + let store = self.inner.get(&partition); + store.map(f) + } + + pub fn with_partition_mut T, T>( + &mut self, + partition: Partition, + f: F, + ) -> T { + let mut store = self.get_mut(Some(partition)); + f(&mut store) + } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() && self.default_store.is_empty() + } + + pub fn merge(&mut self, other: Self) { + for (partition, store) in other.inner { + self.get_mut(Some(partition)).merge(store); + } + self.default_store.merge(other.default_store); + } + + pub fn counter_value>(&self, name: B) -> CounterValue { + let name = name.borrow(); + if let Some(partition) = current_partition() { + self.inner + .get(&partition) + .map(|store| store.counter_value(name)) + .unwrap_or_default() + } else { + self.default_store.counter_value(name) + } + } + + pub fn counter( + &mut self, + key: &MetricName<'_, LABELS>, + ) -> CounterHandle<'_, LABELS> { + if let Some(partition) = current_partition() { + self.inner + .entry(partition) + .or_insert_with(|| Store::default()) + .counter(key) + } else { + self.default_store.counter(key) + } + } + + pub fn len(&self) -> usize { + self.inner.len() + self.default_store.len() + } + + fn get_mut(&mut self, partition: Option) -> &mut Store { + if let Some(v) = partition { + match self.inner.entry(v) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => entry.insert(Store::default()), + } + } else { + &mut self.default_store + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + metric_name, + partitioned::{set_partition, PartitionedStore}, + }; + + #[test] + fn unique_partition() { + let metric = metric_name!("foo"); + let mut store = PartitionedStore::new(); + store.with_partition_mut(1, |store| { + store.counter(&metric).inc(1); + }); + store.with_partition_mut(5, |store| { + store.counter(&metric).inc(5); + }); + + assert_eq!( + 5, + store.with_partition_mut(5, |store| store.counter(&metric).get()) + ); + assert_eq!( + 1, + store.with_partition_mut(1, |store| store.counter(&metric).get()) + ); + assert_eq!( + 0, + store.with_partition_mut(10, |store| store.counter(&metric).get()) + ); + } + + #[test] + fn current_partition() { + let metric = metric_name!("foo"); + let mut store = PartitionedStore::new(); + set_partition(4); + + store.with_current_partition(|store| { + store.counter(&metric).inc(1); + }); + store.with_current_partition(|store| { + store.counter(&metric).inc(5); + }); + + assert_eq!( + 6, + store.with_current_partition(|store| store.counter(&metric).get()) + ); + } + + #[test] + fn empty() { + let mut store = PartitionedStore::new(); + store.with_current_partition(|store| { + store.counter(&metric_name!("foo")).inc(1); + }); + + assert!(!store.is_empty()); + } + + #[test] + fn len() { + let mut store = PartitionedStore::new(); + store.with_current_partition(|store| { + store.counter(&metric_name!("foo")).inc(1); + }); + set_partition(4); + store.with_current_partition(|store| { + store.counter(&metric_name!("foo")).inc(1); + }); + + // one metric in partition 4, another one in default. Even that they are the same, + // partitioned store cannot distinguish between them + assert_eq!(2, store.len()); + } +} diff --git a/ipa-metrics/src/producer.rs b/ipa-metrics/src/producer.rs new file mode 100644 index 000000000..27925bed2 --- /dev/null +++ b/ipa-metrics/src/producer.rs @@ -0,0 +1,37 @@ +use crossbeam_channel::Sender; + +use crate::{context::CurrentThreadContext, MetricsStore}; + +#[derive(Clone)] +pub struct Producer { + pub(super) tx: Sender, +} + +impl Producer { + pub fn install(&self) { + CurrentThreadContext::init(self.tx.clone()); + } + + /// Returns a drop handle that should be used when thread is stopped. + /// In an ideal world, a destructor on [`MetricsContext`] could do this, + /// but as pointed in [`LocalKey`] documentation, deadlocks are possible + /// if another TLS storage is accessed at destruction time. + /// + /// I actually ran into this problem with crossbeam channels. Send operation + /// requires access to `thread::current` and that panics at runtime if called + /// from inside `Drop`. + /// + /// [`LocalKey`]: + pub fn drop_handle(&self) -> ProducerDropHandle { + ProducerDropHandle + } +} + +#[must_use] +pub struct ProducerDropHandle; + +impl Drop for ProducerDropHandle { + fn drop(&mut self) { + CurrentThreadContext::flush() + } +} diff --git a/ipa-metrics/src/store.rs b/ipa-metrics/src/store.rs new file mode 100644 index 000000000..6a18b7b6b --- /dev/null +++ b/ipa-metrics/src/store.rs @@ -0,0 +1,201 @@ +use std::{borrow::Borrow, hash::BuildHasher}; + +use hashbrown::hash_map::RawEntryMut; +use rustc_hash::FxBuildHasher; + +use crate::{ + key::{OwnedMetricName, OwnedName}, + kind::CounterValue, + MetricName, +}; + +/// A basic store. Currently only supports counters. +#[derive(Clone, Debug)] +pub struct Store { + // Counters and other metrics are stored to optimize writes. That means, one lookup + // per write. The cost of assembling the total count across all dimensions is absorbed + // by readers + counters: hashbrown::HashMap, +} + +impl Default for Store { + fn default() -> Self { + Self::new() + } +} + +impl Store { + pub const fn new() -> Self { + Self { + counters: hashbrown::HashMap::with_hasher(FxBuildHasher), + } + } + + pub(crate) fn merge(&mut self, other: Self) { + for (k, v) in other.counters { + let hash_builder = self.counters.hasher(); + let hash = hash_builder.hash_one(&k); + *self + .counters + .raw_entry_mut() + .from_hash(hash, |other| other.eq(&k)) + .or_insert(k, 0) + .1 += v; + } + } + + pub fn is_empty(&self) -> bool { + self.counters.is_empty() + } +} + +impl Store { + pub fn counter( + &mut self, + key: &MetricName<'_, LABELS>, + ) -> CounterHandle<'_, LABELS> { + let hash_builder = self.counters.hasher(); + let hash = hash_builder.hash_one(key); + let entry = self + .counters + .raw_entry_mut() + .from_hash(hash, |key_found| key_found.eq(key)); + match entry { + RawEntryMut::Occupied(slot) => CounterHandle { + val: slot.into_mut(), + }, + RawEntryMut::Vacant(slot) => { + let (_, val) = slot.insert_hashed_nocheck(hash, key.to_owned(), Default::default()); + CounterHandle { val } + } + } + } + + /// Returns the value for the specified metric across all dimensions. + /// The cost of this operation is `O(N*M)` where `N` - number of unique metrics + /// and `M` - number of all dimensions across all metrics. + /// + /// Note that the cost can be improved if it ever becomes a bottleneck by + /// creating a specialized two-level map (metric -> label -> value). + pub fn counter_value<'a, B: Borrow>>(&'a self, key: B) -> CounterValue { + let key = key.borrow(); + let mut answer = 0; + for (metric, value) in &self.counters { + if metric.key == key.key { + answer += value + } + } + + answer + } + + pub fn counters(&self) -> impl Iterator { + self.counters.iter().map(|(key, value)| (key, *value)) + } + + pub fn len(&self) -> usize { + self.counters.len() + } +} + +pub struct CounterHandle<'a, const LABELS: usize> { + val: &'a mut CounterValue, +} + +impl CounterHandle<'_, LABELS> { + pub fn inc(&mut self, inc: CounterValue) { + *self.val += inc; + } + + pub fn get(&self) -> CounterValue { + *self.val + } +} + +#[cfg(test)] +mod tests { + use std::hash::{DefaultHasher, Hash, Hasher}; + + use crate::{metric_name, store::Store, LabelValue}; + + impl LabelValue for &'static str { + fn hash(&self) -> u64 { + // TODO: use fast hashing here + let mut hasher = DefaultHasher::default(); + Hash::hash(self, &mut hasher); + + hasher.finish() + } + + fn boxed(&self) -> Box { + Box::new(*self) + } + } + + #[test] + fn counter() { + let mut store = Store::default(); + let name = metric_name!("foo"); + { + let mut handle = store.counter(&name); + assert_eq!(0, handle.get()); + handle.inc(3); + assert_eq!(3, handle.get()); + } + + { + store.counter(&name).inc(0); + assert_eq!(3, store.counter(&name).get()); + } + } + + #[test] + fn with_labels() { + let mut store = Store::default(); + let valid_name = metric_name!("foo", "h1" => &1, "h2" => &"2"); + let wrong_name = metric_name!("foo", "h1" => &2, "h2" => &"2"); + store.counter(&valid_name).inc(2); + + assert_eq!(2, store.counter(&valid_name).get()); + assert_eq!(0, store.counter(&wrong_name).get()); + } + + #[test] + fn merge() { + let mut store1 = Store::default(); + let mut store2 = Store::default(); + let foo = metric_name!("foo", "h1" => &1, "h2" => &"2"); + let bar = metric_name!("bar", "h2" => &"2"); + let baz = metric_name!("baz"); + store1.counter(&foo).inc(2); + store2.counter(&foo).inc(1); + + store1.counter(&bar).inc(7); + store2.counter(&baz).inc(3); + + store1.merge(store2); + + assert_eq!(3, store1.counter(&foo).get()); + assert_eq!(7, store1.counter(&bar).get()); + assert_eq!(3, store1.counter(&baz).get()); + } + + #[test] + fn counter_value() { + let mut store = Store::default(); + store + .counter(&metric_name!("foo", "h1" => &1, "h2" => &"1")) + .inc(1); + store + .counter(&metric_name!("foo", "h1" => &1, "h2" => &"2")) + .inc(1); + store + .counter(&metric_name!("foo", "h1" => &2, "h2" => &"1")) + .inc(1); + store + .counter(&metric_name!("foo", "h1" => &2, "h2" => &"2")) + .inc(1); + + assert_eq!(4, store.counter_value(&metric_name!("foo"))); + } +} From e2d82c7107e15450f9a9f1fe2c1954ff4d8c37ae Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 16 Oct 2024 23:36:58 -0700 Subject: [PATCH 142/191] Add compile time function to compute the size of AdditiveShare The benefit of using it is it does not require trait bounds math as `Serializable` trait and it can be used in const context --- .../replicated/semi_honest/additive_share.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs b/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs index faca5d570..78d393a07 100644 --- a/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs +++ b/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs @@ -76,6 +76,13 @@ impl, const N: usize> AdditiveShare { >::Array::ZERO_ARRAY, >::Array::ZERO_ARRAY, ); + + /// Returns the size this instance would occupy on the wire or disk. + /// In other words, it does not include padding/alignment. + #[must_use] + pub const fn size() -> usize { + 2 * <>::Array as Serializable>::Size::USIZE + } } impl AdditiveShare { @@ -636,6 +643,14 @@ mod tests { mult_by_constant_test_case((0, 0, 0), 2, 0); } + #[test] + fn test_size() { + const FP31_SZ: usize = AdditiveShare::::size(); + const VEC_FP32: usize = AdditiveShare::::size(); + assert_eq!(2, FP31_SZ); + assert_eq!(256, VEC_FP32); + } + impl Arbitrary for AdditiveShare where V: Vectorizable>, From e130a30790d614aa2fd44548714446883fa0387a Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 17 Oct 2024 10:40:39 -0700 Subject: [PATCH 143/191] Improve documentation and API for partitioned store --- ipa-metrics/src/collector.rs | 2 +- ipa-metrics/src/context.rs | 29 +++++- ipa-metrics/src/key.rs | 13 +++ ipa-metrics/src/lib.rs | 2 +- ipa-metrics/src/partitioned.rs | 176 ++++++++++++++++++++------------- ipa-metrics/src/producer.rs | 13 ++- ipa-metrics/src/store.rs | 61 +++++++----- 7 files changed, 196 insertions(+), 100 deletions(-) diff --git a/ipa-metrics/src/collector.rs b/ipa-metrics/src/collector.rs index bf0a6bc06..8cd4f105a 100644 --- a/ipa-metrics/src/collector.rs +++ b/ipa-metrics/src/collector.rs @@ -134,7 +134,7 @@ mod tests { let (collector, producer, controller) = installer(); let handle = thread::spawn(|| { let store = collector.install().block_until_shutdown(); - store.counter_value(counter!("foo")) + store.counter_val(counter!("foo")) }); thread::scope(move |s| { diff --git a/ipa-metrics/src/context.rs b/ipa-metrics/src/context.rs index 2680c0876..c91fc7906 100644 --- a/ipa-metrics/src/context.rs +++ b/ipa-metrics/src/context.rs @@ -91,6 +91,10 @@ impl MetricsContext { } fn flush(&mut self) { + if self.store.is_empty() { + return; + } + if self.is_connected() { let store = mem::take(&mut self.store); match self.tx.as_ref().unwrap().send(store) { @@ -118,7 +122,9 @@ impl Drop for MetricsContext { #[cfg(test)] mod tests { - use crate::MetricsContext; + use std::thread; + + use crate::{CurrentThreadPartitionContext, MetricsContext}; /// Each thread has its local store by default, and it is exclusive to it #[test] @@ -126,7 +132,7 @@ mod tests { fn local_store() { use crate::context::CurrentThreadContext; - crate::set_partition(0xdeadbeef); + CurrentThreadPartitionContext::set(0xdeadbeef); counter!("foo", 7); std::thread::spawn(|| { @@ -134,13 +140,13 @@ mod tests { counter!("foo", 5); assert_eq!( 5, - CurrentThreadContext::store(|store| store.counter_value(&counter!("foo"))) + CurrentThreadContext::store(|store| store.counter_val(counter!("foo"))) ); }); assert_eq!( 7, - CurrentThreadContext::store(|store| store.counter_value(&counter!("foo"))) + CurrentThreadContext::store(|store| store.counter_val(counter!("foo"))) ); } @@ -148,4 +154,19 @@ mod tests { fn default() { assert_eq!(0, MetricsContext::default().store().len()) } + + #[test] + fn ignore_empty_store_on_flush() { + let (tx, rx) = crossbeam_channel::unbounded(); + let mut ctx = MetricsContext::new(); + ctx.init(tx); + let handle = thread::spawn(move || { + if let Ok(_) = rx.recv() { + panic!("Context sent empty store"); + } + }); + ctx.flush(); + drop(ctx); + handle.join().unwrap(); + } } diff --git a/ipa-metrics/src/key.rs b/ipa-metrics/src/key.rs index dec06a108..50a34b187 100644 --- a/ipa-metrics/src/key.rs +++ b/ipa-metrics/src/key.rs @@ -123,6 +123,19 @@ impl OwnedName { pub fn labels(&self) -> impl Iterator { self.labels.iter().filter_map(|l| l.as_ref()) } + + /// Checks that a subset of labels in `self` matches all values in `other`. + pub fn partial_match(&self, other: &Name<'_, LABELS>) -> bool { + if self.key != other.key { + false + } else { + other.labels.iter().all(|l| self.find_label(l)) + } + } + + fn find_label(&self, label: &Label<'_>) -> bool { + self.labels().any(|l| l.as_borrowed().eq(label)) + } } impl Hash for Name<'_, LABELS> { diff --git a/ipa-metrics/src/lib.rs b/ipa-metrics/src/lib.rs index 87786e91d..843f327a2 100644 --- a/ipa-metrics/src/lib.rs +++ b/ipa-metrics/src/lib.rs @@ -18,7 +18,7 @@ pub use key::{MetricName, OwnedName, UniqueElements}; pub use label::{Label, LabelValue}; #[cfg(feature = "partitions")] pub use partitioned::{ - current_partition, set_or_unset_partition, set_partition, Partition as MetricPartition, + CurrentThreadContext as CurrentThreadPartitionContext, Partition as MetricPartition, PartitionedStore as MetricsStore, }; pub use producer::Producer as MetricsProducer; diff --git a/ipa-metrics/src/partitioned.rs b/ipa-metrics/src/partitioned.rs index 6e97b2446..9e4653992 100644 --- a/ipa-metrics/src/partitioned.rs +++ b/ipa-metrics/src/partitioned.rs @@ -1,3 +1,19 @@ +//! This module enables metric partitioning that can be useful +//! when threads that emit metrics are shared across multiple executions. +//! A typical example for it are unit tests in Rust that share threads. +//! Having a global per-thread store would mean that it is not possible +//! to distinguish between different runs. +//! +//! Partitioning attempts to solve this with a global 16 byte identifier that +//! is set in thread local storage and read automatically by [`PartitionedStore`] +//! +//! Note that this module does not provide means to automatically set and unset +//! partitions. `ipa-metrics-tracing` defines a way to do it via tracing context +//! that is good enough for the vast majority of use cases. +//! +//! Because partitioned stores carry additional cost of extra lookup (partition -> store), +//! it is disabled by default and requires explicit opt-in via `partitioning` feature. + use std::{borrow::Borrow, cell::Cell}; use hashbrown::hash_map::Entry; @@ -9,23 +25,28 @@ use crate::{ MetricName, }; -/// Each partition is a unique 16 byte value. -pub type Partition = u128; - -pub fn set_partition(new: Partition) { - PARTITION.set(Some(new)); +thread_local! { + static PARTITION: Cell> = Cell::new(None); } -pub fn set_or_unset_partition(new: Option) { - PARTITION.set(new); -} +/// Each partition is a unique 8 byte value, meaning roughly 1B partitions +/// can be supported and the limiting factor is birthday bound. +pub type Partition = u64; -pub fn current_partition() -> Option { - PARTITION.get() -} +pub struct CurrentThreadContext; -thread_local! { - static PARTITION: Cell> = Cell::new(None); +impl CurrentThreadContext { + pub fn set(new: Partition) { + Self::toggle(Some(new)) + } + + pub fn toggle(new: Option) { + PARTITION.set(new); + } + + pub fn get() -> Option { + PARTITION.get() + } } /// Provides the same functionality as [`Store`], but partitioned @@ -38,7 +59,10 @@ thread_local! { /// can be used interchangeably. #[derive(Clone, Debug)] pub struct PartitionedStore { + /// Set of stores partitioned by [`Partition`] inner: hashbrown::HashMap, + /// We don't want to lose metrics that are emitted when partitions are not set. + /// So we provide a default store for those default_store: Store, } @@ -56,11 +80,6 @@ impl PartitionedStore { } } - pub fn with_current_partition T, T>(&mut self, f: F) -> T { - let mut store = self.get_mut(current_partition()); - f(&mut store) - } - pub fn with_partition T, T>( &self, partition: Partition, @@ -70,19 +89,6 @@ impl PartitionedStore { store.map(f) } - pub fn with_partition_mut T, T>( - &mut self, - partition: Partition, - f: F, - ) -> T { - let mut store = self.get_mut(Some(partition)); - f(&mut store) - } - - pub fn is_empty(&self) -> bool { - self.inner.is_empty() && self.default_store.is_empty() - } - pub fn merge(&mut self, other: Self) { for (partition, store) in other.inner { self.get_mut(Some(partition)).merge(store); @@ -90,23 +96,26 @@ impl PartitionedStore { self.default_store.merge(other.default_store); } - pub fn counter_value>(&self, name: B) -> CounterValue { - let name = name.borrow(); - if let Some(partition) = current_partition() { + pub fn counter_val<'a, const LABELS: usize, B: Borrow>>( + &'a self, + key: B, + ) -> CounterValue { + let name = key.borrow(); + if let Some(partition) = CurrentThreadContext::get() { self.inner .get(&partition) - .map(|store| store.counter_value(name)) + .map(|store| store.counter_val(name)) .unwrap_or_default() } else { - self.default_store.counter_value(name) + self.default_store.counter_val(name) } } - pub fn counter( - &mut self, - key: &MetricName<'_, LABELS>, - ) -> CounterHandle<'_, LABELS> { - if let Some(partition) = current_partition() { + pub fn counter<'a, const LABELS: usize, B: Borrow>>( + &'a mut self, + key: B, + ) -> CounterHandle<'a, LABELS> { + if let Some(partition) = CurrentThreadContext::get() { self.inner .entry(partition) .or_insert_with(|| Store::default()) @@ -120,7 +129,20 @@ impl PartitionedStore { self.inner.len() + self.default_store.len() } - fn get_mut(&mut self, partition: Option) -> &mut Store { + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn with_partition_mut T, T>( + &mut self, + partition: Partition, + f: F, + ) -> T { + let mut store = self.get_mut(Some(partition)); + f(&mut store) + } + + fn get_mut(&mut self, partition: Option) -> &mut Store { if let Some(v) = partition { match self.inner.entry(v) { Entry::Occupied(entry) => entry.into_mut(), @@ -135,8 +157,8 @@ impl PartitionedStore { #[cfg(test)] mod tests { use crate::{ - metric_name, - partitioned::{set_partition, PartitionedStore}, + counter, metric_name, + partitioned::{CurrentThreadContext, PartitionedStore}, }; #[test] @@ -168,27 +190,23 @@ mod tests { fn current_partition() { let metric = metric_name!("foo"); let mut store = PartitionedStore::new(); - set_partition(4); + store.counter(&metric).inc(7); - store.with_current_partition(|store| { - store.counter(&metric).inc(1); - }); - store.with_current_partition(|store| { - store.counter(&metric).inc(5); - }); + CurrentThreadContext::set(4); - assert_eq!( - 6, - store.with_current_partition(|store| store.counter(&metric).get()) - ); + store.counter(&metric).inc(1); + store.counter(&metric).inc(5); + + assert_eq!(6, store.counter_val(&metric)); + CurrentThreadContext::toggle(None); + assert_eq!(7, store.counter_val(&metric)); } #[test] fn empty() { - let mut store = PartitionedStore::new(); - store.with_current_partition(|store| { - store.counter(&metric_name!("foo")).inc(1); - }); + let mut store = PartitionedStore::default(); + assert!(store.is_empty()); + store.counter(&metric_name!("foo")).inc(1); assert!(!store.is_empty()); } @@ -196,16 +214,42 @@ mod tests { #[test] fn len() { let mut store = PartitionedStore::new(); - store.with_current_partition(|store| { - store.counter(&metric_name!("foo")).inc(1); - }); - set_partition(4); - store.with_current_partition(|store| { - store.counter(&metric_name!("foo")).inc(1); - }); + assert_eq!(0, store.len()); + + store.counter(metric_name!("foo")).inc(1); + CurrentThreadContext::set(4); + store.counter(metric_name!("foo")).inc(1); // one metric in partition 4, another one in default. Even that they are the same, // partitioned store cannot distinguish between them assert_eq!(2, store.len()); } + + #[test] + fn merge() { + let mut store1 = PartitionedStore::new(); + let mut store2 = PartitionedStore::new(); + store1.with_partition_mut(1, |store| store.counter(counter!("foo")).inc(1)); + store2.with_partition_mut(1, |store| store.counter(counter!("foo")).inc(1)); + store1.with_partition_mut(2, |store| store.counter(counter!("foo")).inc(2)); + store2.with_partition_mut(2, |store| store.counter(counter!("foo")).inc(2)); + + store1.counter(counter!("foo")).inc(3); + store2.counter(counter!("foo")).inc(3); + + store1.merge(store2); + assert_eq!( + 2, + store1 + .with_partition(1, |store| store.counter_val(counter!("foo"))) + .unwrap() + ); + assert_eq!( + 4, + store1 + .with_partition(2, |store| store.counter_val(counter!("foo"))) + .unwrap() + ); + assert_eq!(6, store1.counter_val(counter!("foo"))); + } } diff --git a/ipa-metrics/src/producer.rs b/ipa-metrics/src/producer.rs index 27925bed2..ddd445922 100644 --- a/ipa-metrics/src/producer.rs +++ b/ipa-metrics/src/producer.rs @@ -2,6 +2,17 @@ use crossbeam_channel::Sender; use crate::{context::CurrentThreadContext, MetricsStore}; +/// A handle to enable centralized metrics collection from the current thread. +/// +/// This is a cloneable handle, so it can be installed in multiple threads. +/// The handle is installed by calling [`install`], which returns a drop handle. +/// When the drop handle is dropped, the context of local store is flushed +/// to the collector thread. +/// +/// Thread local store is always enabled by [`MetricsContext`], so it is always +/// possible to have a local view of metrics emitted by this thread. +/// +/// [`install`]: Producer::install #[derive(Clone)] pub struct Producer { pub(super) tx: Sender, @@ -13,7 +24,7 @@ impl Producer { } /// Returns a drop handle that should be used when thread is stopped. - /// In an ideal world, a destructor on [`MetricsContext`] could do this, + /// One may think destructor on [`MetricsContext`] could do this, /// but as pointed in [`LocalKey`] documentation, deadlocks are possible /// if another TLS storage is accessed at destruction time. /// diff --git a/ipa-metrics/src/store.rs b/ipa-metrics/src/store.rs index 6a18b7b6b..c662eca76 100644 --- a/ipa-metrics/src/store.rs +++ b/ipa-metrics/src/store.rs @@ -3,18 +3,14 @@ use std::{borrow::Borrow, hash::BuildHasher}; use hashbrown::hash_map::RawEntryMut; use rustc_hash::FxBuildHasher; -use crate::{ - key::{OwnedMetricName, OwnedName}, - kind::CounterValue, - MetricName, -}; +use crate::{key::OwnedMetricName, kind::CounterValue, MetricName}; /// A basic store. Currently only supports counters. +/// Counters and other metrics are stored to optimize writes. That means, one lookup +/// per write. The cost of assembling the total count across all dimensions is absorbed +/// by readers #[derive(Clone, Debug)] pub struct Store { - // Counters and other metrics are stored to optimize writes. That means, one lookup - // per write. The cost of assembling the total count across all dimensions is absorbed - // by readers counters: hashbrown::HashMap, } @@ -50,10 +46,11 @@ impl Store { } impl Store { - pub fn counter( - &mut self, - key: &MetricName<'_, LABELS>, - ) -> CounterHandle<'_, LABELS> { + pub fn counter<'a, const LABELS: usize, B: Borrow>>( + &'a mut self, + key: B, + ) -> CounterHandle<'a, LABELS> { + let key = key.borrow(); let hash_builder = self.counters.hasher(); let hash = hash_builder.hash_one(key); let entry = self @@ -71,17 +68,22 @@ impl Store { } } - /// Returns the value for the specified metric across all dimensions. + /// Returns the value for the specified metric taking into account + /// its dimensionality. That is (foo, dim1 = 1, dim2 = 2) will be + /// different from (foo, dim1 = 1). /// The cost of this operation is `O(N*M)` where `N` - number of unique metrics - /// and `M` - number of all dimensions across all metrics. + /// registered in this store and `M` number of dimensions. /// /// Note that the cost can be improved if it ever becomes a bottleneck by /// creating a specialized two-level map (metric -> label -> value). - pub fn counter_value<'a, B: Borrow>>(&'a self, key: B) -> CounterValue { + pub fn counter_val<'a, const LABELS: usize, B: Borrow>>( + &'a self, + key: B, + ) -> CounterValue { let key = key.borrow(); let mut answer = 0; for (metric, value) in &self.counters { - if metric.key == key.key { + if metric.partial_match(key) { answer += value } } @@ -89,10 +91,6 @@ impl Store { answer } - pub fn counters(&self) -> impl Iterator { - self.counters.iter().map(|(key, value)| (key, *value)) - } - pub fn len(&self) -> usize { self.counters.len() } @@ -116,7 +114,7 @@ impl CounterHandle<'_, LABELS> { mod tests { use std::hash::{DefaultHasher, Hash, Hasher}; - use crate::{metric_name, store::Store, LabelValue}; + use crate::{counter, metric_name, store::Store, LabelValue}; impl LabelValue for &'static str { fn hash(&self) -> u64 { @@ -184,18 +182,27 @@ mod tests { fn counter_value() { let mut store = Store::default(); store - .counter(&metric_name!("foo", "h1" => &1, "h2" => &"1")) + .counter(counter!("foo", "h1" => &1, "h2" => &"1")) .inc(1); store - .counter(&metric_name!("foo", "h1" => &1, "h2" => &"2")) + .counter(counter!("foo", "h1" => &1, "h2" => &"2")) .inc(1); store - .counter(&metric_name!("foo", "h1" => &2, "h2" => &"1")) + .counter(counter!("foo", "h1" => &2, "h2" => &"1")) .inc(1); store - .counter(&metric_name!("foo", "h1" => &2, "h2" => &"2")) + .counter(counter!("foo", "h1" => &2, "h2" => &"2")) .inc(1); - - assert_eq!(4, store.counter_value(&metric_name!("foo"))); + store + .counter(counter!("bar", "h1" => &1, "h2" => &"1")) + .inc(3); + + assert_eq!(4, store.counter_val(counter!("foo"))); + assert_eq!( + 1, + store.counter_val(&counter!("foo", "h1" => &1, "h2" => &"2")) + ); + assert_eq!(2, store.counter_val(&counter!("foo", "h1" => &1))); + assert_eq!(2, store.counter_val(&counter!("foo", "h2" => &"2"))); } } From 98a0feb7c88d79e0b746b3d82e5fc8c20460f792 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 17 Oct 2024 10:44:26 -0700 Subject: [PATCH 144/191] Clarify `Batch::new(None, ...)` --- .../src/protocol/context/dzkp_validator.rs | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 3e83d797f..a734ce629 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -1414,10 +1414,16 @@ mod tests { ) } + impl Batch { + fn with_implicit_first_record(max_multiplications_per_gate: usize) -> Self { + Batch::new(None, max_multiplications_per_gate) + } + } + #[test] fn batch_allocation_small() { const SIZE: usize = 1; - let mut batch = Batch::new(None, SIZE); + let mut batch = Batch::with_implicit_first_record(SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); let segment = segment_from_entry(>::as_segment_entry( @@ -1432,7 +1438,7 @@ mod tests { #[test] fn batch_allocation_big() { const SIZE: usize = 2 * TARGET_PROOF_SIZE; - let mut batch = Batch::new(None, SIZE); + let mut batch = Batch::with_implicit_first_record(SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); let segment = segment_from_entry(>::as_segment_entry( @@ -1453,7 +1459,7 @@ mod tests { #[test] fn batch_fill() { const SIZE: usize = 10; - let mut batch = Batch::new(None, SIZE); + let mut batch = Batch::with_implicit_first_record(SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); let segment = segment_from_entry(>::as_segment_entry( @@ -1469,7 +1475,7 @@ mod tests { #[test] fn batch_fill_out_of_order() { - let mut batch = Batch::new(None, 3); + let mut batch = Batch::with_implicit_first_record(3); let ba0 = BA256::from((0, 0)); let ba1 = BA256::from((0, 1)); let ba2 = BA256::from((0, 2)); @@ -1502,7 +1508,8 @@ mod tests { #[test] fn batch_fill_at_offset() { - let mut batch = Batch::new(None, 3); + const SIZE: usize = 3; + let mut batch = Batch::with_implicit_first_record(SIZE); let ba0 = BA256::from((0, 0)); let ba1 = BA256::from((0, 1)); let ba2 = BA256::from((0, 2)); @@ -1535,7 +1542,8 @@ mod tests { #[test] fn batch_explicit_first_record() { - let mut batch = Batch::new(Some(RecordId::from(4)), 3); + const SIZE: usize = 3; + let mut batch = Batch::new(Some(RecordId::from(4)), SIZE); let ba6 = BA256::from((0, 6)); let segment = segment_from_entry(>::as_segment_entry( &ba6, @@ -1551,7 +1559,7 @@ mod tests { #[test] fn batch_is_empty() { const SIZE: usize = 10; - let mut batch = Batch::new(None, SIZE); + let mut batch = Batch::with_implicit_first_record(SIZE); assert!(batch.is_empty()); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); @@ -1568,7 +1576,7 @@ mod tests { )] fn batch_underflow() { const SIZE: usize = 10; - let mut batch = Batch::new(None, SIZE); + let mut batch = Batch::with_implicit_first_record(SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); let segment = segment_from_entry(>::as_segment_entry( @@ -1584,7 +1592,7 @@ mod tests { )] fn batch_overflow() { const SIZE: usize = 10; - let mut batch = Batch::new(None, SIZE); + let mut batch = Batch::with_implicit_first_record(SIZE); let zero = Boolean::ZERO; let zero_vec: >::Array = zero.into_array(); let segment = segment_from_entry(>::as_segment_entry( @@ -1720,13 +1728,13 @@ mod tests { // test for small and large segments, i.e. 8bit and 512 bit for segment_size in [8usize, 512usize] { // generate batch for the prover - let mut batch_prover = Batch::new(None, 1024 / segment_size); + let mut batch_prover = Batch::with_implicit_first_record(1024 / segment_size); // generate batch for the verifier on the left of the prover - let mut batch_left = Batch::new(None, 1024 / segment_size); + let mut batch_left = Batch::with_implicit_first_record(1024 / segment_size); // generate batch for the verifier on the right of the prover - let mut batch_right = Batch::new(None, 1024 / segment_size); + let mut batch_right = Batch::with_implicit_first_record(1024 / segment_size); // fill the batches with random values populate_batch( From e289209c82812f06fb533afff84442066c45dc8c Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Thu, 17 Oct 2024 11:37:52 -0700 Subject: [PATCH 145/191] unique bytes -> &[u8], use unwrap_infallible() --- ipa-core/src/report/hybrid.rs | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index e12c92ba4..c3b037aa1 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -12,7 +12,7 @@ use rand_core::{CryptoRng, RngCore}; use typenum::{Sum, Unsigned, U16}; use crate::{ - error::{BoxError, Error}, + error::{BoxError, Error, UnwrapInfallible}, ff::{boolean_array::BA64, Serializable}, hpke::{ open_in_place, seal_in_place, CryptError, EncapsulationSize, PrivateKeyRegistry, @@ -316,9 +316,10 @@ where Ok(HybridImpressionReport:: { match_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_mk)) - .map_err(|e| { - InvalidHybridReportError::DeserializationError("matchkey", e.into()) - })?, + .unwrap_infallible(), + /* .map_err(|e| { + InvalidHybridReportError::DeserializationError("matchkey", e.into()) + })?,*/ breakdown_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_btt)) .map_err(|e| { InvalidHybridReportError::DeserializationError("is_trigger", e.into()) @@ -398,14 +399,14 @@ impl TryFrom for EncryptedHybridReport { } pub trait UniqueBytes { - fn unique_bytes(&self) -> Vec; + fn unique_bytes(&self) -> &[u8]; } impl UniqueBytes for EncryptedHybridReport { /// We use the `TagSize` (the first 16 bytes of the ciphertext) for collision-detection /// See [analysis here for uniqueness](https://eprint.iacr.org/2019/624) - fn unique_bytes(&self) -> Vec { - self.mk_ciphertext()[0..TagSize::USIZE].to_vec() + fn unique_bytes(&self) -> &[u8] { + &self.mk_ciphertext()[0..TagSize::USIZE] } } @@ -419,8 +420,8 @@ where { /// We use the `TagSize` (the first 16 bytes of the ciphertext) for collision-detection /// See [analysis here for uniqueness](https://eprint.iacr.org/2019/624) - fn unique_bytes(&self) -> Vec { - self.mk_ciphertext()[0..TagSize::USIZE].to_vec() + fn unique_bytes(&self) -> &[u8] { + &self.mk_ciphertext()[0..TagSize::USIZE] } } @@ -449,7 +450,7 @@ impl UniqueBytesValidator { /// if the item inserted is not unique among all checked thus far pub fn check_duplicate(&mut self, item: &U) -> Result<(), Error> { self.check_counter += 1; - if self.insert(item.unique_bytes()) { + if self.insert(item.unique_bytes().to_vec()) { Ok(()) } else { Err(Error::DuplicateBytes(self.check_counter)) @@ -587,8 +588,8 @@ mod test { } impl UniqueBytes for UniqueByteHolder { - fn unique_bytes(&self) -> Vec { - self.bytes.clone() + fn unique_bytes(&self) -> &[u8] { + &self.bytes } } From 5a5a6f5ee79df6d68001c2874f88d9be2abb0700 Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Thu, 17 Oct 2024 12:04:23 -0700 Subject: [PATCH 146/191] commented code --- ipa-core/src/report/hybrid.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index c3b037aa1..edb614714 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -317,9 +317,6 @@ where Ok(HybridImpressionReport:: { match_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_mk)) .unwrap_infallible(), - /* .map_err(|e| { - InvalidHybridReportError::DeserializationError("matchkey", e.into()) - })?,*/ breakdown_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_btt)) .map_err(|e| { InvalidHybridReportError::DeserializationError("is_trigger", e.into()) From 07be2603f3d350fdcd65df919ad885e67771c4c5 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 17 Oct 2024 15:07:59 -0700 Subject: [PATCH 147/191] Clippy --- ipa-metrics/src/collector.rs | 27 +++++++++++++++++++-------- ipa-metrics/src/context.rs | 22 +++++++++++----------- ipa-metrics/src/controller.rs | 34 +++++++++++++++++++++++++++++++++- ipa-metrics/src/key.rs | 20 ++++++++++++-------- ipa-metrics/src/kind.rs | 1 - ipa-metrics/src/label.rs | 5 +++-- ipa-metrics/src/lib.rs | 15 ++++++++++++--- ipa-metrics/src/producer.rs | 2 +- ipa-metrics/src/store.rs | 17 +++++++++-------- 9 files changed, 100 insertions(+), 43 deletions(-) diff --git a/ipa-metrics/src/collector.rs b/ipa-metrics/src/collector.rs index 8cd4f105a..872d1ccbe 100644 --- a/ipa-metrics/src/collector.rs +++ b/ipa-metrics/src/collector.rs @@ -10,9 +10,11 @@ thread_local! { static COLLECTOR: RefCell> = const { RefCell::new(None) } } +/// Convenience struct to block the current thread on metric collection pub struct Installed; impl Installed { + #[allow(clippy::unused_self)] pub fn block_until_shutdown(&self) -> MetricsStore { MetricsCollector::with_current_mut(|c| { c.event_loop(); @@ -29,6 +31,11 @@ pub struct MetricsCollector { } impl MetricsCollector { + /// This installs metrics collection mechanism to current thread. + /// + /// ## Panics + /// It panics if there is another collector system already installed. + #[allow(clippy::must_use_candidate)] pub fn install(self) -> Installed { COLLECTOR.with_borrow_mut(|c| { assert!(c.replace(self).is_none(), "Already initialized"); @@ -49,7 +56,7 @@ impl MetricsCollector { Ok(store) => { tracing::trace!("Collector received more data: {store:?}"); println!("Collector received more data: {store:?}"); - self.local_store.merge(store) + self.local_store.merge(store); } Err(e) => { tracing::debug!("No more threads collecting metrics. Disconnected: {e}"); @@ -76,7 +83,7 @@ impl MetricsCollector { } } - pub fn with_current_mut T, T>(f: F) -> T { + fn with_current_mut T, T>(f: F) -> T { COLLECTOR.with_borrow_mut(|c| { let collector = c.as_mut().expect("Collector is installed"); f(collector) @@ -97,7 +104,7 @@ mod tests { thread::{Scope, ScopedJoinHandle}, }; - use crate::{counter, installer, producer::Producer, thread_installer}; + use crate::{counter, install, install_new_thread, producer::Producer}; struct MeteredScope<'scope, 'env: 'scope>(&'scope Scope<'scope, 'env>, Producer); @@ -131,7 +138,7 @@ mod tests { #[test] fn start_stop() { - let (collector, producer, controller) = installer(); + let (collector, producer, controller) = install(); let handle = thread::spawn(|| { let store = collector.install().block_until_shutdown(); store.counter_val(counter!("foo")) @@ -144,19 +151,23 @@ mod tests { controller.stop().unwrap(); }); - assert_eq!(8, handle.join().unwrap()) + assert_eq!(8, handle.join().unwrap()); } #[test] fn with_thread() { - let (producer, controller, handle) = thread_installer().unwrap(); + let (producer, controller, handle) = install_new_thread().unwrap(); thread::scope(move |s| { let s = s.metered(producer); s.spawn(|| counter!("baz", 4)); s.spawn(|| counter!("bar", 1)); - s.spawn(|| controller.stop().unwrap()); + s.spawn(|| { + let snapshot = controller.snapshot().unwrap(); + println!("snapshot: {snapshot:?}"); + controller.stop().unwrap(); + }); }); - handle.join().unwrap() // Collector thread should be terminated by now + handle.join().unwrap(); // Collector thread should be terminated by now } } diff --git a/ipa-metrics/src/context.rs b/ipa-metrics/src/context.rs index c91fc7906..ed5061415 100644 --- a/ipa-metrics/src/context.rs +++ b/ipa-metrics/src/context.rs @@ -29,11 +29,11 @@ impl CurrentThreadContext { } pub fn flush() { - METRICS_CTX.with_borrow_mut(|ctx| ctx.flush()); + METRICS_CTX.with_borrow_mut(MetricsContext::flush); } pub fn is_connected() -> bool { - METRICS_CTX.with_borrow(|ctx| ctx.is_connected()) + METRICS_CTX.with_borrow(MetricsContext::is_connected) } pub fn store T, T>(f: F) -> T { @@ -63,6 +63,7 @@ impl Default for MetricsContext { } impl MetricsContext { + #[must_use] pub const fn new() -> Self { Self { store: MetricsStore::new(), @@ -78,6 +79,7 @@ impl MetricsContext { self.tx = Some(tx); } + #[must_use] pub fn store(&self) -> &MetricsStore { &self.store } @@ -98,7 +100,7 @@ impl MetricsContext { if self.is_connected() { let store = mem::take(&mut self.store); match self.tx.as_ref().unwrap().send(store) { - Ok(_) => {} + Ok(()) => {} Err(e) => { tracing::warn!("MetricsContext is not connected: {e}"); } @@ -124,13 +126,13 @@ impl Drop for MetricsContext { mod tests { use std::thread; - use crate::{CurrentThreadPartitionContext, MetricsContext}; + use crate::MetricsContext; /// Each thread has its local store by default, and it is exclusive to it #[test] #[cfg(feature = "partitions")] fn local_store() { - use crate::context::CurrentThreadContext; + use crate::{context::CurrentThreadContext, CurrentThreadPartitionContext}; CurrentThreadPartitionContext::set(0xdeadbeef); counter!("foo", 7); @@ -152,7 +154,7 @@ mod tests { #[test] fn default() { - assert_eq!(0, MetricsContext::default().store().len()) + assert_eq!(0, MetricsContext::default().store().len()); } #[test] @@ -160,11 +162,9 @@ mod tests { let (tx, rx) = crossbeam_channel::unbounded(); let mut ctx = MetricsContext::new(); ctx.init(tx); - let handle = thread::spawn(move || { - if let Ok(_) = rx.recv() { - panic!("Context sent empty store"); - } - }); + let handle = + thread::spawn(move || assert!(rx.recv().is_err(), "Context sent non-empty store")); + ctx.flush(); drop(ctx); handle.join().unwrap(); diff --git a/ipa-metrics/src/controller.rs b/ipa-metrics/src/controller.rs index a70802f38..2f6e31194 100644 --- a/ipa-metrics/src/controller.rs +++ b/ipa-metrics/src/controller.rs @@ -7,11 +7,27 @@ pub enum Command { Stop(Sender<()>), } +/// Handle to communicate with centralized metrics collection system. pub struct Controller { pub(super) tx: Sender, } impl Controller { + /// Request new metric snapshot from the collector thread. + /// Blocks current thread until the snapshot is received + /// + /// ## Errors + /// If collector thread is disconnected or an error occurs during snapshot request + /// + /// ## Example + /// ```rust + /// use ipa_metrics::{install_new_thread, MetricsStore}; + /// + /// let (_, controller, _handle) = install_new_thread().unwrap(); + /// let snapshot = controller.snapshot().unwrap(); + /// println!("Current metrics: {snapshot:?}"); + /// ``` + #[inline] pub fn snapshot(&self) -> Result { let (tx, rx) = crossbeam_channel::bounded(0); self.tx @@ -20,11 +36,27 @@ impl Controller { rx.recv().map_err(|e| format!("Disconnected channel: {e}")) } + /// Send request to terminate the collector thread. + /// Blocks current thread until the snapshot is received. + /// If this request is successful, any subsequent snapshot + /// or stop requests will return an error. + /// + /// ## Errors + /// If collector thread is disconnected or an error occurs while sending + /// or receiving data from the collector thread. + /// + /// ## Example + /// ```rust + /// use ipa_metrics::{install_new_thread, MetricsStore}; + /// + /// let (_, controller, _handle) = install_new_thread().unwrap(); + /// controller.stop().unwrap(); + /// ``` pub fn stop(self) -> Result<(), String> { let (tx, rx) = crossbeam_channel::bounded(0); self.tx .send(Command::Stop(tx)) - .map_err(|e| format!("An error occurred while requesting metrics snapshot: {e}"))?; + .map_err(|e| format!("An error occurred while requesting termination: {e}"))?; rx.recv().map_err(|e| format!("Disconnected channel: {e}")) } } diff --git a/ipa-metrics/src/key.rs b/ipa-metrics/src/key.rs index 50a34b187..d5ec9644a 100644 --- a/ipa-metrics/src/key.rs +++ b/ipa-metrics/src/key.rs @@ -76,6 +76,9 @@ pub struct Name<'lv, const LABELS: usize = 0> { } impl<'lv, const LABELS: usize> Name<'lv, LABELS> { + /// Constructs this instance from key and labels. + /// ## Panics + /// If number of labels exceeds `MAX_LABELS`. pub fn from_parts>(key: I, labels: [Label<'lv>; LABELS]) -> Self { assert!( LABELS <= MAX_LABELS, @@ -116,6 +119,7 @@ pub struct OwnedName { } impl OwnedName { + #[must_use] pub fn key(&self) -> &'static str { self.key } @@ -125,11 +129,12 @@ impl OwnedName { } /// Checks that a subset of labels in `self` matches all values in `other`. + #[must_use] pub fn partial_match(&self, other: &Name<'_, LABELS>) -> bool { - if self.key != other.key { - false - } else { + if self.key == other.key { other.labels.iter().all(|l| self.find_label(l)) + } else { + false } } @@ -142,7 +147,7 @@ impl Hash for Name<'_, LABELS> { fn hash(&self, state: &mut H) { state.write(self.key.as_bytes()); for label in &self.labels { - label.hash(state) + label.hash(state); } } } @@ -157,14 +162,13 @@ impl From<&'static str> for Name<'_, 0> { } pub trait UniqueElements { + #[must_use] fn enforce_unique(self) -> Self; } impl UniqueElements for [Label<'_>; 2] { fn enforce_unique(self) -> Self { - if self[0].name == self[1].name { - panic!("label names must be unique") - } + assert_ne!(self[0].name, self[1].name, "label names must be unique"); self } @@ -202,7 +206,7 @@ impl Hash for OwnedName { fn hash(&self, state: &mut H) { state.write(self.key.as_bytes()); for label in self.labels.iter().flatten() { - label.hash(state) + label.hash(state); } } } diff --git a/ipa-metrics/src/kind.rs b/ipa-metrics/src/kind.rs index b6abf7e4e..3a48d105b 100644 --- a/ipa-metrics/src/kind.rs +++ b/ipa-metrics/src/kind.rs @@ -1,6 +1,5 @@ //! Different metric types supported by this crate. //! Currently, only counters are supported. -//! TODO: add more /// Counters are simple 8 byte values. pub type CounterValue = u64; diff --git a/ipa-metrics/src/label.rs b/ipa-metrics/src/label.rs index d35680e08..f2ac183e9 100644 --- a/ipa-metrics/src/label.rs +++ b/ipa-metrics/src/label.rs @@ -38,6 +38,7 @@ pub struct Label<'lv> { } impl Label<'_> { + #[must_use] pub fn to_owned(&self) -> OwnedLabel { OwnedLabel { name: self.name, @@ -104,7 +105,7 @@ impl OwnedLabel { impl Hash for OwnedLabel { fn hash(&self, state: &mut H) { - self.as_borrowed().hash(state) + self.as_borrowed().hash(state); } } @@ -137,7 +138,7 @@ mod tests { assert_ne!(compute_hash(&foo_1), compute_hash(&foo_2)); assert_ne!(foo_2.to_owned(), foo_1); - assert_eq!(compute_hash(&foo_1), compute_hash(foo_1.to_owned())) + assert_eq!(compute_hash(&foo_1), compute_hash(foo_1.to_owned())); } #[test] diff --git a/ipa-metrics/src/lib.rs b/ipa-metrics/src/lib.rs index 843f327a2..552fac5a1 100644 --- a/ipa-metrics/src/lib.rs +++ b/ipa-metrics/src/lib.rs @@ -1,3 +1,7 @@ +#![deny(clippy::pedantic)] +#![allow(clippy::similar_names)] +#![allow(clippy::module_name_repetitions)] + mod collector; mod context; mod controller; @@ -25,7 +29,8 @@ pub use producer::Producer as MetricsProducer; #[cfg(not(feature = "partitions"))] pub use store::Store as MetricsStore; -pub fn installer() -> ( +#[must_use] +pub fn install() -> ( MetricsCollector, MetricsProducer, MetricsCollectorController, @@ -43,9 +48,13 @@ pub fn installer() -> ( ) } -pub fn thread_installer( +/// Same as [`installer]` but spawns a new thread to run the collector. +/// +/// ## Errors +/// if thread cannot be started +pub fn install_new_thread( ) -> io::Result<(MetricsProducer, MetricsCollectorController, JoinHandle<()>)> { - let (collector, producer, controller) = installer(); + let (collector, producer, controller) = install(); let handle = std::thread::Builder::new() .name("metric-collector".to_string()) .spawn(|| { diff --git a/ipa-metrics/src/producer.rs b/ipa-metrics/src/producer.rs index ddd445922..f9ee42cc3 100644 --- a/ipa-metrics/src/producer.rs +++ b/ipa-metrics/src/producer.rs @@ -43,6 +43,6 @@ pub struct ProducerDropHandle; impl Drop for ProducerDropHandle { fn drop(&mut self) { - CurrentThreadContext::flush() + CurrentThreadContext::flush(); } } diff --git a/ipa-metrics/src/store.rs b/ipa-metrics/src/store.rs index c662eca76..34982a0b1 100644 --- a/ipa-metrics/src/store.rs +++ b/ipa-metrics/src/store.rs @@ -21,13 +21,14 @@ impl Default for Store { } impl Store { + #[must_use] pub const fn new() -> Self { Self { counters: hashbrown::HashMap::with_hasher(FxBuildHasher), } } - pub(crate) fn merge(&mut self, other: Self) { + pub fn merge(&mut self, other: Self) { for (k, v) in other.counters { let hash_builder = self.counters.hasher(); let hash = hash_builder.hash_one(&k); @@ -40,12 +41,6 @@ impl Store { } } - pub fn is_empty(&self) -> bool { - self.counters.is_empty() - } -} - -impl Store { pub fn counter<'a, const LABELS: usize, B: Borrow>>( &'a mut self, key: B, @@ -84,16 +79,22 @@ impl Store { let mut answer = 0; for (metric, value) in &self.counters { if metric.partial_match(key) { - answer += value + answer += value; } } answer } + #[must_use] pub fn len(&self) -> usize { self.counters.len() } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } pub struct CounterHandle<'a, const LABELS: usize> { From f064fb90916d4362df859ce8307271edba04f478 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 17 Oct 2024 15:53:27 -0700 Subject: [PATCH 148/191] Final touches --- ipa-metrics/src/collector.rs | 1 - ipa-metrics/src/context.rs | 4 ---- ipa-metrics/src/key.rs | 7 +------ ipa-metrics/src/label.rs | 25 +++++++++++++++++-------- ipa-metrics/src/store.rs | 29 ++++++++++++++++++++++------- 5 files changed, 40 insertions(+), 26 deletions(-) diff --git a/ipa-metrics/src/collector.rs b/ipa-metrics/src/collector.rs index 872d1ccbe..50f2b9b8f 100644 --- a/ipa-metrics/src/collector.rs +++ b/ipa-metrics/src/collector.rs @@ -66,7 +66,6 @@ impl MetricsCollector { i if i == command_idx => match next_op.recv(&self.command_rx) { Ok(ControllerCommand::Snapshot(tx)) => { tracing::trace!("Snapshot request received"); - println!("snapshot request received"); tx.send(self.local_store.clone()).unwrap(); } Ok(ControllerCommand::Stop(tx)) => { diff --git a/ipa-metrics/src/context.rs b/ipa-metrics/src/context.rs index ed5061415..1c1ff3d7f 100644 --- a/ipa-metrics/src/context.rs +++ b/ipa-metrics/src/context.rs @@ -32,10 +32,6 @@ impl CurrentThreadContext { METRICS_CTX.with_borrow_mut(MetricsContext::flush); } - pub fn is_connected() -> bool { - METRICS_CTX.with_borrow(MetricsContext::is_connected) - } - pub fn store T, T>(f: F) -> T { METRICS_CTX.with_borrow(|ctx| f(ctx.store())) } diff --git a/ipa-metrics/src/key.rs b/ipa-metrics/src/key.rs index d5ec9644a..eadb5e392 100644 --- a/ipa-metrics/src/key.rs +++ b/ipa-metrics/src/key.rs @@ -114,16 +114,11 @@ impl<'lv, const LABELS: usize> Name<'lv, LABELS> { /// This is the key inside metric stores which are simple hashmaps. #[derive(Debug, Clone)] pub struct OwnedName { - pub(super) key: &'static str, + key: &'static str, labels: [Option; 5], } impl OwnedName { - #[must_use] - pub fn key(&self) -> &'static str { - self.key - } - pub fn labels(&self) -> impl Iterator { self.labels.iter().filter_map(|l| l.as_ref()) } diff --git a/ipa-metrics/src/label.rs b/ipa-metrics/src/label.rs index f2ac183e9..9ee414e26 100644 --- a/ipa-metrics/src/label.rs +++ b/ipa-metrics/src/label.rs @@ -72,6 +72,8 @@ impl PartialEq for Label<'_> { } } +/// Same as [`Label`] but owns the values. This instance is stored +/// inside metric hashmaps as they need to own the keys. pub struct OwnedLabel { pub name: &'static str, pub val: Box, @@ -93,14 +95,6 @@ impl OwnedLabel { val: self.val.as_ref(), } } - - pub fn name(&self) -> &'static str { - self.name - } - - pub fn str_value(&self) -> String { - self.val.to_string() - } } impl Hash for OwnedLabel { @@ -158,4 +152,19 @@ mod tests { compute_hash(&metric_name!("foo", "l1" => &1, "l2" => &1)), ); } + + #[test] + fn clone() { + let metric = metric_name!("foo", "l1" => &1).to_owned(); + assert_eq!(&metric.labels().next(), &metric.labels().next().clone()); + } + + #[test] + fn fields() { + let metric = metric_name!("foo", "l1" => &1).to_owned(); + let label = metric.labels().next().unwrap().to_owned(); + + assert_eq!(label.name, "l1"); + assert_eq!(label.val.to_string(), "1"); + } } diff --git a/ipa-metrics/src/store.rs b/ipa-metrics/src/store.rs index 34982a0b1..58160ed4c 100644 --- a/ipa-metrics/src/store.rs +++ b/ipa-metrics/src/store.rs @@ -76,14 +76,12 @@ impl Store { key: B, ) -> CounterValue { let key = key.borrow(); - let mut answer = 0; - for (metric, value) in &self.counters { - if metric.partial_match(key) { - answer += value; - } - } - answer + self.counters + .iter() + .filter(|(counter, _)| counter.partial_match(key)) + .map(|(_, val)| val) + .sum() } #[must_use] @@ -206,4 +204,21 @@ mod tests { assert_eq!(2, store.counter_val(&counter!("foo", "h1" => &1))); assert_eq!(2, store.counter_val(&counter!("foo", "h2" => &"2"))); } + + #[test] + fn len_empty() { + let mut store = Store::default(); + assert!(store.is_empty()); + assert_eq!(0, store.len()); + + store.counter(counter!("foo")).inc(1); + assert!(!store.is_empty()); + assert_eq!(1, store.len()); + + store.counter(counter!("foo")).inc(1); + assert_eq!(1, store.len()); + + store.counter(counter!("bar")).inc(1); + assert_eq!(2, store.len()); + } } From be3740cb06298f52d2fb5c6aaae1b681ec06c6e8 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Thu, 17 Oct 2024 16:09:44 -0700 Subject: [PATCH 149/191] Addressing comments --- ipa-core/src/net/client/mod.rs | 77 ++++++++++++++++------------------ ipa-core/src/net/mod.rs | 60 +++++++++++++++++++++++--- ipa-core/src/net/server/mod.rs | 33 ++++++++------- ipa-core/src/net/transport.rs | 9 ++-- ipa-core/src/sharding.rs | 40 +----------------- 5 files changed, 112 insertions(+), 107 deletions(-) diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 62cb04b75..8fed45a8b 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -25,6 +25,7 @@ use pin_project::pin_project; use rustls::RootCertStore; use tracing::error; +use super::{ConnectionFlavor, Helper}; use crate::{ config::{ ClientConfig, HyperClientConfigurator, NetworkConfig, OwnedCertificate, OwnedPrivateKey, @@ -35,13 +36,12 @@ use crate::{ query::{PrepareQuery, QueryConfig, QueryInput}, TransportIdentity, }, - net::{http_serde, Error, CRYPTO_PROVIDER, HTTP_CLIENT_ID_HEADER}, + net::{http_serde, Error, CRYPTO_PROVIDER}, protocol::{Gate, QueryId}, - sharding::{Ring, TransportRestriction}, }; -#[derive(Default, Debug)] -pub enum ClientIdentity { +#[derive(Default)] +pub enum ClientIdentity { /// Claim the specified helper identity without any additional authentication. /// /// This is only supported for HTTP clients. @@ -57,7 +57,7 @@ pub enum ClientIdentity { None, } -impl ClientIdentity { +impl ClientIdentity { /// Authenticates clients with an X.509 certificate using the provided certificate and private /// key. Certificate must be in PEM format, private key encoding must be [`PKCS8`]. /// @@ -149,13 +149,33 @@ impl Future for ResponseFuture { } } +/// Helper to read a possible error response to a request that returns nothing on success +/// +/// # Errors +/// If there was an error reading the response body or if the request itself failed. +pub async fn resp_ok(resp: ResponseFromEndpoint) -> Result<(), Error> { + if resp.status().is_success() { + Ok(()) + } else { + Err(Error::from_failed_resp(resp).await) + } +} + +/// Reads the entire response from the server into Bytes +/// +/// # Errors +/// If there was an error collecting the response stream. +async fn response_to_bytes(resp: ResponseFromEndpoint) -> Result { + Ok(resp.into_body().collect().await?.to_bytes()) +} + /// TODO: we need a client that can be used by any system that is not aware of the internals /// of the helper network. That means that create query and send inputs API need to be /// separated from prepare/step data etc. /// TODO: It probably isn't necessary to always use `[MpcHelperClient; 3]`. Instead, a single /// client can be configured to talk to all three helpers. #[derive(Debug, Clone)] -pub struct MpcHelperClient { +pub struct MpcHelperClient { client: Client, Body>, scheme: uri::Scheme, authority: uri::Authority, @@ -163,7 +183,7 @@ pub struct MpcHelperClient { _restriction: PhantomData, } -impl MpcHelperClient { +impl MpcHelperClient { /// Create a new client with the given configuration /// /// `identity`, if present, configures whether and how the client will authenticate to the server @@ -177,7 +197,6 @@ impl MpcHelperClient { client_config: &ClientConfig, peer_config: PeerConfig, identity: ClientIdentity, - header_name: &'static HeaderName, ) -> Self { let (connector, auth_header) = if peer_config.url.scheme() == Some(&Scheme::HTTP) { // This connector works for both http and https. A regular HttpConnector would suffice, @@ -188,7 +207,7 @@ impl MpcHelperClient { None } ClientIdentity::Header(id) => Some(( - header_name.clone(), + R::identity_header(), HeaderValue::from_str(id.as_str().as_ref()).unwrap(), )), ClientIdentity::None => None, @@ -292,26 +311,6 @@ impl MpcHelperClient { } } - /// Helper to read a possible error response to a request that returns nothing on success - /// - /// # Errors - /// If there was an error reading the response body or if the request itself failed. - pub async fn resp_ok(resp: ResponseFromEndpoint) -> Result<(), Error> { - if resp.status().is_success() { - Ok(()) - } else { - Err(Error::from_failed_resp(resp).await) - } - } - - /// Reads the entire response from the server into Bytes - /// - /// # Errors - /// If there was an error collecting the response stream. - async fn response_to_bytes(resp: ResponseFromEndpoint) -> Result { - Ok(resp.into_body().collect().await?.to_bytes()) - } - /// Responds with whatever input is passed to it /// # Errors /// If the request has illegal arguments, or fails to deliver to helper @@ -323,7 +322,7 @@ impl MpcHelperClient { let resp = self.request(req).await?; let status = resp.status(); if status.is_success() { - let bytes = Self::response_to_bytes(resp).await?; + let bytes = response_to_bytes(resp).await?; let http_serde::echo::Request { mut query_params, .. } = serde_json::from_slice(&bytes)?; @@ -368,11 +367,11 @@ impl MpcHelperClient { let req = http_serde::query::prepare::Request::new(data); let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; let resp = self.request(req).await?; - Self::resp_ok(resp).await + resp_ok(resp).await } } -impl MpcHelperClient { +impl MpcHelperClient { /// Create a set of clients for the MPC helpers in the supplied helper network configuration. /// /// This function returns a set of three clients, which may be used to talk to each of the @@ -386,7 +385,7 @@ impl MpcHelperClient { pub fn from_conf( runtime: &IpaRuntime, conf: &NetworkConfig, - identity: &ClientIdentity, + identity: &ClientIdentity, ) -> [Self; 3] { conf.peers().each_ref().map(|peer_conf| { Self::new( @@ -394,7 +393,6 @@ impl MpcHelperClient { &conf.client, peer_conf.clone(), identity.clone_with_key(), - &HTTP_CLIENT_ID_HEADER, ) }) } @@ -408,7 +406,7 @@ impl MpcHelperClient { let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; let resp = self.request(req).await?; if resp.status().is_success() { - let bytes = Self::response_to_bytes(resp).await?; + let bytes = response_to_bytes(resp).await?; let http_serde::query::create::ResponseBody { query_id } = serde_json::from_slice(&bytes)?; Ok(query_id) @@ -426,7 +424,7 @@ impl MpcHelperClient { let req = http_serde::query::input::Request::new(data); let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; let resp = self.request(req).await?; - Self::resp_ok(resp).await + resp_ok(resp).await } /// Retrieve the status of a query. @@ -443,7 +441,7 @@ impl MpcHelperClient { let resp = self.request(req).await?; if resp.status().is_success() { - let bytes = Self::response_to_bytes(resp).await?; + let bytes = response_to_bytes(resp).await?; let http_serde::query::status::ResponseBody { status } = serde_json::from_slice(&bytes)?; Ok(status) @@ -523,8 +521,7 @@ pub(crate) mod tests { IpaRuntime::current(), &ClientConfig::default(), peer_config, - ClientIdentity::::None, - &HTTP_CLIENT_ID_HEADER, + ClientIdentity::::None, ); // The server's self-signed test cert is not in the system truststore, and we didn't supply @@ -692,7 +689,7 @@ pub(crate) mod tests { .await .unwrap(); - MpcHelperClient::::resp_ok(resp).await.unwrap(); + resp_ok(resp).await.unwrap(); let mut stream = Arc::clone(&transport) .receive(HelperIdentity::ONE, (QueryId, expected_step.clone())) diff --git a/ipa-core/src/net/mod.rs b/ipa-core/src/net/mod.rs index 332af7f49..104d33408 100644 --- a/ipa-core/src/net/mod.rs +++ b/ipa-core/src/net/mod.rs @@ -1,14 +1,19 @@ use std::{ + fmt::Debug, io::{self, BufRead}, sync::Arc, }; -use axum::http::HeaderName; +use hyper::header::HeaderName; use once_cell::sync::Lazy; use rustls::crypto::CryptoProvider; use rustls_pki_types::CertificateDer; -use crate::config::{OwnedCertificate, OwnedPrivateKey}; +use crate::{ + config::{OwnedCertificate, OwnedPrivateKey}, + helpers::{HelperIdentity, TransportIdentity}, + sharding::ShardIndex, +}; mod client; mod error; @@ -23,10 +28,9 @@ pub use error::Error; pub use server::{MpcHelperServer, TracingSpanMaker}; pub use transport::{HttpShardTransport, HttpTransport}; -pub const APPLICATION_JSON: &str = "application/json"; -pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; -pub static HTTP_CLIENT_ID_HEADER: HeaderName = - HeaderName::from_static("x-unverified-client-identity"); +const APPLICATION_JSON: &str = "application/json"; +const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; +static HTTP_HELPER_ID_HEADER: HeaderName = HeaderName::from_static("x-unverified-helper-identity"); pub static HTTP_SHARD_INDEX_HEADER: HeaderName = HeaderName::from_static("x-unverified-shard-index"); @@ -43,6 +47,50 @@ pub(crate) const MAX_HTTP2_CONCURRENT_STREAMS: u32 = 5000; static CRYPTO_PROVIDER: Lazy> = Lazy::new(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider())); +/// This simple trait is used to make aware on what transport dimnsion one is running. Structs like +/// [`MpcHelperClient`] use it to know whether they are talking to other servers as Shards +/// inside a Helper or as a Helper talking to another Helper in a Ring. This trait can be used to +/// limit the functions exposed by a struct impl depending on the context that it's being used. +/// Continuing the previous example, the functions a [`MpcHelperClient`] provides are dependent +/// on whether it's communicating with another Shard or another Helper. +/// +/// This trait is a safety restriction so that structs or traits only expose an API that's +/// meaningful for their specific context. When used as a generic bound, it also spreads through +/// the types making it harder to be misused or combining incompatible types, e.g. Using a +/// [`ShardIndex`] with a [`Shard`]. +pub trait ConnectionFlavor: Debug + Send + Sync + Clone + 'static { + /// The meaningful identity used in this transport dimension. + type Identity: TransportIdentity; + + fn identity_header() -> HeaderName; +} + +/// Shard-to-shard communication marker. +/// This marker is used to restrict communication inside a single Helper, with other shards. +#[derive(Debug, Copy, Clone)] +pub struct Shard; + +/// Helper-to-helper communication marker. +/// This marker is used to restrict communication between Helpers. This communication usually has +/// more restrictions. 3 Hosts with the same sharding index are conencted in a Ring. +#[derive(Debug, Copy, Clone)] +pub struct Helper; + +impl ConnectionFlavor for Shard { + type Identity = ShardIndex; + + fn identity_header() -> HeaderName { + HTTP_SHARD_INDEX_HEADER.clone() + } +} +impl ConnectionFlavor for Helper { + type Identity = HelperIdentity; + + fn identity_header() -> HeaderName { + HTTP_HELPER_ID_HEADER.clone() + } +} + /// Reads certificates and a private key from the corresponding readers /// /// # Errors diff --git a/ipa-core/src/net/server/mod.rs b/ipa-core/src/net/server/mod.rs index 1f7d9b6ec..0aee2d832 100644 --- a/ipa-core/src/net/server/mod.rs +++ b/ipa-core/src/net/server/mod.rs @@ -31,7 +31,7 @@ use futures::{ future::{ready, BoxFuture, Either, Ready}, FutureExt, }; -use hyper::{body::Incoming, header::HeaderName, Request}; +use hyper::{body::Incoming, Request}; use metrics::increment_counter; use rustls::{server::WebPkiClientVerifier, RootCertStore}; use rustls_pki_types::CertificateDer; @@ -40,6 +40,7 @@ use tower::{layer::layer_fn, Service}; use tower_http::trace::TraceLayer; use tracing::{error, Span}; +use super::HTTP_HELPER_ID_HEADER; use crate::{ config::{NetworkConfig, OwnedCertificate, OwnedPrivateKey, ServerConfig, TlsConfig}, error::BoxError, @@ -325,6 +326,17 @@ impl Deref for ClientIdentity { } } +impl TryFrom for ClientIdentity { + type Error = Error; + + fn try_from(value: HeaderValue) -> Result { + let header_str = value.to_str()?; + HelperIdentity::from_str(header_str) + .map_err(|e| Error::InvalidHeader(Box::new(e))) + .map(ClientIdentity) + } +} + /// `Accept`or that sets an axum `Extension` indiciating the authenticated remote helper identity. #[derive(Clone)] struct ClientCertRecognizingAcceptor { @@ -427,10 +439,6 @@ impl>> Service> for SetClientIdentityFromCer } } -/// Name of the header that passes the client identity when not using HTTPS. -pub static HTTP_CLIENT_ID_HEADER: HeaderName = - HeaderName::from_static("x-unverified-client-identity"); - /// Service wrapper that gets a client helper identity from a header. /// /// Since this allows a client to claim any identity, it is completely @@ -444,13 +452,6 @@ impl SetClientIdentityFromHeader { fn new(inner: S) -> Self { Self { inner } } - - fn parse_client_id(header_value: &HeaderValue) -> Result { - let header_str = header_value.to_str()?; - HelperIdentity::from_str(header_str) - .map_err(|e| Error::InvalidHeader(Box::new(e))) - .map(ClientIdentity) - } } impl, Response = Response>> Service> @@ -466,9 +467,9 @@ impl, Response = Response>> Service> } fn call(&mut self, mut req: Request) -> Self::Future { - if let Some(header_value) = req.headers().get(&HTTP_CLIENT_ID_HEADER) { - let id_result = Self::parse_client_id(header_value) - .map_err(|e| Error::InvalidHeader(format!("{HTTP_CLIENT_ID_HEADER}: {e}").into())); + if let Some(header_value) = req.headers().get(&HTTP_HELPER_ID_HEADER) { + let id_result = ClientIdentity::try_from(header_value.clone()) + .map_err(|e| Error::InvalidHeader(format!("{HTTP_HELPER_ID_HEADER}: {e}").into())); match id_result { Ok(id) => req.extensions_mut().insert(id), Err(err) => return ready(Ok(err.into_response())).right_future(), @@ -730,7 +731,7 @@ mod e2e_tests { let expected = expected_req(addr.to_string()); let req = http_req(&expected, uri::Scheme::HTTP, addr.to_string()); let response = client.request(req).await.unwrap(); - println!("{}", response.status()); + assert_eq!(response.status(), StatusCode::OK); assert_eq!( diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 9a7732a33..9fbdc8103 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -9,6 +9,7 @@ use async_trait::async_trait; use futures::{Stream, TryFutureExt}; use pin_project::{pin_project, pinned_drop}; +use super::client::resp_ok; use crate::{ config::{NetworkConfig, ServerConfig}, executor::IpaRuntime, @@ -21,7 +22,7 @@ use crate::{ }, net::{client::MpcHelperClient, error::Error, MpcHelperServer}, protocol::{Gate, QueryId}, - sharding::{Ring, ShardIndex}, + sharding::ShardIndex, sync::Arc, }; @@ -205,11 +206,7 @@ impl Transport for Arc { // - avoid blocking this task, if the current runtime is overloaded // - use the runtime that enables IO (current runtime may not). self.http_runtime - .spawn( - resp_future - .map_err(Into::into) - .and_then(MpcHelperClient::::resp_ok), - ) + .spawn(resp_future.map_err(Into::into).and_then(resp_ok)) .await?; Ok(()) } diff --git a/ipa-core/src/sharding.rs b/ipa-core/src/sharding.rs index 671a1675a..195573f45 100644 --- a/ipa-core/src/sharding.rs +++ b/ipa-core/src/sharding.rs @@ -3,46 +3,8 @@ use std::{ num::TryFromIntError, }; -use serde::{Deserialize, Serialize}; - -use crate::helpers::{HelperIdentity, TransportIdentity}; - -/// This simple trait is used to make aware on what transport dimnsion one is running. Structs like -/// [`crate::net::client::MpcHelperClient`] use it to know whether they are talking to other -/// servers as Shards inside a Helper or as a Helper talking to another Helper in a Ring. This -/// trait can be used to limit the functions exposed by a struct impl depending on the context that -/// it's being used. Continuing the previous example, the functions a -/// [`crate::net::client::MpcHelperClient`] provides are dependent on whether it's communicating -/// with another Shard or another Helper. -/// -/// This trait is a safety restriction so that structs or traits only expose an API that's -/// meaningful for their specific context. When used as a generic bound, it also spreads through -/// the types making it harder to be misused or combining incompatible types, e.g. Using a -/// [`ShardIndex`] with a [`Ring`]. -pub trait TransportRestriction: Debug + Send + Sync + Clone + 'static { - /// The meaningful identity used in this transport dimension. - type Identity: TransportIdentity; -} - -/// This marker is used to restrict communication inside a single Helper, with other shards. -#[derive(Debug, Copy, Clone)] -pub struct Sharding; - -/// This marker is used to restrict communication inter Helpers. This communication usually has -/// more restrictions. 3 Hosts with the same sharding index are conencted in a Ring. -#[derive(Debug, Copy, Clone)] -pub struct Ring; - -impl TransportRestriction for Sharding { - type Identity = ShardIndex; -} -impl TransportRestriction for Ring { - type Identity = HelperIdentity; -} - /// A unique zero-based index of the helper shard. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] -#[serde(from = "u32")] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ShardIndex(pub u32); impl From for u32 { From 5183c9f0438d7d9e61a7f30cc47252e5ec540728 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Thu, 17 Oct 2024 17:48:05 -0700 Subject: [PATCH 150/191] fix typo in shuffle/sharded.rs (#1359) --- ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs index 7bc766917..48c02c103 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs @@ -495,7 +495,7 @@ mod tests { let inputs = [1_u32, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] .map(BA8::truncate_from) .to_vec(); - let mut result = sharded_shuffle::<3, D>(inputs.clone()).await; + let mut result = sharded_shuffle::(inputs.clone()).await; assert_ne!(inputs, result); result.sort_by_key(U128Conversions::as_u128); From c3377eb28064b723c70b9a3ba7a6a9f064261d5c Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Thu, 17 Oct 2024 17:52:19 -0700 Subject: [PATCH 151/191] Rename generic param --- ipa-core/src/net/client/mod.rs | 18 +++++++++--------- ipa-core/src/net/mod.rs | 3 ++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 8fed45a8b..d334c6a48 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -41,11 +41,11 @@ use crate::{ }; #[derive(Default)] -pub enum ClientIdentity { +pub enum ClientIdentity { /// Claim the specified helper identity without any additional authentication. /// /// This is only supported for HTTP clients. - Header(R::Identity), + Header(F::Identity), /// Authenticate with an X.509 certificate or a certificate chain. /// @@ -57,7 +57,7 @@ pub enum ClientIdentity { None, } -impl ClientIdentity { +impl ClientIdentity { /// Authenticates clients with an X.509 certificate using the provided certificate and private /// key. Certificate must be in PEM format, private key encoding must be [`PKCS8`]. /// @@ -82,7 +82,7 @@ impl ClientIdentity { /// to own a private key, and we need to create 3 with the same config, we provide Clone /// capabilities via this method to `ClientIdentity`. #[must_use] - pub fn clone_with_key(&self) -> ClientIdentity { + pub fn clone_with_key(&self) -> ClientIdentity { match self { Self::Certificate((c, pk)) => Self::Certificate((c.clone(), pk.clone_key())), Self::Header(h) => Self::Header(*h), @@ -175,15 +175,15 @@ async fn response_to_bytes(resp: ResponseFromEndpoint) -> Result { /// TODO: It probably isn't necessary to always use `[MpcHelperClient; 3]`. Instead, a single /// client can be configured to talk to all three helpers. #[derive(Debug, Clone)] -pub struct MpcHelperClient { +pub struct MpcHelperClient { client: Client, Body>, scheme: uri::Scheme, authority: uri::Authority, auth_header: Option<(HeaderName, HeaderValue)>, - _restriction: PhantomData, + _restriction: PhantomData, } -impl MpcHelperClient { +impl MpcHelperClient { /// Create a new client with the given configuration /// /// `identity`, if present, configures whether and how the client will authenticate to the server @@ -196,7 +196,7 @@ impl MpcHelperClient { runtime: IpaRuntime, client_config: &ClientConfig, peer_config: PeerConfig, - identity: ClientIdentity, + identity: ClientIdentity, ) -> Self { let (connector, auth_header) = if peer_config.url.scheme() == Some(&Scheme::HTTP) { // This connector works for both http and https. A regular HttpConnector would suffice, @@ -207,7 +207,7 @@ impl MpcHelperClient { None } ClientIdentity::Header(id) => Some(( - R::identity_header(), + F::identity_header(), HeaderValue::from_str(id.as_str().as_ref()).unwrap(), )), ClientIdentity::None => None, diff --git a/ipa-core/src/net/mod.rs b/ipa-core/src/net/mod.rs index 104d33408..58981f7cc 100644 --- a/ipa-core/src/net/mod.rs +++ b/ipa-core/src/net/mod.rs @@ -50,7 +50,7 @@ static CRYPTO_PROVIDER: Lazy> = /// This simple trait is used to make aware on what transport dimnsion one is running. Structs like /// [`MpcHelperClient`] use it to know whether they are talking to other servers as Shards /// inside a Helper or as a Helper talking to another Helper in a Ring. This trait can be used to -/// limit the functions exposed by a struct impl depending on the context that it's being used. +/// limit the functions exposed by a struct impl, depending on the context that it's being used. /// Continuing the previous example, the functions a [`MpcHelperClient`] provides are dependent /// on whether it's communicating with another Shard or another Helper. /// @@ -62,6 +62,7 @@ pub trait ConnectionFlavor: Debug + Send + Sync + Clone + 'static { /// The meaningful identity used in this transport dimension. type Identity: TransportIdentity; + /// The header to be used to identify a HTTP request fn identity_header() -> HeaderName; } From dd8417b6eee2979c370d3e14a95f6fcc8db30f9d Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 17 Oct 2024 17:56:49 -0700 Subject: [PATCH 152/191] Fix flaky test --- ipa-metrics/src/collector.rs | 16 +++++++++++++--- ipa-metrics/src/controller.rs | 36 +++++++++++++++++++++++++++++++++++ ipa-metrics/src/lib.rs | 5 ++++- 3 files changed, 53 insertions(+), 4 deletions(-) diff --git a/ipa-metrics/src/collector.rs b/ipa-metrics/src/collector.rs index 50f2b9b8f..4d5995af3 100644 --- a/ipa-metrics/src/collector.rs +++ b/ipa-metrics/src/collector.rs @@ -2,7 +2,10 @@ use std::cell::RefCell; use crossbeam_channel::{Receiver, Select}; -use crate::{ControllerCommand, MetricsStore}; +use crate::{ + controller::{Command, Status}, + ControllerCommand, MetricsStore, +}; thread_local! { /// Collector that is installed in a thread. It is responsible for receiving metrics from @@ -48,6 +51,7 @@ impl MetricsCollector { let mut select = Select::new(); let data_idx = select.recv(&self.rx); let command_idx = select.recv(&self.command_rx); + let mut state = Status::Active; loop { let next_op = select.select(); @@ -55,12 +59,12 @@ impl MetricsCollector { i if i == data_idx => match next_op.recv(&self.rx) { Ok(store) => { tracing::trace!("Collector received more data: {store:?}"); - println!("Collector received more data: {store:?}"); self.local_store.merge(store); } Err(e) => { tracing::debug!("No more threads collecting metrics. Disconnected: {e}"); select.remove(data_idx); + state = Status::Disconnected; } }, i if i == command_idx => match next_op.recv(&self.command_rx) { @@ -69,9 +73,13 @@ impl MetricsCollector { tx.send(self.local_store.clone()).unwrap(); } Ok(ControllerCommand::Stop(tx)) => { + tracing::trace!("Stop signal received"); tx.send(()).unwrap(); break; } + Ok(Command::Status(tx)) => { + tx.send(state).unwrap(); + } Err(e) => { tracing::debug!("Metric controller is disconnected: {e}"); break; @@ -103,7 +111,7 @@ mod tests { thread::{Scope, ScopedJoinHandle}, }; - use crate::{counter, install, install_new_thread, producer::Producer}; + use crate::{controller::Status, counter, install, install_new_thread, producer::Producer}; struct MeteredScope<'scope, 'env: 'scope>(&'scope Scope<'scope, 'env>, Producer); @@ -147,6 +155,8 @@ mod tests { let s = s.metered(producer); s.spawn(|| counter!("foo", 3)).join().unwrap(); s.spawn(|| counter!("foo", 5)).join().unwrap(); + drop(s); // this causes collector to eventually stop receiving signals + while controller.status().unwrap() == Status::Active {} controller.stop().unwrap(); }); diff --git a/ipa-metrics/src/controller.rs b/ipa-metrics/src/controller.rs index 2f6e31194..265dacf45 100644 --- a/ipa-metrics/src/controller.rs +++ b/ipa-metrics/src/controller.rs @@ -2,9 +2,22 @@ use crossbeam_channel::Sender; use crate::MetricsStore; +/// Indicates the current status of collector thread +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum Status { + /// There are at least one active thread that can send + /// the store snapshots to the collector. Collector is actively + /// listening for new snapshots. + Active, + /// All threads have been disconnected from this collector, + /// and it is currently awaiting shutdown via [`Command::Stop`] + Disconnected, +} + pub enum Command { Snapshot(Sender), Stop(Sender<()>), + Status(Sender), } /// Handle to communicate with centralized metrics collection system. @@ -59,4 +72,27 @@ impl Controller { .map_err(|e| format!("An error occurred while requesting termination: {e}"))?; rx.recv().map_err(|e| format!("Disconnected channel: {e}")) } + + /// Request current collector status. + /// + /// ## Errors + /// If collector thread is disconnected or an error occurs while sending + /// or receiving data from the collector thread. + /// + /// ## Example + /// ```rust + /// use ipa_metrics::{install_new_thread, ControllerStatus}; + /// + /// let (_, controller, _handle) = install_new_thread().unwrap(); + /// let status = controller.status().unwrap(); + /// println!("Collector status: {status:?}"); + /// ``` + #[inline] + pub fn status(&self) -> Result { + let (tx, rx) = crossbeam_channel::bounded(0); + self.tx + .send(Command::Status(tx)) + .map_err(|e| format!("An error occurred while requesting status: {e}"))?; + rx.recv().map_err(|e| format!("Disconnected channel: {e}")) + } } diff --git a/ipa-metrics/src/lib.rs b/ipa-metrics/src/lib.rs index 552fac5a1..2ee9d5be0 100644 --- a/ipa-metrics/src/lib.rs +++ b/ipa-metrics/src/lib.rs @@ -17,7 +17,10 @@ use std::{io, thread::JoinHandle}; pub use collector::MetricsCollector; pub use context::{CurrentThreadContext as MetricsCurrentThreadContext, MetricsContext}; -pub use controller::{Command as ControllerCommand, Controller as MetricsCollectorController}; +pub use controller::{ + Command as ControllerCommand, Controller as MetricsCollectorController, + Status as ControllerStatus, +}; pub use key::{MetricName, OwnedName, UniqueElements}; pub use label::{Label, LabelValue}; #[cfg(feature = "partitions")] From 497220ef8e32e78d4e4b2a2f109f7cd61b75b98b Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 18 Oct 2024 09:56:54 -0700 Subject: [PATCH 153/191] Add coverage for partitions --- scripts/coverage-ci | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scripts/coverage-ci b/scripts/coverage-ci index 2d448daa7..c652e9f65 100755 --- a/scripts/coverage-ci +++ b/scripts/coverage-ci @@ -11,6 +11,9 @@ cargo build --all-targets # Need to be kept in sync manually with tests we run inside check.yml. cargo test --features "cli test-fixture relaxed-dp" +# Provide code coverage stats for ipa-metrics crate with partitions enabled +cargo test -p ipa-metrics --features "partitions" + # descriptive-gate does not require a feature flag. for gate in "compact-gate" ""; do cargo test --no-default-features --features "cli web-app real-world-infra test-fixture $gate" From cf5a66d1bb97cc97edb15cf8401f6aa4452b1d8e Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Thu, 17 Oct 2024 16:09:44 -0700 Subject: [PATCH 154/191] Server and Config --- ipa-core/src/bin/report_collector.rs | 4 +- ipa-core/src/cli/clientconf.rs | 2 +- ipa-core/src/cli/crypto/encrypt.rs | 2 +- ipa-core/src/cli/playbook/mod.rs | 12 +- ipa-core/src/config.rs | 173 +++++++++++----- ipa-core/src/net/client/mod.rs | 2 +- ipa-core/src/net/mod.rs | 3 +- ipa-core/src/net/server/handlers/mod.rs | 2 +- ipa-core/src/net/server/handlers/query/mod.rs | 3 +- .../src/net/server/handlers/query/prepare.rs | 6 +- .../src/net/server/handlers/query/step.rs | 6 +- ipa-core/src/net/server/mod.rs | 187 ++++++++++-------- ipa-core/src/net/test.rs | 17 +- ipa-core/src/net/transport.rs | 11 +- 14 files changed, 261 insertions(+), 169 deletions(-) diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index a642b7bfc..38750e578 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -22,7 +22,7 @@ use ipa_core::{ config::{KeyRegistries, NetworkConfig}, ff::{boolean_array::BA32, FieldType}, helpers::query::{DpMechanism, IpaQueryConfig, QueryConfig, QuerySize, QueryType}, - net::MpcHelperClient, + net::{Helper, MpcHelperClient}, report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID}, test_fixture::{ ipa::{ipa_in_the_clear, CappingOrder, IpaSecurityModel, TestRawDataRecord}, @@ -380,7 +380,7 @@ async fn ipa( async fn ipa_test( args: &Args, - network: &NetworkConfig, + network: &NetworkConfig, security_model: IpaSecurityModel, ipa_query_config: IpaQueryConfig, helper_clients: &[MpcHelperClient; 3], diff --git a/ipa-core/src/cli/clientconf.rs b/ipa-core/src/cli/clientconf.rs index 42835bdd0..341a4253a 100644 --- a/ipa-core/src/cli/clientconf.rs +++ b/ipa-core/src/cli/clientconf.rs @@ -186,7 +186,7 @@ fn assert_network_config(config_toml: &Map, config_str: &str) { else { panic!("peers section in toml config is not a table"); }; - for (i, peer_config_actual) in nw_config.peers.iter().enumerate() { + for (i, peer_config_actual) in nw_config.peers().iter().enumerate() { assert_peer_config(&peer_config_expected[i], peer_config_actual); } } diff --git a/ipa-core/src/cli/crypto/encrypt.rs b/ipa-core/src/cli/crypto/encrypt.rs index 6f174f89a..c2ee6ea84 100644 --- a/ipa-core/src/cli/crypto/encrypt.rs +++ b/ipa-core/src/cli/crypto/encrypt.rs @@ -244,7 +244,7 @@ this is not toml! } #[test] - #[should_panic = "invalid length 2, expected an array of length 3"] + #[should_panic = "Expected a Vec of length 3 but it was 2"] fn encrypt_incomplete_network_file() { let input_file = sample_data::write_csv(sample_data::test_ipa_data().take(10)).unwrap(); diff --git a/ipa-core/src/cli/playbook/mod.rs b/ipa-core/src/cli/playbook/mod.rs index 7d0acb1c8..fbf7843ec 100644 --- a/ipa-core/src/cli/playbook/mod.rs +++ b/ipa-core/src/cli/playbook/mod.rs @@ -20,7 +20,7 @@ use crate::{ executor::IpaRuntime, ff::boolean_array::{BA20, BA3, BA8}, helpers::query::DpMechanism, - net::{ClientIdentity, MpcHelperClient}, + net::{ClientIdentity, Helper, MpcHelperClient}, protocol::{dp::NoiseParams, ipa_prf::oprf_padding::insecure::OPRFPaddingDp}, }; @@ -194,19 +194,19 @@ pub async fn make_clients( network_path: Option<&Path>, scheme: Scheme, wait: usize, -) -> ([MpcHelperClient; 3], NetworkConfig) { +) -> ([MpcHelperClient; 3], NetworkConfig) { let mut wait = wait; let network = if let Some(path) = network_path { NetworkConfig::from_toml_str(&fs::read_to_string(path).unwrap()).unwrap() } else { - NetworkConfig { - peers: [ + NetworkConfig::::new_ring( + vec![ PeerConfig::new("localhost:3000".parse().unwrap(), None), PeerConfig::new("localhost:3001".parse().unwrap(), None), PeerConfig::new("localhost:3002".parse().unwrap(), None), ], - client: ClientConfig::default(), - } + ClientConfig::default(), + ) }; let network = network.override_scheme(&scheme); diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 0486ad490..090c65b40 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -1,13 +1,13 @@ use std::{ - array, borrow::{Borrow, Cow}, + collections::HashMap, fmt::{Debug, Formatter}, - iter::Zip, + iter::zip, path::PathBuf, - slice, time::Duration, }; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use hyper::{http::uri::Scheme, Uri}; use hyper_util::client::legacy::Builder; use rustls_pemfile::Item; @@ -22,6 +22,8 @@ use crate::{ Deserializable as _, IpaPrivateKey, IpaPublicKey, KeyRegistry, PrivateKeyOnly, PublicKeyOnly, Serializable as _, }, + net::{ConnectionFlavor, Helper, Shard}, + sharding::ShardIndex, }; pub type OwnedCertificate = CertificateDer<'static>; @@ -37,23 +39,123 @@ pub enum Error { IOError(#[from] std::io::Error), } -/// Configuration information describing a helper network. +/// Configuration describing either 3 peers in a Ring or N shard peers. In a non-sharded case a +/// single [`NetworkConfig`] represents the entire network. In a sharded case, each host should +/// have one Ring and one Sharded configuration to know how to reach its peers. /// /// The most important thing this contains is discovery information for each of the participating -/// helpers. +/// peers. #[derive(Clone, Debug, Deserialize)] -pub struct NetworkConfig { - /// Information about each helper participating in the network. The order that helpers are - /// listed here determines their assigned helper identities in the network. Note that while the - /// helper identities are stable, roles are assigned per query. - pub peers: [PeerConfig; 3], +pub struct NetworkConfig { + peers: Vec, /// HTTP client configuration. #[serde(default)] pub client: ClientConfig, + + /// The identities of the index-matching peers. Separating this from [`Self::peers`](field) so + /// that parsing is easy to implement. + #[serde(skip)] + identities: Vec, +} + +impl NetworkConfig { + /// # Panics + /// If `PathAndQuery::from_str("")` fails + #[must_use] + pub fn override_scheme(self, scheme: &Scheme) -> Self { + Self { + peers: self + .peers + .into_iter() + .map(|mut peer| { + let mut parts = peer.url.into_parts(); + parts.scheme = Some(scheme.clone()); + // `http::uri::Uri::from_parts()` requires that a URI have a path if it has a + // scheme. If the URI does not have a scheme, it is not required to have a path. + if parts.path_and_query.is_none() { + parts.path_and_query = Some("".parse().unwrap()); + } + peer.url = Uri::try_from(parts).unwrap(); + peer + }) + .collect(), + ..self + } + } + + #[must_use] + pub fn vec_peers(&self) -> Vec { + self.peers.clone() + } + + #[must_use] + pub fn get_peer(&self, i: usize) -> Option<&PeerConfig> { + self.peers.get(i) + } + + pub fn peers_iter(&self) -> std::slice::Iter<'_, PeerConfig> { + self.peers.iter() + } + + /// We currently require an exact match with the peer cert (i.e. we don't support verifying + /// the certificate against a truststore and identifying the peer by the certificate + /// subject). This could be changed if the need arises. + #[must_use] + pub fn identify_cert(&self, cert: Option<&CertificateDer>) -> Option { + let cert = cert?; + for (id, p) in zip(self.identities.iter(), self.peers.iter()) { + if p.certificate.as_ref() == Some(cert) { + return Some(*id); + } + } + // It might be nice to log something here. We could log the certificate base64? + tracing::error!( + "A client certificate was presented that does not match a known helper. Certificate: {}", + BASE64.encode(cert), + ); + None + } +} + +impl NetworkConfig { + #[must_use] + pub fn new_shards(peers: Vec, client: ClientConfig) -> Self { + let mut identities = Vec::with_capacity(peers.len()); + for (i, _p) in zip(0u32.., peers.iter()) { + identities.push(ShardIndex(i)); + } + Self { + peers, + client, + identities, + } + } + + #[must_use] + pub fn peers_map(&self) -> HashMap { + let mut indexed_peers = HashMap::new(); + for (ix, p) in zip(self.identities.iter(), self.peers.iter()) { + indexed_peers.insert(*ix, p); + } + indexed_peers + } } -impl NetworkConfig { +impl NetworkConfig { + /// Creates a new ring configuration. + /// # Panics + /// If the vector doesn't contain exactly 3 items. + #[must_use] + pub fn new_ring(ring: Vec, client: ClientConfig) -> Self { + assert_eq!(3, ring.len()); + Self { + peers: ring, + client, + identities: HelperIdentity::make_three().to_vec(), + } + } + /// Reads config from string. Expects config to be toml format. /// To read file, use `fs::read_to_string` /// @@ -62,49 +164,25 @@ impl NetworkConfig { pub fn from_toml_str(input: &str) -> Result { use config::{Config, File, FileFormat}; - let conf: Self = Config::builder() + let mut conf: Self = Config::builder() .add_source(File::from_str(input, FileFormat::Toml)) .build()? .try_deserialize()?; - Ok(conf) - } - - pub fn new(peers: [PeerConfig; 3], client: ClientConfig) -> Self { - Self { peers, client } - } + conf.identities = HelperIdentity::make_three().to_vec(); - pub fn peers(&self) -> &[PeerConfig; 3] { - &self.peers - } - - // Can maybe be replaced with array::zip when stable? - pub fn enumerate_peers( - &self, - ) -> Zip, slice::Iter> { - HelperIdentity::make_three() - .into_iter() - .zip(self.peers.iter()) + Ok(conf) } + /// Clones the internal configs and returns them as an array. /// # Panics - /// If `PathAndQuery::from_str("")` fails + /// If the internal vector isn't of size 3. #[must_use] - pub fn override_scheme(self, scheme: &Scheme) -> NetworkConfig { - NetworkConfig { - peers: self.peers.map(|mut peer| { - let mut parts = peer.url.into_parts(); - parts.scheme = Some(scheme.clone()); - // `http::uri::Uri::from_parts()` requires that a URI have a path if it has a - // scheme. If the URI does not have a scheme, it is not required to have a path. - if parts.path_and_query.is_none() { - parts.path_and_query = Some("".parse().unwrap()); - } - peer.url = Uri::try_from(parts).unwrap(); - peer - }), - ..self - } + pub fn peers(&self) -> [PeerConfig; 3] { + self.peers + .clone() + .try_into() + .unwrap_or_else(|v: Vec<_>| panic!("Expected a Vec of length 3 but it was {}", v.len())) } } @@ -422,10 +500,11 @@ impl KeyRegistries { /// If network file is improperly formatted pub fn init_from( &mut self, - network: &NetworkConfig, + network: &NetworkConfig, ) -> Option<[&KeyRegistry; 3]> { // Get the configs, if all three peers have one - let configs = network.peers().iter().try_fold(Vec::new(), |acc, peer| { + let peers = network.peers(); + let configs = peers.iter().try_fold(Vec::new(), |acc, peer| { if let (mut vec, Some(hpke_config)) = (acc, peer.hpke_config.as_ref()) { vec.push(hpke_config); Some(vec) diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index d334c6a48..d602e5ce2 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -384,7 +384,7 @@ impl MpcHelperClient { #[allow(clippy::missing_panics_doc)] pub fn from_conf( runtime: &IpaRuntime, - conf: &NetworkConfig, + conf: &NetworkConfig, identity: &ClientIdentity, ) -> [Self; 3] { conf.peers().each_ref().map(|peer_conf| { diff --git a/ipa-core/src/net/mod.rs b/ipa-core/src/net/mod.rs index 58981f7cc..6f60116ca 100644 --- a/ipa-core/src/net/mod.rs +++ b/ipa-core/src/net/mod.rs @@ -31,8 +31,7 @@ pub use transport::{HttpShardTransport, HttpTransport}; const APPLICATION_JSON: &str = "application/json"; const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; static HTTP_HELPER_ID_HEADER: HeaderName = HeaderName::from_static("x-unverified-helper-identity"); -pub static HTTP_SHARD_INDEX_HEADER: HeaderName = - HeaderName::from_static("x-unverified-shard-index"); +static HTTP_SHARD_INDEX_HEADER: HeaderName = HeaderName::from_static("x-unverified-shard-index"); /// This has the same meaning as const defined in h2 crate, but we don't import it directly. /// According to the [`spec`] it cannot exceed 2^31 - 1. diff --git a/ipa-core/src/net/server/handlers/mod.rs b/ipa-core/src/net/server/handlers/mod.rs index 14c9b4c49..a72b9e8f2 100644 --- a/ipa-core/src/net/server/handlers/mod.rs +++ b/ipa-core/src/net/server/handlers/mod.rs @@ -8,7 +8,7 @@ use crate::{ sync::Arc, }; -pub fn router(transport: Arc) -> Router { +pub fn ring_router(transport: Arc) -> Router { echo::router().nest( http_serde::query::BASE_AXUM_PATH, Router::new() diff --git a/ipa-core/src/net/server/handlers/query/mod.rs b/ipa-core/src/net/server/handlers/query/mod.rs index 616308eea..c15817a42 100644 --- a/ipa-core/src/net/server/handlers/query/mod.rs +++ b/ipa-core/src/net/server/handlers/query/mod.rs @@ -18,6 +18,7 @@ use hyper::{Request, StatusCode}; use tower::{layer::layer_fn, Service}; use crate::{ + helpers::HelperIdentity, net::{server::ClientIdentity, HttpTransport}, sync::Arc, }; @@ -88,7 +89,7 @@ impl, Response = Response>> Service> } fn call(&mut self, req: Request) -> Self::Future { - match req.extensions().get() { + match req.extensions().get::>() { Some(ClientIdentity(_)) => self.inner.call(req).left_future(), None => ready(Ok(( StatusCode::UNAUTHORIZED, diff --git a/ipa-core/src/net/server/handlers/query/prepare.rs b/ipa-core/src/net/server/handlers/query/prepare.rs index 5ad5431d1..add6c7a95 100644 --- a/ipa-core/src/net/server/handlers/query/prepare.rs +++ b/ipa-core/src/net/server/handlers/query/prepare.rs @@ -2,7 +2,7 @@ use axum::{extract::Path, response::IntoResponse, routing::post, Extension, Json use hyper::StatusCode; use crate::{ - helpers::{query::PrepareQuery, BodyStream, Transport}, + helpers::{query::PrepareQuery, BodyStream, HelperIdentity, Transport}, net::{ http_serde::{ self, @@ -20,7 +20,7 @@ use crate::{ /// processing of that query. async fn handler( transport: Extension>, - _: Extension, // require that client is an authenticated helper + _: Extension>, // require that client is an authenticated helper Path(query_id): Path, QueryConfigQueryParams(config): QueryConfigQueryParams, Json(RequestBody { roles }): Json, @@ -100,7 +100,7 @@ mod tests { // since we tested `QueryType` with `create`, skip it here // More lenient version of Request, specifically so to test failure scenarios struct OverrideReq { - client_id: Option, + client_id: Option>, query_id: String, field_type: String, size: Option, diff --git a/ipa-core/src/net/server/handlers/query/step.rs b/ipa-core/src/net/server/handlers/query/step.rs index 07e511c65..b0b5d06ba 100644 --- a/ipa-core/src/net/server/handlers/query/step.rs +++ b/ipa-core/src/net/server/handlers/query/step.rs @@ -1,7 +1,7 @@ use axum::{extract::Path, routing::post, Extension, Router}; use crate::{ - helpers::{BodyStream, Transport}, + helpers::{BodyStream, HelperIdentity, Transport}, net::{ http_serde, server::{ClientIdentity, Error}, @@ -15,7 +15,7 @@ use crate::{ #[tracing::instrument(level = "trace", "step", skip_all, fields(from = ?**from, gate = ?gate))] async fn handler( transport: Extension>, - from: Extension, + from: Extension>, Path((query_id, gate)): Path<(QueryId, Gate)>, body: BodyStream, ) -> Result<(), Error> { @@ -76,7 +76,7 @@ mod tests { } struct OverrideReq { - client_id: Option, + client_id: Option>, query_id: String, gate: Gate, payload: Vec, diff --git a/ipa-core/src/net/server/mod.rs b/ipa-core/src/net/server/mod.rs index 0aee2d832..3ca8277a4 100644 --- a/ipa-core/src/net/server/mod.rs +++ b/ipa-core/src/net/server/mod.rs @@ -4,6 +4,7 @@ mod handlers; use std::{ borrow::Cow, io, + marker::PhantomData, net::{Ipv4Addr, SocketAddr, TcpListener}, ops::Deref, task::{Context, Poll}, @@ -26,7 +27,6 @@ use axum_server::{ tls_rustls::{RustlsAcceptor, RustlsConfig}, Handle, Server, }; -use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use futures::{ future::{ready, BoxFuture, Either, Ready}, FutureExt, @@ -34,21 +34,21 @@ use futures::{ use hyper::{body::Incoming, Request}; use metrics::increment_counter; use rustls::{server::WebPkiClientVerifier, RootCertStore}; -use rustls_pki_types::CertificateDer; use tokio_rustls::server::TlsStream; use tower::{layer::layer_fn, Service}; use tower_http::trace::TraceLayer; use tracing::{error, Span}; -use super::HTTP_HELPER_ID_HEADER; use crate::{ - config::{NetworkConfig, OwnedCertificate, OwnedPrivateKey, ServerConfig, TlsConfig}, + config::{ + NetworkConfig, OwnedCertificate, OwnedPrivateKey, PeerConfig, ServerConfig, TlsConfig, + }, error::BoxError, executor::{IpaJoinHandle, IpaRuntime}, - helpers::{HelperIdentity, TransportIdentity}, + helpers::TransportIdentity, net::{ - parse_certificate_and_private_key_bytes, server::config::HttpServerConfig, Error, - HttpTransport, CRYPTO_PROVIDER, + parse_certificate_and_private_key_bytes, server::config::HttpServerConfig, + ConnectionFlavor, Error, Helper, HttpTransport, CRYPTO_PROVIDER, }, sync::Arc, telemetry::metrics::{web::RequestProtocolVersion, REQUESTS_RECEIVED}, @@ -76,34 +76,38 @@ impl TracingSpanMaker for () { /// IPA helper web service /// -/// `MpcHelperServer` handles requests from both peer helpers and external clients. -pub struct MpcHelperServer { - transport: Arc, +/// `MpcHelperServer` handles requests from peer helpers, shards within the same helper and +/// external clients. +/// +/// The Transport Restriction generic is used to make the server aware whether it should offer a +/// HTTP API for shards or for other Helpers. External clients can reach out to both APIs to push +/// the input data among other things. +pub struct MpcHelperServer { config: ServerConfig, - network_config: NetworkConfig, + network_config: NetworkConfig, + router: Router, } -impl MpcHelperServer { - pub fn new( +impl MpcHelperServer { + pub fn new_ring( transport: Arc, config: ServerConfig, - network_config: NetworkConfig, + network_config: NetworkConfig, ) -> Self { + let router = handlers::ring_router(transport); MpcHelperServer { - transport, config, network_config, + router, } } +} - fn router(&self) -> Router { - handlers::router(Arc::clone(&self.transport)) - } - +impl MpcHelperServer { #[cfg(all(test, unit_test))] async fn handle_req(&self, req: hyper::Request) -> axum::response::Response { use tower::ServiceExt; - self.router().oneshot(req).await.unwrap() + self.router.clone().oneshot(req).await.unwrap() } /// Starts the MPC helper service. @@ -133,7 +137,7 @@ impl MpcHelperServer { #[cfg(not(test))] const BIND_ADDRESS: Ipv4Addr = Ipv4Addr::UNSPECIFIED; - let svc = self.router().layer( + let svc = self.router.clone().layer( TraceLayer::new_for_http() .make_span_with(move |_request: &hyper::Request<_>| tracing.make_span()) .on_request(|request: &hyper::Request<_>, _: &Span| { @@ -146,7 +150,7 @@ impl MpcHelperServer { let task_handle = match (self.config.disable_https, listener) { (true, Some(listener)) => { let svc = svc - .layer(layer_fn(SetClientIdentityFromHeader::new)) + .layer(layer_fn(SetClientIdentityFromHeader::<_, F>::new)) .into_make_service(); spawn_server( runtime, @@ -159,12 +163,12 @@ impl MpcHelperServer { (true, None) => { let addr = SocketAddr::new(BIND_ADDRESS.into(), self.config.port.unwrap_or(0)); let svc = svc - .layer(layer_fn(SetClientIdentityFromHeader::new)) + .layer(layer_fn(SetClientIdentityFromHeader::<_, F>::new)) .into_make_service(); spawn_server(runtime, axum_server::bind(addr), handle.clone(), svc).await } (false, Some(listener)) => { - let rustls_config = rustls_config(&self.config, &self.network_config) + let rustls_config = rustls_config(&self.config, self.network_config.vec_peers()) .await .expect("invalid TLS configuration"); spawn_server( @@ -179,7 +183,7 @@ impl MpcHelperServer { } (false, None) => { let addr = SocketAddr::new(BIND_ADDRESS.into(), self.config.port.unwrap_or(0)); - let rustls_config = rustls_config(&self.config, &self.network_config) + let rustls_config = rustls_config(&self.config, self.network_config.vec_peers()) .await .expect("invalid TLS configuration"); spawn_server( @@ -276,16 +280,12 @@ async fn certificate_and_key( /// If there is a problem with the TLS configuration. async fn rustls_config( config: &ServerConfig, - network: &NetworkConfig, + certs: Vec, ) -> Result { let (cert, key) = certificate_and_key(config).await?; let mut trusted_certs = RootCertStore::empty(); - for cert in network - .peers() - .iter() - .filter_map(|peer| peer.certificate.clone()) - { + for cert in certs.into_iter().filter_map(|peer| peer.certificate) { // Note that this uses `webpki::TrustAnchor::try_from_cert_der`, which *does not* validate // the certificate. That is not required for security, but might be desirable to flag // configuration errors. @@ -315,73 +315,54 @@ async fn rustls_config( // at some inconvenience (e.g. `MaybeExtensionExt`), we avoid using `Option` within the extension, // to avoid possible confusion about how many times the return from `req.extensions().get()` must be // unwrapped to ensure valid authentication. -#[derive(Clone, Copy, Debug)] -struct ClientIdentity(pub HelperIdentity); +#[derive(Clone, Copy, Debug, PartialEq)] +struct ClientIdentity(pub I); -impl Deref for ClientIdentity { - type Target = HelperIdentity; +impl Deref for ClientIdentity { + type Target = I; fn deref(&self) -> &Self::Target { &self.0 } } -impl TryFrom for ClientIdentity { +impl TryFrom for ClientIdentity { type Error = Error; fn try_from(value: HeaderValue) -> Result { let header_str = value.to_str()?; - HelperIdentity::from_str(header_str) + I::from_str(header_str) .map_err(|e| Error::InvalidHeader(Box::new(e))) .map(ClientIdentity) } } /// `Accept`or that sets an axum `Extension` indiciating the authenticated remote helper identity. +/// Validating the certificate is something that happens earlier at connection time, this just +/// provide identity to the inner server handlers. #[derive(Clone)] -struct ClientCertRecognizingAcceptor { +struct ClientCertRecognizingAcceptor { inner: RustlsAcceptor, - network_config: Arc, + network_config: Arc>, } -impl ClientCertRecognizingAcceptor { - fn new(inner: RustlsAcceptor, network_config: NetworkConfig) -> Self { +impl ClientCertRecognizingAcceptor { + fn new(inner: RustlsAcceptor, network_config: NetworkConfig) -> Self { Self { inner, network_config: Arc::new(network_config), } } - - // This can't be a method (at least not that takes `&self`) because it needs to go in a 'static future. - fn identify_client( - network_config: &NetworkConfig, - cert_option: Option<&CertificateDer>, - ) -> Option { - let cert = cert_option?; - // We currently require an exact match with the peer cert (i.e. we don't support verifying - // the certificate against a truststore and identifying the peer by the certificate - // subject). This could be changed if the need arises. - for (id, peer) in network_config.enumerate_peers() { - if peer.certificate.as_ref() == Some(cert) { - return Some(ClientIdentity(id)); - } - } - // It might be nice to log something here. We could log the certificate base64? - error!( - "A client certificate was presented that does not match a known helper. Certificate: {}", - BASE64.encode(cert), - ); - None - } } -impl Accept for ClientCertRecognizingAcceptor +impl Accept for ClientCertRecognizingAcceptor where I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Send + 'static, + F: ConnectionFlavor, { type Stream = TlsStream; - type Service = SetClientIdentityFromCertificate; + type Service = SetClientIdentityFromCertificate; type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>; fn accept(&self, stream: I, service: S) -> Self::Future { @@ -390,7 +371,7 @@ where Box::pin(async move { let (stream, service) = acceptor.accept(stream, service).await.map_err(|err| { - error!("[ClientCertRecognizingAcceptor] connection error: {err}"); + error!("[ClientCertRecognizingAcceptor] Internal acceptor error: {err}"); err })?; @@ -401,27 +382,33 @@ where // certificate here, because the certificate must have passed full verification at // connection time. But it's possible the certificate subject is not something we // recognize as a helper. - let id = Self::identify_client( - &network_config, - stream - .get_ref() - .1 - .peer_certificates() - .and_then(<[_]>::first), - ); - let service = SetClientIdentityFromCertificate { inner: service, id }; + let opt_cert = stream + .get_ref() + .1 + .peer_certificates() + .and_then(<[_]>::first); + let option_id: Option = network_config.identify_cert(opt_cert); + let client_id = option_id.map(ClientIdentity); + let service = SetClientIdentityFromCertificate { + inner: service, + id: client_id, + }; Ok((stream, service)) }) } } #[derive(Clone)] -struct SetClientIdentityFromCertificate { +struct SetClientIdentityFromCertificate { inner: S, - id: Option, + id: Option>, } -impl>> Service> for SetClientIdentityFromCertificate { +impl Service> for SetClientIdentityFromCertificate +where + S: Service>, + F: ConnectionFlavor, +{ type Response = S::Response; type Error = S::Error; type Future = S::Future; @@ -444,18 +431,24 @@ impl>> Service> for SetClientIdentityFromCer /// Since this allows a client to claim any identity, it is completely /// insecure. It must only be used in contexts where that is acceptable. #[derive(Clone)] -struct SetClientIdentityFromHeader { +struct SetClientIdentityFromHeader { inner: S, + _restriction: PhantomData, } -impl SetClientIdentityFromHeader { +impl SetClientIdentityFromHeader { fn new(inner: S) -> Self { - Self { inner } + Self { + inner, + _restriction: PhantomData, + } } } -impl, Response = Response>> Service> - for SetClientIdentityFromHeader +impl Service> for SetClientIdentityFromHeader +where + S: Service, Response = Response>, + F: ConnectionFlavor, { type Response = Response; type Error = S::Error; @@ -467,9 +460,8 @@ impl, Response = Response>> Service> } fn call(&mut self, mut req: Request) -> Self::Future { - if let Some(header_value) = req.headers().get(&HTTP_HELPER_ID_HEADER) { - let id_result = ClientIdentity::try_from(header_value.clone()) - .map_err(|e| Error::InvalidHeader(format!("{HTTP_HELPER_ID_HEADER}: {e}").into())); + if let Some(header_value) = req.headers().get(F::identity_header()) { + let id_result = ClientIdentity::::try_from(header_value.clone()); match id_result { Ok(id) => req.extensions_mut().insert(id), Err(err) => return ready(Ok(err.into_response())).right_future(), @@ -479,6 +471,28 @@ impl, Response = Response>> Service> } } +#[cfg(all(test, unit_test))] +mod tests { + use axum::http::HeaderValue; + + use crate::{helpers::HelperIdentity, net::server::ClientIdentity}; + + #[test] + fn identify_from_header_happy_case() { + let h = HeaderValue::from_static("A"); + let id = ClientIdentity::::try_from(h); + assert_eq!(id.unwrap(), ClientIdentity(HelperIdentity::ONE)); + } + + #[test] + #[should_panic = "The string H1 is an invalid Helper Identity"] + fn identify_from_header_wrong_header() { + let h = HeaderValue::from_static("H1"); + let id = ClientIdentity::::try_from(h); + id.unwrap(); + } +} + #[cfg(all(test, unit_test))] mod e2e_tests { use std::collections::HashMap; @@ -499,6 +513,7 @@ mod e2e_tests { client::danger::{ServerCertVerified, ServerCertVerifier}, pki_types::ServerName, }; + use rustls_pki_types::CertificateDer; use tracing::Level; use super::*; diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index d56dbf387..9fdcf5877 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -24,7 +24,7 @@ use crate::{ executor::{IpaJoinHandle, IpaRuntime}, helpers::{HandlerBox, HelperIdentity, RequestHandler}, hpke::{Deserializable as _, IpaPublicKey}, - net::{ClientIdentity, HttpTransport, MpcHelperClient, MpcHelperServer}, + net::{ClientIdentity, Helper, HttpTransport, MpcHelperClient, MpcHelperServer}, sync::Arc, test_fixture::metrics::MetricsHandle, }; @@ -33,7 +33,7 @@ pub const DEFAULT_TEST_PORTS: [u16; 3] = [3000, 3001, 3002]; pub struct TestConfig { pub disable_https: bool, - pub network: NetworkConfig, + pub network: NetworkConfig, pub servers: [ServerConfig; 3], pub sockets: Option<[TcpListener; 3]>, } @@ -174,16 +174,13 @@ impl TestConfigBuilder { )) }, }) - .collect::>() - .try_into() - .unwrap(); - let network = NetworkConfig { + .collect::>(); + let network = NetworkConfig::::new_ring( peers, - client: self - .use_http1 + self.use_http1 .then(ClientConfig::use_http1) .unwrap_or_default(), - }; + ); let servers = if self.disable_https { ports.map(|ports| server_config_insecure_http(ports, !self.disable_matchkey_encryption)) } else { @@ -203,7 +200,7 @@ pub struct TestServer { pub addr: SocketAddr, pub handle: IpaJoinHandle<()>, pub transport: Arc, - pub server: MpcHelperServer, + pub server: MpcHelperServer, pub client: MpcHelperClient, pub request_handler: Option>>, } diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 9fbdc8103..7ada0da2f 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -9,7 +9,7 @@ use async_trait::async_trait; use futures::{Stream, TryFutureExt}; use pin_project::{pin_project, pinned_drop}; -use super::client::resp_ok; +use super::{client::resp_ok, Helper}; use crate::{ config::{NetworkConfig, ServerConfig}, executor::IpaRuntime, @@ -68,12 +68,13 @@ impl HttpTransport { runtime: IpaRuntime, identity: HelperIdentity, server_config: ServerConfig, - network_config: NetworkConfig, + network_config: NetworkConfig, clients: [MpcHelperClient; 3], handler: Option, - ) -> (Arc, MpcHelperServer) { + ) -> (Arc, MpcHelperServer) { let transport = Self::new_internal(runtime, identity, clients, handler); - let server = MpcHelperServer::new(Arc::clone(&transport), server_config, network_config); + let server = + MpcHelperServer::new_ring(Arc::clone(&transport), server_config, network_config); (transport, server) } @@ -378,7 +379,7 @@ mod tests { async fn make_helpers( sockets: [TcpListener; 3], server_config: [ServerConfig; 3], - network_config: &NetworkConfig, + network_config: &NetworkConfig, disable_https: bool, ) -> [HelperApp; 3] { join_all( From 6dd937c599b2b906756f321e7004b8b922c35d47 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 18 Oct 2024 10:04:19 -0700 Subject: [PATCH 155/191] Clippy --- ipa-metrics/src/partitioned.rs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ipa-metrics/src/partitioned.rs b/ipa-metrics/src/partitioned.rs index 9e4653992..f723759c3 100644 --- a/ipa-metrics/src/partitioned.rs +++ b/ipa-metrics/src/partitioned.rs @@ -26,7 +26,7 @@ use crate::{ }; thread_local! { - static PARTITION: Cell> = Cell::new(None); + static PARTITION: Cell> = const { Cell::new(None) } } /// Each partition is a unique 8 byte value, meaning roughly 1B partitions @@ -37,13 +37,14 @@ pub struct CurrentThreadContext; impl CurrentThreadContext { pub fn set(new: Partition) { - Self::toggle(Some(new)) + Self::toggle(Some(new)); } pub fn toggle(new: Option) { PARTITION.set(new); } + #[must_use] pub fn get() -> Option { PARTITION.get() } @@ -73,6 +74,7 @@ impl Default for PartitionedStore { } impl PartitionedStore { + #[must_use] pub const fn new() -> Self { Self { inner: hashbrown::HashMap::with_hasher(FxBuildHasher), @@ -118,28 +120,31 @@ impl PartitionedStore { if let Some(partition) = CurrentThreadContext::get() { self.inner .entry(partition) - .or_insert_with(|| Store::default()) + .or_insert_with(Store::default) .counter(key) } else { self.default_store.counter(key) } } + #[must_use] pub fn len(&self) -> usize { self.inner.len() + self.default_store.len() } + #[must_use] pub fn is_empty(&self) -> bool { self.len() == 0 } + #[allow(dead_code)] fn with_partition_mut T, T>( &mut self, partition: Partition, f: F, ) -> T { - let mut store = self.get_mut(Some(partition)); - f(&mut store) + let store = self.get_mut(Some(partition)); + f(store) } fn get_mut(&mut self, partition: Option) -> &mut Store { From 245388231e523e8809e4f4aef0f2d77d5f5169f2 Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Fri, 18 Oct 2024 11:02:35 -0700 Subject: [PATCH 156/191] deserialize_infallible, change open_in_place/seal_in_place to not use KeyRegistry --- ipa-core/src/hpke/mod.rs | 34 +++++++++++++++------------------- ipa-core/src/report/hybrid.rs | 27 ++++++++++++++++----------- ipa-core/src/report/ipa.rs | 20 ++++++++++++-------- 3 files changed, 43 insertions(+), 38 deletions(-) diff --git a/ipa-core/src/hpke/mod.rs b/ipa-core/src/hpke/mod.rs index 2b7f2bb80..e545efa54 100644 --- a/ipa-core/src/hpke/mod.rs +++ b/ipa-core/src/hpke/mod.rs @@ -96,19 +96,15 @@ impl From for CryptError { /// If ciphertext cannot be opened for any reason. /// /// [`HPKE decryption`]: https://datatracker.ietf.org/doc/html/rfc9180#name-encryption-and-decryption -pub fn open_in_place<'a, R: PrivateKeyRegistry>( - key_registry: &R, +pub fn open_in_place<'a>( + sk: &IpaPrivateKey, enc: &[u8], ciphertext: &'a mut [u8], - key_id: u8, info: &[u8], ) -> Result<&'a [u8], CryptError> { let encap_key = ::EncappedKey::from_bytes(enc)?; let (ct, tag) = ciphertext.split_at_mut(ciphertext.len() - AeadTag::::size()); let tag = AeadTag::::from_bytes(tag)?; - let sk = key_registry - .private_key(key_id) - .ok_or(CryptError::NoSuchKey(key_id))?; single_shot_open_in_place_detached::<_, IpaKdf, IpaKem>( &OpModeR::Base, @@ -135,20 +131,15 @@ pub(crate) type Ciphertext<'a> = ( /// ## Errors /// If the match key cannot be sealed for any reason. -pub(crate) fn seal_in_place<'a, R: CryptoRng + RngCore, K: PublicKeyRegistry>( - key_registry: &K, +pub(crate) fn seal_in_place<'a, R: CryptoRng + RngCore>( + pk: &IpaPublicKey, plaintext: &'a mut [u8], - key_id: u8, info: &[u8], rng: &mut R, ) -> Result, CryptError> { - let pk_r = key_registry - .public_key(key_id) - .ok_or(CryptError::NoSuchKey(key_id))?; - let (encap_key, tag) = single_shot_seal_in_place_detached::( &OpModeS::Base, - pk_r, + pk, info, plaintext, &[], @@ -167,6 +158,7 @@ mod tests { use rand_core::{CryptoRng, RngCore, SeedableRng}; use typenum::Unsigned; + use super::{PrivateKeyRegistry, PublicKeyRegistry}; use crate::{ ff::{Gf40Bit, Serializable as IpaSerializable}, hpke::{open_in_place, seal_in_place, CryptError, Info, IpaAead, KeyPair, KeyRegistry}, @@ -229,9 +221,11 @@ mod tests { match_key.serialize(&mut plaintext); let (encap_key, ciphertext, tag) = seal_in_place( - &self.registry, + self.registry + .public_key(info.key_id) + .ok_or(CryptError::NoSuchKey(info.key_id)) + .unwrap(), plaintext.as_mut_slice(), - info.key_id, &info.to_bytes(), &mut self.rng, ) @@ -282,10 +276,11 @@ mod tests { ) .unwrap(); open_in_place( - &self.registry, + self.registry + .private_key(info.key_id) + .ok_or(CryptError::NoSuchKey(info.key_id))?, &enc.enc, enc.ct.as_mut(), - info.key_id, &info.to_bytes(), )?; @@ -472,7 +467,8 @@ mod tests { _ => panic!("bad test setup: only 5 fields can be corrupted, asked to corrupt: {corrupted_info_field}") }; - open_in_place(&suite.registry, &encryption.enc, &mut encryption.ct, info.key_id, &info.to_bytes()).unwrap_err(); + open_in_place(suite.registry.private_key(info.key_id) + .ok_or(CryptError::NoSuchKey(info.key_id))?, &encryption.enc, &mut encryption.ct, &info.to_bytes()).unwrap_err(); } } } diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index edb614714..d96346bf4 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -12,7 +12,7 @@ use rand_core::{CryptoRng, RngCore}; use typenum::{Sum, Unsigned, U16}; use crate::{ - error::{BoxError, Error, UnwrapInfallible}, + error::{BoxError, Error}, ff::{boolean_array::BA64, Serializable}, hpke::{ open_in_place, seal_in_place, CryptError, EncapsulationSize, PrivateKeyRegistry, @@ -151,17 +151,17 @@ where .serialize(GenericArray::from_mut_slice(&mut plaintext_btt[..])); let (encap_key_mk, ciphertext_mk, tag_mk) = seal_in_place( - key_registry, + key_registry.public_key(key_id) + .ok_or(CryptError::NoSuchKey(key_id))?, plaintext_mk.as_mut(), - key_id, &info.to_bytes(), rng, )?; let (encap_key_btt, ciphertext_btt, tag_btt) = seal_in_place( - key_registry, + key_registry.public_key(key_id) + .ok_or(CryptError::NoSuchKey(key_id))?, plaintext_btt.as_mut(), - key_id, &info.to_bytes(), rng, )?; @@ -297,26 +297,31 @@ where let mut ct_mk: GenericArray = *GenericArray::from_slice(self.mk_ciphertext()); let plaintext_mk = open_in_place( - key_registry, + key_registry + .private_key(self.key_id()) + .ok_or(CryptError::NoSuchKey(self.key_id()))?, self.encap_key_mk(), &mut ct_mk, - self.key_id(), &info.to_bytes(), )?; let mut ct_btt: GenericArray> = GenericArray::from_slice(self.btt_ciphertext()).clone(); let plaintext_btt = open_in_place( - key_registry, + key_registry + .private_key(self.key_id()) + .ok_or(CryptError::NoSuchKey(self.key_id()))?, self.encap_key_btt(), &mut ct_btt, - self.key_id(), &info.to_bytes(), )?; Ok(HybridImpressionReport:: { - match_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_mk)) - .unwrap_infallible(), + //match_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_mk)) + // .unwrap_infallible(), + match_key: Replicated::::deserialize_infallible(GenericArray::from_slice( + plaintext_mk, + )), breakdown_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_btt)) .map_err(|e| { InvalidHybridReportError::DeserializationError("is_trigger", e.into()) diff --git a/ipa-core/src/report/ipa.rs b/ipa-core/src/report/ipa.rs index a71358f64..44a014db1 100644 --- a/ipa-core/src/report/ipa.rs +++ b/ipa-core/src/report/ipa.rs @@ -408,20 +408,22 @@ where let mut ct_mk: GenericArray = *GenericArray::from_slice(self.mk_ciphertext()); let plaintext_mk = open_in_place( - key_registry, + key_registry + .private_key(self.key_id()) + .ok_or(CryptError::NoSuchKey(self.key_id()))?, self.encap_key_mk(), &mut ct_mk, - self.key_id(), &info.to_bytes(), )?; let mut ct_btt: GenericArray> = GenericArray::from_slice(self.btt_ciphertext()).clone(); let plaintext_btt = open_in_place( - key_registry, + key_registry + .private_key(self.key_id()) + .ok_or(CryptError::NoSuchKey(self.key_id()))?, self.encap_key_btt(), &mut ct_btt, - self.key_id(), &info.to_bytes(), )?; @@ -590,17 +592,19 @@ where )); let (encap_key_mk, ciphertext_mk, tag_mk) = seal_in_place( - key_registry, + key_registry + .public_key(key_id) + .ok_or(CryptError::NoSuchKey(key_id))?, plaintext_mk.as_mut(), - key_id, &info.to_bytes(), rng, )?; let (encap_key_btt, ciphertext_btt, tag_btt) = seal_in_place( - key_registry, + key_registry + .public_key(key_id) + .ok_or(CryptError::NoSuchKey(key_id))?, plaintext_btt.as_mut(), - key_id, &info.to_bytes(), rng, )?; From 0c276b914036c8a9877f182a29121e1501d48a02 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Fri, 18 Oct 2024 15:04:29 -0700 Subject: [PATCH 157/191] Sharded decrypt and uniqueness verification (#1358) * refactor UniqueBytesValidator to UniqueTagValidator with fixed width array * add implementation of Serializable for UniqueTag * tmp working commit * working commit, all tests passing * remove type aliases because clippy doesn't like them * Update ipa-core/src/query/runner/hybrid.rs Co-authored-by: Alex Koshelev * fix formatting error * add a ReshardByTag step and narrow to it * do modulo on u128, update compile check for TAG_SIZE * update type signature and name of flatten3v * update test to generate_random_tag directly * reorganize ShardIndex implementations and remove conflicting implementation --------- Co-authored-by: Alex Koshelev --- ipa-core/Cargo.toml | 1 + ipa-core/src/protocol/hybrid/step.rs | 4 +- ipa-core/src/query/executor.rs | 17 +- ipa-core/src/query/runner/hybrid.rs | 316 +++++++++++++++++---------- ipa-core/src/query/runner/mod.rs | 1 - ipa-core/src/report/hybrid.rs | 126 +++++++---- ipa-core/src/report/ipa.rs | 6 +- ipa-core/src/sharding.rs | 104 +++++---- ipa-core/src/test_fixture/mod.rs | 31 +++ ipa-core/src/test_fixture/world.rs | 36 +++ 10 files changed, 419 insertions(+), 223 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 49348843d..0081a0a50 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -82,6 +82,7 @@ ipa-step = { version = "*", path = "../ipa-step" } ipa-step-derive = { version = "*", path = "../ipa-step-derive" } aes = "0.8.3" +assertions = "0.1.0" async-trait = "0.1.79" async-scoped = { version = "0.9.0", features = ["use-tokio"], optional = true } axum = { version = "0.7.5", optional = true, features = ["http2", "macros"] } diff --git a/ipa-core/src/protocol/hybrid/step.rs b/ipa-core/src/protocol/hybrid/step.rs index 6fd7406c6..5de0051be 100644 --- a/ipa-core/src/protocol/hybrid/step.rs +++ b/ipa-core/src/protocol/hybrid/step.rs @@ -1,4 +1,6 @@ use ipa_step_derive::CompactStep; #[derive(CompactStep)] -pub(crate) enum HybridStep {} +pub(crate) enum HybridStep { + ReshardByTag, +} diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index a6b1dae74..edd1662e4 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -39,7 +39,7 @@ use crate::{ Gate, }, query::{ - runner::{HybridQuery, OprfIpaQuery, QueryResult}, + runner::{OprfIpaQuery, QueryResult}, state::RunningQuery, }, sync::Arc, @@ -164,20 +164,7 @@ pub fn execute( ) }, ), - (QueryType::SemiHonestHybrid(query_params), _) => do_query( - runtime, - config, - gateway, - input, - move |prss, gateway, config, input| { - let ctx = SemiHonestContext::new(prss, gateway); - Box::pin( - HybridQuery::<_, BA32, R>::new(query_params, key_registry) - .execute(ctx, config.size, input) - .then(|res| ready(res.map(|out| Box::new(out) as Box))), - ) - }, - ), + (QueryType::SemiHonestHybrid(_), _) => todo!(), } } diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index 6ff7ed7f2..cc3b861c7 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -1,6 +1,6 @@ use std::{marker::PhantomData, sync::Arc}; -use futures::{future, stream::iter, StreamExt, TryStreamExt}; +use futures::{stream::iter, StreamExt, TryStreamExt}; use crate::{ error::Error, @@ -10,16 +10,16 @@ use crate::{ BodyStream, LengthDelimitedStream, }, hpke::PrivateKeyRegistry, - protocol::{context::UpgradableContext, ipa_prf::shuffle::Shuffle, step::ProtocolStep::Hybrid}, - report::hybrid::{EncryptedHybridReport, UniqueBytesValidator}, + protocol::{ + context::{reshard_iter, ShardedContext}, + hybrid::step::HybridStep, + step::ProtocolStep::Hybrid, + }, + report::hybrid::{EncryptedHybridReport, HybridReport, UniqueTag, UniqueTagValidator}, secret_sharing::{replicated::semi_honest::AdditiveShare as ReplicatedShare, SharedValue}, }; -pub type BreakdownKey = BA8; -pub type Value = BA3; -// TODO: remove this when encryption/decryption works for HybridReports -pub type Timestamp = BA20; - +#[allow(dead_code)] pub struct Query { config: HybridQueryParams, key_registry: Arc, @@ -28,7 +28,7 @@ pub struct Query { impl Query where - C: UpgradableContext + Shuffle, + C: ShardedContext, { pub fn new(query_params: HybridQueryParams, key_registry: Arc) -> Self { Self { @@ -50,41 +50,56 @@ where key_registry, phantom_data: _, } = self; + tracing::info!("New hybrid query: {config:?}"); - let _ctx = ctx.narrow(&Hybrid); + let ctx = ctx.narrow(&Hybrid); let sz = usize::from(query_size); - let mut unique_encrypted_hybrid_reports = UniqueBytesValidator::new(sz); - if config.plaintext_match_keys { return Err(Error::Unsupported( "Hybrid queries do not currently support plaintext match keys".to_string(), )); } - let _input = LengthDelimitedStream::::new(input_stream) - .map_err(Into::::into) - .and_then(|enc_reports| { - future::ready( - unique_encrypted_hybrid_reports - .check_duplicates(&enc_reports) - .map(|()| enc_reports) - .map_err(Into::::into), + let (_decrypted_reports, tags): (Vec>, Vec) = + LengthDelimitedStream::::new(input_stream) + .map_err(Into::::into) + .map_ok(|enc_reports| { + iter(enc_reports.into_iter().map({ + |enc_report| { + let dec_report = enc_report + .decrypt::(key_registry.as_ref()) + .map_err(Into::::into); + let unique_tag = UniqueTag::from_unique_bytes(&enc_report); + dec_report.map(|dec_report1| (dec_report1, unique_tag)) + } + })) + }) + .try_flatten() + .take(sz) + .try_fold( + (Vec::with_capacity(sz), Vec::with_capacity(sz)), + |mut acc, result| async move { + acc.0.push(result.0); + acc.1.push(result.1); + Ok(acc) + }, ) - }) - .map_ok(|enc_reports| { - iter(enc_reports.into_iter().map({ - |enc_report| { - enc_report - .decrypt::(key_registry.as_ref()) - .map_err(Into::::into) - } - })) - }) - .try_flatten() - .take(sz) - .try_collect::>() - .await?; + .await?; + + let resharded_tags = reshard_iter( + ctx.narrow(&HybridStep::ReshardByTag), + tags, + |ctx, _, tag| tag.shard_picker(ctx.shard_count()), + ) + .await?; + + // this should use ? but until this returns a result, + //we want to capture the panic for the test + let mut unique_encrypted_hybrid_reports = UniqueTagValidator::new(resharded_tags.len()); + unique_encrypted_hybrid_reports + .check_duplicates(&resharded_tags) + .unwrap(); unimplemented!("query::runnner::HybridQuery.execute is not fully implemented") } @@ -107,10 +122,13 @@ mod tests { BodyStream, }, hpke::{KeyPair, KeyRegistry}, - query::runner::HybridQuery, + query::runner::hybrid::Query as HybridQuery, report::{OprfReport, DEFAULT_KEY_ID}, - secret_sharing::IntoShares, - test_fixture::{ipa::TestRawDataRecord, join3v, Reconstruct, TestWorld}, + secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, + test_fixture::{ + flatten3v, ipa::TestRawDataRecord, Reconstruct, RoundRobinInputDistribution, TestWorld, + TestWorldConfig, WithShards, + }, }; const EXPECTED: &[u128] = &[0, 8, 5]; @@ -165,28 +183,44 @@ mod tests { } struct BufferAndKeyRegistry { - buffers: [Vec; 3], + buffers: [Vec>; 3], key_registry: Arc>, + query_sizes: Vec, } - fn build_buffers_from_records(records: &[TestRawDataRecord]) -> BufferAndKeyRegistry { + fn build_buffers_from_records(records: &[TestRawDataRecord], s: usize) -> BufferAndKeyRegistry { let mut rng = StdRng::seed_from_u64(42); let key_id = DEFAULT_KEY_ID; let key_registry = Arc::new(KeyRegistry::::random(1, &mut rng)); - let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new()); - + let mut buffers: [_; 3] = std::array::from_fn(|_| vec![Vec::new(); s]); let shares: [Vec>; 3] = records.iter().cloned().share(); for (buf, shares) in zip(&mut buffers, shares) { - for share in shares { + for (i, share) in shares.into_iter().enumerate() { share - .delimited_encrypt_to(key_id, key_registry.as_ref(), &mut rng, buf) + .delimited_encrypt_to(key_id, key_registry.as_ref(), &mut rng, &mut buf[i % s]) .unwrap(); } } + + let total_query_size = records.len(); + let base_size = total_query_size / s; + let remainder = total_query_size % s; + let query_sizes: Vec<_> = (0..s) + .map(|i| { + if i < remainder { + base_size + 1 + } else { + base_size + } + }) + .map(|size| QuerySize::try_from(size).unwrap()) + .collect(); + BufferAndKeyRegistry { buffers, key_registry, + query_sizes, } } @@ -200,37 +234,59 @@ mod tests { // While this test currently checks for an unimplemented panic it is // designed to test for a correct result for a complete implementation. + const SHARDS: usize = 2; let records = build_records(); - let query_size = QuerySize::try_from(records.len()).unwrap(); let BufferAndKeyRegistry { buffers, key_registry, - } = build_buffers_from_records(&records); + query_sizes, + } = build_buffers_from_records(&records, SHARDS); - let world = TestWorld::default(); + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); let contexts = world.contexts(); + #[allow(clippy::large_futures)] - let results = join3v(buffers.into_iter().zip(contexts).map(|(buffer, ctx)| { - let query_params = HybridQueryParams { - per_user_credit_cap: 8, - max_breakdown_key: 3, - with_dp: 0, - epsilon: 5.0, - plaintext_match_keys: false, - }; - let input = BodyStream::from(buffer); - - HybridQuery::<_, BA16, KeyRegistry>::new( - query_params, - Arc::clone(&key_registry), - ) - .execute(ctx, query_size, input) - })) + let results = flatten3v(buffers.into_iter().zip(contexts).map( + |(helper_buffers, helper_ctxs)| { + helper_buffers + .into_iter() + .zip(helper_ctxs) + .zip(query_sizes.clone()) + .map(|((buffer, ctx), query_size)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: false, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + }) + }, + )) .await; + let results: Vec<[Vec>; 3]> = results + .chunks(3) + .map(|chunk| { + [ + chunk[0].as_ref().unwrap().clone(), + chunk[1].as_ref().unwrap().clone(), + chunk[2].as_ref().unwrap().clone(), + ] + }) + .collect(); + assert_eq!( - results.reconstruct()[0..3] + results.into_iter().next().unwrap().reconstruct()[0..3] .iter() .map(U128Conversions::as_u128) .collect::>(), @@ -240,45 +296,71 @@ mod tests { // cannot test for Err directly because join3v calls unwrap. This should be sufficient. #[tokio::test] - #[should_panic(expected = "DuplicateBytes(3)")] + #[should_panic(expected = "DuplicateBytes")] async fn duplicate_encrypted_hybrid_reports() { - let all_records = build_records(); - let records = &all_records[..2].to_vec(); + const SHARDS: usize = 2; + let records = build_records(); let BufferAndKeyRegistry { mut buffers, key_registry, - } = build_buffers_from_records(records); + query_sizes, + } = build_buffers_from_records(&records, SHARDS); // this is double, since we duplicate the data below - let query_size = QuerySize::try_from(records.len() * 2).unwrap(); - - // duplicate all the data - for buffer in &mut buffers { - let original = buffer.clone(); - buffer.extend(original); + let query_sizes = query_sizes + .into_iter() + .map(|query_size| QuerySize::try_from(usize::from(query_size) * 2).unwrap()) + .collect::>(); + + // duplicate all the data across shards + + for helper_buffers in &mut buffers { + // Get the last shard buffer to use for the first shard buffer extension + let last_shard_buffer = helper_buffers.last().unwrap().clone(); + let len = helper_buffers.len(); + for i in 0..len { + if i > 0 { + let previous = &helper_buffers[i - 1].clone(); + helper_buffers[i].extend_from_slice(previous); + } else { + helper_buffers[i].extend_from_slice(&last_shard_buffer); + } + } } - let world = TestWorld::default(); + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); let contexts = world.contexts(); + #[allow(clippy::large_futures)] - let _results = join3v(buffers.into_iter().zip(contexts).map(|(buffer, ctx)| { - let query_params = HybridQueryParams { - per_user_credit_cap: 8, - max_breakdown_key: 3, - with_dp: 0, - epsilon: 5.0, - plaintext_match_keys: false, - }; - let input = BodyStream::from(buffer); - - HybridQuery::<_, BA16, KeyRegistry>::new( - query_params, - Arc::clone(&key_registry), - ) - .execute(ctx, query_size, input) - })) + let results = flatten3v(buffers.into_iter().zip(contexts).map( + |(helper_buffers, helper_ctxs)| { + helper_buffers + .into_iter() + .zip(helper_ctxs) + .zip(query_sizes.clone()) + .map(|((buffer, ctx), query_size)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: false, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + }) + }, + )) .await; + + results.into_iter().map(|r| r.unwrap()).for_each(drop); } // cannot test for Err directly because join3v calls unwrap. This should be sufficient. @@ -287,34 +369,46 @@ mod tests { expected = "Unsupported(\"Hybrid queries do not currently support plaintext match keys\")" )] async fn unsupported_plaintext_match_keys_hybrid_query() { - let all_records = build_records(); - let records = &all_records[..2].to_vec(); - let query_size = QuerySize::try_from(records.len()).unwrap(); + const SHARDS: usize = 2; + let records = build_records(); let BufferAndKeyRegistry { buffers, key_registry, - } = build_buffers_from_records(records); + query_sizes, + } = build_buffers_from_records(&records, SHARDS); - let world = TestWorld::default(); + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); let contexts = world.contexts(); + #[allow(clippy::large_futures)] - let _results = join3v(buffers.into_iter().zip(contexts).map(|(buffer, ctx)| { - let query_params = HybridQueryParams { - per_user_credit_cap: 8, - max_breakdown_key: 3, - with_dp: 0, - epsilon: 5.0, - plaintext_match_keys: true, - }; - let input = BodyStream::from(buffer); - - HybridQuery::<_, BA16, KeyRegistry>::new( - query_params, - Arc::clone(&key_registry), - ) - .execute(ctx, query_size, input) - })) + let results = flatten3v(buffers.into_iter().zip(contexts).map( + |(helper_buffers, helper_ctxs)| { + helper_buffers + .into_iter() + .zip(helper_ctxs) + .zip(query_sizes.clone()) + .map(|((buffer, ctx), query_size)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: true, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + }) + }, + )) .await; + + results.into_iter().map(|r| r.unwrap()).for_each(drop); } } diff --git a/ipa-core/src/query/runner/mod.rs b/ipa-core/src/query/runner/mod.rs index 9e5935c20..9bd739db9 100644 --- a/ipa-core/src/query/runner/mod.rs +++ b/ipa-core/src/query/runner/mod.rs @@ -7,7 +7,6 @@ mod test_multiply; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use add_in_prime_field::execute as test_add_in_prime_field; -pub use hybrid::Query as HybridQuery; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use test_multiply::execute_test_multiply; diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index d322f92ae..593c6d0e8 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -1,7 +1,8 @@ -use std::{collections::HashSet, ops::Add}; +use std::{collections::HashSet, convert::Infallible, ops::Add}; +use assertions::const_assert; use bytes::Bytes; -use generic_array::ArrayLength; +use generic_array::{ArrayLength, GenericArray}; use rand_core::{CryptoRng, RngCore}; use typenum::{Sum, Unsigned, U16}; @@ -11,6 +12,7 @@ use crate::{ hpke::{EncapsulationSize, PrivateKeyRegistry, PublicKeyRegistry, TagSize}, report::{EncryptedOprfReport, EventType, InvalidReportError, KeyIdentifier}, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, SharedValue}, + sharding::ShardIndex, }; #[derive(Clone, Debug, Eq, PartialEq)] @@ -150,37 +152,92 @@ impl TryFrom for EncryptedHybridReport { } } +const TAG_SIZE: usize = TagSize::USIZE; + +#[derive(Clone, Debug)] +pub struct UniqueTag { + bytes: [u8; TAG_SIZE], +} + pub trait UniqueBytes { - fn unique_bytes(&self) -> Vec; + fn unique_bytes(&self) -> [u8; TAG_SIZE]; +} + +impl UniqueBytes for UniqueTag { + fn unique_bytes(&self) -> [u8; TAG_SIZE] { + self.bytes + } } impl UniqueBytes for EncryptedHybridReport { /// We use the `TagSize` (the first 16 bytes of the ciphertext) for collision-detection /// See [analysis here for uniqueness](https://eprint.iacr.org/2019/624) - fn unique_bytes(&self) -> Vec { - self.mk_ciphertext()[0..TagSize::USIZE].to_vec() + fn unique_bytes(&self) -> [u8; TAG_SIZE] { + let slice = &self.mk_ciphertext()[0..TAG_SIZE]; + let mut array = [0u8; TAG_SIZE]; + array.copy_from_slice(slice); + array + } +} + +impl UniqueTag { + fn _compile_check() { + // This will vaild at compile time if TAG_SIZE doesn't match U16 + // the macro expansion needs to be wrapped in a function + const_assert!(TAG_SIZE == 16); + } + + // Function to attempt to create a UniqueTag from a UniqueBytes implementor + pub fn from_unique_bytes(item: &T) -> Self { + UniqueTag { + bytes: item.unique_bytes(), + } + } + + /// Maps the tag into a consistent shard. + /// + /// ## Panics + /// if the `TAG_SIZE != 16` + /// note: ~10 below this, we have a compile time check that `TAG_SIZE = 16` + #[must_use] + pub fn shard_picker(&self, shard_count: ShardIndex) -> ShardIndex { + let num = u128::from_le_bytes(self.bytes); + let shard_count = u128::from(shard_count); + ShardIndex::try_from(num % shard_count).expect("Modulo a u32 will fit in u32") + } +} + +impl Serializable for UniqueTag { + type Size = U16; // This must match TAG_SIZE + type DeserializationError = Infallible; + + fn serialize(&self, buf: &mut GenericArray) { + buf.copy_from_slice(&self.bytes); + } + fn deserialize(buf: &GenericArray) -> Result { + let mut bytes = [0u8; TAG_SIZE]; + bytes.copy_from_slice(buf.as_slice()); + Ok(UniqueTag { bytes }) } } #[derive(Debug)] -pub struct UniqueBytesValidator { - hash_set: HashSet>, +pub struct UniqueTagValidator { + hash_set: HashSet<[u8; TAG_SIZE]>, check_counter: usize, } -impl UniqueBytesValidator { +impl UniqueTagValidator { #[must_use] pub fn new(size: usize) -> Self { - UniqueBytesValidator { + UniqueTagValidator { hash_set: HashSet::with_capacity(size), check_counter: 0, } } - - fn insert(&mut self, value: Vec) -> bool { + fn insert(&mut self, value: [u8; TAG_SIZE]) -> bool { self.hash_set.insert(value) } - /// Checks that item is unique among all checked thus far /// /// ## Errors @@ -193,7 +250,6 @@ impl UniqueBytesValidator { Err(Error::DuplicateBytes(self.check_counter)) } } - /// Checks that an iter of items is unique among the iter and any other items checked thus far /// /// ## Errors @@ -213,7 +269,7 @@ mod test { use super::{ EncryptedHybridReport, HybridConversionReport, HybridImpressionReport, HybridReport, - UniqueBytes, UniqueBytesValidator, + UniqueTag, UniqueTagValidator, }; use crate::{ error::Error, @@ -239,11 +295,11 @@ mod test { } } - fn generate_random_bytes(size: usize) -> Vec { + fn generate_random_tag() -> UniqueTag { let mut rng = thread_rng(); - let mut bytes = vec![0u8; size]; + let mut bytes = [0u8; 16]; rng.fill(&mut bytes[..]); - bytes + UniqueTag { bytes } } #[test] @@ -305,40 +361,22 @@ mod test { #[test] fn unique_encrypted_hybrid_reports() { - #[derive(Clone)] - pub struct UniqueByteHolder { - bytes: Vec, - } - - impl UniqueByteHolder { - pub fn new(size: usize) -> Self { - let bytes = generate_random_bytes(size); - UniqueByteHolder { bytes } - } - } - - impl UniqueBytes for UniqueByteHolder { - fn unique_bytes(&self) -> Vec { - self.bytes.clone() - } - } - - let bytes1 = UniqueByteHolder::new(4); - let bytes2 = UniqueByteHolder::new(4); - let bytes3 = UniqueByteHolder::new(4); - let bytes4 = UniqueByteHolder::new(4); + let tag1 = generate_random_tag(); + let tag2 = generate_random_tag(); + let tag3 = generate_random_tag(); + let tag4 = generate_random_tag(); - let mut unique_bytes = UniqueBytesValidator::new(4); + let mut unique_bytes = UniqueTagValidator::new(4); - unique_bytes.check_duplicate(&bytes1).unwrap(); + unique_bytes.check_duplicate(&tag1).unwrap(); unique_bytes - .check_duplicates(&[bytes2.clone(), bytes3.clone()]) + .check_duplicates(&[tag2.clone(), tag3.clone()]) .unwrap(); - let expected_err = unique_bytes.check_duplicate(&bytes2); + let expected_err = unique_bytes.check_duplicate(&tag2); assert!(matches!(expected_err, Err(Error::DuplicateBytes(4)))); - let expected_err = unique_bytes.check_duplicates(&[bytes4, bytes3]); + let expected_err = unique_bytes.check_duplicates(&[tag4, tag3]); assert!(matches!(expected_err, Err(Error::DuplicateBytes(6)))); } } diff --git a/ipa-core/src/report/ipa.rs b/ipa-core/src/report/ipa.rs index a9da93454..bceba37e3 100644 --- a/ipa-core/src/report/ipa.rs +++ b/ipa-core/src/report/ipa.rs @@ -827,9 +827,9 @@ mod test { fn check_compatibility_impressionmk_with_ios_encryption() { let enc_report_bytes1 = hex::decode( "12854879d86ef277cd70806a7f6bad269877adc95ee107380381caf15b841a7e995e41\ - 4c63a9d82f834796cdd6c40529189fca82720714d24200d8a916a1e090b123f27eaf24\ - f047f3930a77e5bcd33eeb823b73b0e9546c59d3d6e69383c74ae72b79645698fe1422\ - f83886bd3cbca9fbb63f7019e2139191dd000000007777772e6d6574612e636f6d", + 4c63a9d82f834796cdd6c40529189fca82720714d24200d8a916a1e090b123f27eaf24\ + f047f3930a77e5bcd33eeb823b73b0e9546c59d3d6e69383c74ae72b79645698fe1422\ + f83886bd3cbca9fbb63f7019e2139191dd000000007777772e6d6574612e636f6d", ) .unwrap(); let enc_report_bytes2 = hex::decode( diff --git a/ipa-core/src/sharding.rs b/ipa-core/src/sharding.rs index 195573f45..64ed9b6d3 100644 --- a/ipa-core/src/sharding.rs +++ b/ipa-core/src/sharding.rs @@ -7,12 +7,68 @@ use std::{ #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ShardIndex(pub u32); +impl ShardIndex { + pub const FIRST: Self = Self(0); + + /// Returns an iterator over all shard indices that precede this one, excluding this one. + pub fn iter(self) -> impl Iterator { + (0..self.0).map(Self) + } +} + +impl From for ShardIndex { + fn from(value: u32) -> Self { + Self(value) + } +} + +impl From for u64 { + fn from(value: ShardIndex) -> Self { + u64::from(value.0) + } +} + +impl From for u128 { + fn from(value: ShardIndex) -> Self { + Self::from(value.0) + } +} + +#[cfg(target_pointer_width = "64")] +impl From for usize { + fn from(value: ShardIndex) -> Self { + usize::try_from(value.0).unwrap() + } +} + impl From for u32 { fn from(value: ShardIndex) -> Self { value.0 } } +impl TryFrom for ShardIndex { + type Error = TryFromIntError; + + fn try_from(value: usize) -> Result { + u32::try_from(value).map(Self) + } +} + +impl TryFrom for ShardIndex { + type Error = TryFromIntError; + + fn try_from(value: u128) -> Result { + u32::try_from(value).map(Self) + } +} + +impl Display for ShardIndex { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + #[derive(Debug, Copy, Clone)] pub struct Sharded { pub shard_id: ShardIndex, @@ -29,12 +85,6 @@ impl ShardConfiguration for Sharded { } } -impl Display for ShardIndex { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - Display::fmt(&self.0, f) - } -} - /// Shard-specific configuration required by sharding API. Each shard must know its own index and /// the total number of shards in the system. pub trait ShardConfiguration { @@ -70,48 +120,6 @@ pub struct NotSharded; impl ShardBinding for NotSharded {} impl ShardBinding for Sharded {} -impl ShardIndex { - pub const FIRST: Self = Self(0); - - /// Returns an iterator over all shard indices that precede this one, excluding this one. - pub fn iter(self) -> impl Iterator { - (0..self.0).map(Self) - } -} - -impl From for ShardIndex { - fn from(value: u32) -> Self { - Self(value) - } -} - -impl From for u64 { - fn from(value: ShardIndex) -> Self { - u64::from(value.0) - } -} - -impl From for u128 { - fn from(value: ShardIndex) -> Self { - Self::from(value.0) - } -} - -#[cfg(target_pointer_width = "64")] -impl From for usize { - fn from(value: ShardIndex) -> Self { - usize::try_from(value.0).unwrap() - } -} - -impl TryFrom for ShardIndex { - type Error = TryFromIntError; - - fn try_from(value: usize) -> Result { - u32::try_from(value).map(Self) - } -} - #[cfg(all(test, unit_test))] mod tests { use std::iter::empty; diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index 9c6e7995f..38d12eb21 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -147,6 +147,37 @@ where join3(fut0, fut1, fut2) } +/// Wrapper for flattening 3 vecs of vecs into a single future +/// # Panics +/// If the tasks return `Err` or if `a` is the wrong length. +pub fn flatten3v(a: V) -> impl Future::Output>> +where + V: IntoIterator, + I: IntoIterator, + T: TryFuture, + T::Output: Debug, + T::Ok: Debug, + T::Error: Debug, +{ + let mut it = a.into_iter(); + + let outer0 = it.next().unwrap().into_iter(); + let outer1 = it.next().unwrap().into_iter(); + let outer2 = it.next().unwrap().into_iter(); + + assert!(it.next().is_none()); + + // only used for tests + #[allow(clippy::disallowed_methods)] + futures::future::join_all( + outer0 + .zip(outer1) + .zip(outer2) + .flat_map(|((fut0, fut1), fut2)| vec![fut0, fut1, fut2]) + .collect::>(), + ) +} + /// Take a slice of bits in `{0,1} ⊆ F_p`, and reconstruct the integer in `Z` pub fn bits_to_value(x: &[F]) -> u128 { #[allow(clippy::cast_possible_truncation)] diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 1290e9f98..5c23cb2d8 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -207,6 +207,42 @@ impl TestWorld> { .ok() .unwrap() } + + /// Creates protocol contexts for 3 helpers across all shards + /// + /// # Panics + /// Panics if world has more or less than 3 gateways/participants + #[must_use] + pub fn contexts(&self) -> [Vec>; 3] { + let gate = &self.next_gate(); + self.shards().iter().map(|shard| shard.contexts(gate)).fold( + [Vec::new(), Vec::new(), Vec::new()], + |mut acc, contexts| { + // Distribute contexts into the respective vectors. + for (vec, context) in acc.iter_mut().zip(contexts.iter()) { + vec.push(context.clone()); + } + acc + }, + ) + } + /// Creates malicious protocol contexts for 3 helpers across all shards + /// + /// # Panics + /// Panics if world has more or less than 3 gateways/participants + #[must_use] + pub fn malicious_contexts(&self) -> [Vec>; 3] { + self.shards() + .iter() + .map(|shard| shard.malicious_contexts(&self.next_gate())) + .fold([Vec::new(), Vec::new(), Vec::new()], |mut acc, contexts| { + // Distribute contexts into the respective vectors. + for (vec, context) in acc.iter_mut().zip(contexts.iter()) { + vec.push(context.clone()); + } + acc + }) + } } /// Backward-compatible API for tests that don't use sharding. From 2f431f81cfea1764f26b65a6a14504576e41eb4f Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 18 Oct 2024 15:58:15 -0700 Subject: [PATCH 158/191] Improve coverage a bit --- ipa-metrics/src/partitioned.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/ipa-metrics/src/partitioned.rs b/ipa-metrics/src/partitioned.rs index f723759c3..0f71d0e28 100644 --- a/ipa-metrics/src/partitioned.rs +++ b/ipa-metrics/src/partitioned.rs @@ -117,14 +117,7 @@ impl PartitionedStore { &'a mut self, key: B, ) -> CounterHandle<'a, LABELS> { - if let Some(partition) = CurrentThreadContext::get() { - self.inner - .entry(partition) - .or_insert_with(Store::default) - .counter(key) - } else { - self.default_store.counter(key) - } + self.get_mut(CurrentThreadContext::get()).counter(key) } #[must_use] From 38682b530756033319f565bca90971f45a1de13c Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Fri, 18 Oct 2024 16:51:01 -0700 Subject: [PATCH 159/191] fix merge conflict --- ipa-core/src/report/hybrid.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index f704d109a..a0d565416 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -495,7 +495,7 @@ impl UniqueTagValidator { /// if the item inserted is not unique among all checked thus far pub fn check_duplicate(&mut self, item: &U) -> Result<(), Error> { self.check_counter += 1; - if self.insert(item.unique_bytes().to_vec()) { + if self.insert(item.unique_bytes()) { Ok(()) } else { Err(Error::DuplicateBytes(self.check_counter)) @@ -521,9 +521,8 @@ mod test { use super::{ EncryptedHybridImpressionReport, EncryptedHybridReport, GenericArray, - HybridConversionReport, HybridImpressionReport, HybridReport, UniqueTag, UniqueTagValidator, - EncryptedHybridReport, HybridConversionReport, HybridImpressionReport, HybridReport, - UniqueTag, UniqueTagValidator, + HybridConversionReport, HybridImpressionReport, HybridReport, UniqueTag, + UniqueTagValidator, }; use crate::{ error::Error, From 99dbe22a56bce06783d1ba9586d525dc89e2b957 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sun, 20 Oct 2024 14:16:04 -0700 Subject: [PATCH 160/191] Add reshard_try_stream function This adds support for fallible streams in resharding, so it can operate on `Result` rather than just `T`. Any error that occurs in the input stream is propagated and resharding terminates after that. --- ipa-core/src/protocol/context/mod.rs | 322 +++++++++++++++++++++------ 1 file changed, 250 insertions(+), 72 deletions(-) diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 5b1de3505..bafc17be9 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -11,12 +11,12 @@ pub mod upgrade; mod batcher; pub mod validator; -use std::{collections::HashMap, iter, num::NonZeroUsize, pin::pin}; +use std::{collections::HashMap, num::NonZeroUsize, pin::pin}; use async_trait::async_trait; pub use dzkp_malicious::DZKPUpgraded as DZKPUpgradedMaliciousContext; pub use dzkp_semi_honest::DZKPUpgraded as DZKPUpgradedSemiHonestContext; -use futures::{stream, Stream, StreamExt}; +use futures::{stream, Stream, StreamExt, TryStreamExt}; use ipa_step::{Step, StepNarrow}; pub use malicious::MaliciousProtocolSteps; use prss::{InstrumentedIndexedSharedRandomness, InstrumentedSequentialSharedRandomness}; @@ -365,6 +365,12 @@ impl<'a> Inner<'a> { /// N per shard). Each channel stays open until the very last row is processed, then they are explicitly /// closed, even if nothing has been communicated between that pair. /// +/// ## Stream size +/// [`reshard_try_stream`] takes a regular stream, but will panic at runtime, if the stream +/// upper bound size is not known. Opting out for a runtime check is necessary for it to work +/// with query inputs, where the submitter stream is truncated to take at most `sz` elements. +/// This would mean that stream may have less than `sz` elements and resharding should work. +/// /// ## Shard picking considerations /// It is expected for `shard_picker` to select shards uniformly, by either using [`prss`] or sampling /// random values with enough entropy. Failure to do so may lead to extra memory overhead - this @@ -374,45 +380,29 @@ impl<'a> Inner<'a> { /// /// [`calculations`]: https://docs.google.com/document/d/1vej6tYgNV3GWcldD4tl7a4Z9EeZwda3F5u7roPGArlU/ /// -/// ## Stream size -/// Note that it currently works for streams where size is known in advance. Mainly because -/// we want to set up send buffer sizes and avoid sending records one-by-one to each shard. -/// Other than that, there are no technical limitation here, and it could be possible to make it -/// work with regular streams if the batching problem is somehow addressed. -/// -/// -/// ```compile_fail -/// use futures::stream::{self, StreamExt}; -/// use ipa_core::protocol::context::reshard_stream; -/// use ipa_core::ff::boolean::Boolean; -/// use ipa_core::secret_sharing::SharedValue; -/// async { -/// let a = [Boolean::ZERO]; -/// let mut s = stream::iter(a.into_iter()).cycle(); -/// // this should fail to compile: -/// // the trait bound `futures::stream::Cycle<...>: ExactSizeStream` is not satisfied -/// reshard_stream(todo!(), s, todo!()).await; -/// }; -/// ``` /// /// ## Panics -/// When `shard_picker` returns an out-of-bounds index. +/// When `shard_picker` returns an out-of-bounds index or if the input stream size +/// upper bound is not known. The latter may be the case for infinite streams. /// /// ## Errors -/// If cross-shard communication fails +/// If cross-shard communication fails or if an input stream +/// yields an `Err` element. /// -pub async fn reshard_stream( +pub async fn reshard_try_stream( ctx: C, input: L, shard_picker: S, ) -> Result, crate::error::Error> where - L: ExactSizeStream, + L: Stream>, S: Fn(C, RecordId, &K) -> ShardIndex, K: Message + Clone, C: ShardedContext, { - let input_len = input.len(); + let (_, Some(input_len)) = input.size_hint() else { + panic!("input stream must have size upper bound for resharding to work") + }; // We set channels capacity to be at least 1 to be able to open send channels to all peers. // It is prohibited to create them if total records is not set. We also over-provision here @@ -438,15 +428,17 @@ where // Request data from all shards. let rcv_stream = ctx .recv_from_shards::() - .map(|(shard_id, v)| { - ( - shard_id, - v.map(Option::Some).map_err(crate::error::Error::from), - ) + .map(|(shard_id, v)| match v { + Ok(v) => Ok((shard_id, Some(v))), + Err(e) => Err(e), }) .fuse(); let input = pin!(input); + // Annoying consequence of not having async closures stable. async blocks + // cannot capture `Copy` values and there is no way to express that + // only some things need to be moved in Rust + let mut counter = 0_u32; // This produces a stream of outcomes of send requests. // In order to make it compatible with receive stream, it also returns records that must @@ -456,38 +448,36 @@ where // whole resharding process. // If send was successful, we set the argument to Ok(None). Only records assigned to this shard // by the `shard_picker` will have the value of Ok(Some(Value)) - let send_stream = futures::stream::unfold( + let send_stream = futures::stream::try_unfold( // it is crucial that the following execution is completed sequentially, in order for record id // tracking per shard to work correctly. If tasks complete out of order, this will cause share // misplacement on the recipient side. - ( - input - .enumerate() - .zip(stream::iter(iter::repeat(ctx.clone()))), - &mut send_channels, - ), - |(mut input, send_channels)| async { - // Process more data as it comes in, or close the sending channels, if there is nothing - // left. - if let Some(((i, val), ctx)) = input.next().await { - let dest_shard = shard_picker(ctx, RecordId::from(i), &val); - if dest_shard == my_shard { - Some(((my_shard, Ok(Some(val.clone()))), (input, send_channels))) + (input, &mut send_channels, &mut counter), + |(mut input, send_channels, i)| { + let ctx = ctx.clone(); + + async { + // Process more data as it comes in, or close the sending channels, if there is nothing + // left. + if let Some(val) = input.try_next().await? { + let dest_shard = shard_picker(ctx, RecordId::from(*i), &val); + *i += 1; + if dest_shard == my_shard { + Ok(Some(((my_shard, Some(val)), (input, send_channels, i)))) + } else { + let (record_id, se) = send_channels.get_mut(&dest_shard).unwrap(); + se.send(*record_id, val) + .await + .map_err(crate::error::Error::from)?; + *record_id += 1; + Ok(Some(((my_shard, None), (input, send_channels, i)))) + } } else { - let (record_id, se) = send_channels.get_mut(&dest_shard).unwrap(); - let send_result = se - .send(*record_id, val) - .await - .map_err(crate::error::Error::from) - .map(|()| None); - *record_id += 1; - Some(((my_shard, send_result), (input, send_channels))) - } - } else { - for (last_record, send_channel) in send_channels.values() { - send_channel.close(*last_record).await; + for (last_record, send_channel) in send_channels.values() { + send_channel.close(*last_record).await; + } + Ok(None) } - None } }, ) @@ -519,8 +509,8 @@ where // This approach makes sure we do what we can - send or receive. let mut send_recv = pin!(futures::stream::select(send_stream, rcv_stream)); - while let Some((shard_id, v)) = send_recv.next().await { - if let Some(m) = v? { + while let Some((shard_id, v)) = send_recv.try_next().await? { + if let Some(m) = v { r[usize::from(shard_id)].push(m); } } @@ -528,12 +518,56 @@ where Ok(r.into_iter().flatten().collect()) } +/// Provides the same functionality as [`reshard_try_stream`] on +/// infallible streams +/// +/// ## Stream size +/// Note that it currently works for streams where size is known in advance. Mainly because +/// we want to set up send buffer sizes and avoid sending records one-by-one to each shard. +/// Other than that, there are no technical limitation here, and it could be possible to make it +/// work with regular streams or opt-out to runtime checks as [`reshard_try_stream`] does. +/// +/// +/// ```compile_fail +/// use futures::stream::{self, StreamExt}; +/// use ipa_core::protocol::context::reshard_stream; +/// use ipa_core::ff::boolean::Boolean; +/// use ipa_core::secret_sharing::SharedValue; +/// async { +/// let a = [Boolean::ZERO]; +/// let mut s = stream::iter(a.into_iter()).cycle(); +/// // this should fail to compile: +/// // the trait bound `futures::stream::Cycle<...>: ExactSizeStream` is not satisfied +/// reshard_stream(todo!(), s, todo!()).await; +/// }; +/// ``` +/// ## Panics +/// When `shard_picker` returns an out-of-bounds index. +/// +/// ## Errors +/// If cross-shard communication fails +pub async fn reshard_stream( + ctx: C, + input: L, + shard_picker: S, +) -> Result, crate::error::Error> +where + L: ExactSizeStream, + S: Fn(C, RecordId, &K) -> ShardIndex, + K: Message + Clone, + C: ShardedContext, +{ + reshard_try_stream(ctx, input.map(Ok), shard_picker).await +} + /// Same as [`reshard_stream`] but takes an iterator with the known size /// as input. /// -/// # Errors +/// ## Panics +/// When `shard_picker` returns an out-of-bounds index. /// -/// # Panics +/// ## Errors +/// If cross-shard communication fails pub async fn reshard_iter( ctx: C, input: L, @@ -567,12 +601,13 @@ pub trait DZKPContext: Context { async fn validate_record(&self, record_id: RecordId) -> Result<(), Error>; } -#[cfg(all(test, unit_test))] +#[cfg(test)] mod tests { - use std::{iter, iter::repeat}; + use std::{iter, iter::repeat, pin::Pin, task::Poll}; - use futures::{future::join_all, stream, stream::StreamExt, try_join}; + use futures::{future::join_all, ready, stream, stream::StreamExt, try_join, Stream}; use ipa_step::StepNarrow; + use pin_project::pin_project; use rand::{ distributions::{Distribution, Standard}, Rng, @@ -588,16 +623,20 @@ mod tests { protocol::{ basics::ShareKnownValue, context::{ - reshard_iter, reshard_stream, step::MaliciousProtocolStep::MaliciousProtocol, - upgrade::Upgradable, Context, ShardedContext, UpgradableContext, Validator, + reshard_iter, reshard_stream, reshard_try_stream, + step::MaliciousProtocolStep::MaliciousProtocol, upgrade::Upgradable, Context, + ShardedContext, UpgradableContext, Validator, }, prss::SharedRandomness, RecordId, }, - secret_sharing::replicated::{ - malicious::{AdditiveShare as MaliciousReplicated, ExtendableField}, - semi_honest::AdditiveShare as Replicated, - ReplicatedSecretSharing, + secret_sharing::{ + replicated::{ + malicious::{AdditiveShare as MaliciousReplicated, ExtendableField}, + semi_honest::AdditiveShare as Replicated, + ReplicatedSecretSharing, + }, + SharedValue, }, sharding::{ShardConfiguration, ShardIndex}, telemetry::metrics::{ @@ -917,6 +956,145 @@ mod tests { }); } + #[test] + fn reshard_try_stream_basic() { + run(|| async move { + const SHARDS: u32 = 5; + let input: Vec<_> = (0..SHARDS).map(BA8::truncate_from).collect(); + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let r = world + .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { + reshard_try_stream(ctx, stream::iter(shard_input).map(Ok), |_, record_id, _| { + ShardIndex::from(u32::from(record_id) % SHARDS) + }) + .await + .unwrap() + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } + + #[test] + fn reshard_try_stream_less_items_than_expected() { + /// This allows advertising higher upper bound limit + /// that actual number of elements in the stream. + /// reshard should be able to tolerate that + #[pin_project] + struct Wrapper { + #[pin] + inner: S, + expected_len: usize, + } + + impl Wrapper { + fn new(inner: S, expected_len: usize) -> Self { + assert!(expected_len > 0); + Self { + inner, + expected_len, + } + } + } + + impl Stream for Wrapper { + type Item = S::Item; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.project(); + let r = match ready!(this.inner.poll_next(cx)) { + Some(val) => { + *this.expected_len -= 1; + Poll::Ready(Some(val)) + } + None => Poll::Ready(None), + }; + + assert!( + *this.expected_len > 0, + "Stream should have less elements than expected" + ); + r + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.expected_len)) + } + } + + run(|| async move { + const SHARDS: u32 = 5; + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let input: Vec<_> = (0..5 * SHARDS).map(BA8::truncate_from).collect(); + let r = world + .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { + reshard_try_stream( + ctx, + Wrapper::new(stream::iter(shard_input).map(Ok), 25), + |_, record_id, _| ShardIndex::from(u32::from(record_id) % SHARDS), + ) + .await + .unwrap() + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } + + #[test] + #[should_panic(expected = "input stream must have size upper bound for resharding to work")] + fn reshard_try_stream_infinite() { + run(|| async move { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest(Vec::::new().into_iter(), |ctx, _| async move { + reshard_try_stream(ctx, stream::repeat(BA8::ZERO).map(Ok), |_, _, _| { + ShardIndex::FIRST + }) + .await + .unwrap() + }) + .await; + }); + } + + #[test] + fn reshard_try_stream_err() { + run(|| async move { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest(Vec::::new().into_iter(), |ctx, _| async move { + let err = reshard_try_stream( + ctx, + stream::iter(vec![ + Ok(BA8::ZERO), + Err(crate::error::Error::InconsistentShares), + ]), + |_, _, _| ShardIndex::FIRST, + ) + .await + .unwrap_err(); + assert!(matches!(err, crate::error::Error::InconsistentShares)); + }) + .await; + }); + } + #[test] fn prss_one_side() { run(|| async { From 659561d4434c3ff7982befe652cdbc03e6c53054 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sun, 20 Oct 2024 14:18:54 -0700 Subject: [PATCH 161/191] Streaming tag resharding and match key collection in Hybrid In #1358 it was mentioned that waiting until all AAD tags and match keys have been collected before starting resharding process adds latency that is unnecessary. We can start resharding process right when we received the first tag and do everything in parallel. This change does that by leveraging newly added `reshard_try_stream` and a few helper structs and functions to make it ergonomic to use --- ipa-core/src/query/runner/hybrid.rs | 54 +++----- ipa-core/src/query/runner/mod.rs | 1 + ipa-core/src/query/runner/reshard_tag.rs | 149 +++++++++++++++++++++++ 3 files changed, 170 insertions(+), 34 deletions(-) create mode 100644 ipa-core/src/query/runner/reshard_tag.rs diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index cc3b861c7..7154c066b 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -10,12 +10,9 @@ use crate::{ BodyStream, LengthDelimitedStream, }, hpke::PrivateKeyRegistry, - protocol::{ - context::{reshard_iter, ShardedContext}, - hybrid::step::HybridStep, - step::ProtocolStep::Hybrid, - }, - report::hybrid::{EncryptedHybridReport, HybridReport, UniqueTag, UniqueTagValidator}, + protocol::{context::ShardedContext, hybrid::step::HybridStep, step::ProtocolStep::Hybrid}, + query::runner::reshard_tag::reshard_aad, + report::hybrid::{EncryptedHybridReport, UniqueTag, UniqueTagValidator}, secret_sharing::{replicated::semi_honest::AdditiveShare as ReplicatedShare, SharedValue}, }; @@ -61,35 +58,24 @@ where )); } - let (_decrypted_reports, tags): (Vec>, Vec) = - LengthDelimitedStream::::new(input_stream) - .map_err(Into::::into) - .map_ok(|enc_reports| { - iter(enc_reports.into_iter().map({ - |enc_report| { - let dec_report = enc_report - .decrypt::(key_registry.as_ref()) - .map_err(Into::::into); - let unique_tag = UniqueTag::from_unique_bytes(&enc_report); - dec_report.map(|dec_report1| (dec_report1, unique_tag)) - } - })) - }) - .try_flatten() - .take(sz) - .try_fold( - (Vec::with_capacity(sz), Vec::with_capacity(sz)), - |mut acc, result| async move { - acc.0.push(result.0); - acc.1.push(result.1); - Ok(acc) - }, - ) - .await?; - - let resharded_tags = reshard_iter( + let stream = LengthDelimitedStream::::new(input_stream) + .map_err(Into::::into) + .map_ok(|enc_reports| { + iter(enc_reports.into_iter().map({ + |enc_report| { + let dec_report = enc_report + .decrypt::(key_registry.as_ref()) + .map_err(Into::::into); + let unique_tag = UniqueTag::from_unique_bytes(&enc_report); + dec_report.map(|dec_report1| (dec_report1, unique_tag)) + } + })) + }) + .try_flatten() + .take(sz); + let (_decrypted_reports, resharded_tags) = reshard_aad( ctx.narrow(&HybridStep::ReshardByTag), - tags, + stream, |ctx, _, tag| tag.shard_picker(ctx.shard_count()), ) .await?; diff --git a/ipa-core/src/query/runner/mod.rs b/ipa-core/src/query/runner/mod.rs index 9bd739db9..3f1b59f55 100644 --- a/ipa-core/src/query/runner/mod.rs +++ b/ipa-core/src/query/runner/mod.rs @@ -2,6 +2,7 @@ mod add_in_prime_field; mod hybrid; mod oprf_ipa; +mod reshard_tag; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] mod test_multiply; diff --git a/ipa-core/src/query/runner/reshard_tag.rs b/ipa-core/src/query/runner/reshard_tag.rs new file mode 100644 index 000000000..5ef7b6311 --- /dev/null +++ b/ipa-core/src/query/runner/reshard_tag.rs @@ -0,0 +1,149 @@ +use std::{ + pin::{pin, Pin}, + task::{Context, Poll}, +}; + +use futures::{ready, Stream}; +use pin_project::pin_project; + +use crate::{ + error::Error, + helpers::Message, + protocol::{ + context::{reshard_try_stream, ShardedContext}, + RecordId, + }, + sharding::ShardIndex, +}; + +type DataWithTag = Result<(D, A), Error>; + +/// Helper function to work with inputs to hybrid queries. Each encryption needs +/// to be checked for uniqueness and we use AAD tag for that. While match keys are +/// being collected, AAD tags need to be resharded. This function does both at the same +/// time which should reduce the perceived latency of queries. +/// +/// The output contains two separate collections: one for data and another one +/// for AAD tags that are "owned" by this shard. The tags can later be checked for +/// uniqueness. +/// +/// ## Errors +/// This will return an error, if input stream contains at least one `Err` element. +#[allow(dead_code)] +pub async fn reshard_aad( + ctx: C, + input: L, + shard_picker: S, +) -> Result<(Vec, Vec), crate::error::Error> +where + L: Stream>, + S: Fn(C, RecordId, &A) -> ShardIndex + Send, + A: Message + Clone, + C: ShardedContext, +{ + let mut k_buf = Vec::with_capacity(input.size_hint().1.unwrap_or(0)); + let splitter = StreamSplitter { + inner: input, + buf: &mut k_buf, + }; + let a_buf = reshard_try_stream(ctx, splitter, shard_picker).await?; + + Ok((k_buf, a_buf)) +} + +/// Takes a fallible input stream that yields a tuple `(K, A)` and produces a new stream +/// over `A` while collecting `K` elements into the provided buffer. +/// Any error encountered from the input stream is propagated. +#[pin_project] +struct StreamSplitter<'a, S: Stream>, K, A> { + #[pin] + inner: S, + buf: &'a mut Vec, +} + +impl>, K, A> Stream for StreamSplitter<'_, S, K, A> { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + match ready!(this.inner.poll_next(cx)) { + Some(Ok((k, a))) => { + this.buf.push(k); + Poll::Ready(Some(Ok(a))) + } + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +#[cfg(test)] +mod tests { + use futures::{stream, StreamExt}; + + use crate::{ + error::Error, + ff::{boolean_array::BA8, U128Conversions}, + query::runner::reshard_tag::reshard_aad, + secret_sharing::SharedValue, + sharding::{ShardConfiguration, ShardIndex}, + test_executor::run, + test_fixture::{Runner, TestWorld, TestWorldConfig, WithShards}, + }; + + #[test] + fn reshard_basic() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest( + vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(), + |ctx, input| async move { + let shard_id = ctx.shard_id(); + let sz = input.len(); + let (values, tags) = reshard_aad( + ctx, + stream::iter(input).map(|v| Ok((v, BA8::ZERO))), + |_, _, _| ShardIndex::FIRST, + ) + .await + .unwrap(); + assert_eq!(sz, values.len()); + match shard_id { + ShardIndex::FIRST => assert_eq!(2, tags.len()), + _ => assert_eq!(0, tags.len()), + } + }, + ) + .await; + }); + } + + #[test] + #[should_panic(expected = "InconsistentShares")] + fn reshard_err() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest( + vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(), + |ctx, input| async move { + reshard_aad( + ctx, + stream::iter(input) + .map(|_| Err::<(BA8, BA8), _>(Error::InconsistentShares)), + |_, _, _| ShardIndex::FIRST, + ) + .await + .unwrap(); + }, + ) + .await; + }); + } +} From d48a18bcc2dccd4d81fdd139ef61ea42d5996cbd Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Mon, 21 Oct 2024 01:47:09 -0700 Subject: [PATCH 162/191] manual seralize test --- ipa-core/src/report/hybrid.rs | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index a0d565416..7afbee256 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -532,7 +532,9 @@ mod test { }, hpke::{KeyPair, KeyRegistry}, report::{ - hybrid::NonAsciiStringError, hybrid_info::HybridImpressionInfo, EventType, OprfReport, + hybrid::{NonAsciiStringError, BA64}, + hybrid_info::HybridImpressionInfo, + EventType, OprfReport, }, secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, }; @@ -658,6 +660,28 @@ mod test { ) .unwrap(); assert_eq!(hybrid_impression_report, hybrid_impression_report2); + + let hybrid_report3 = HybridImpressionReport::::deserialize(GenericArray::from_slice( + &hex::decode("4123a6e38ef1d6d9785c948797cb744d38f4").unwrap(), + )) + .unwrap(); + + let match_key = AdditiveShare::::deserialize(GenericArray::from_slice( + &hex::decode("4123a6e38ef1d6d9785c948797cb744d").unwrap(), + )) + .unwrap(); + let breakdown_key = AdditiveShare::::deserialize(GenericArray::from_slice( + &hex::decode("38f4").unwrap(), + )) + .unwrap(); + + assert_eq!( + hybrid_report3, + HybridImpressionReport:: { + match_key, + breakdown_key + } + ); } #[test] From c38d472857ba8f4ae0e25c9f9111a277b8d97ecb Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 21 Oct 2024 11:02:58 -0700 Subject: [PATCH 163/191] Fix another clippy error --- ipa-metrics/src/context.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-metrics/src/context.rs b/ipa-metrics/src/context.rs index 1c1ff3d7f..2020f14fd 100644 --- a/ipa-metrics/src/context.rs +++ b/ipa-metrics/src/context.rs @@ -130,7 +130,7 @@ mod tests { fn local_store() { use crate::{context::CurrentThreadContext, CurrentThreadPartitionContext}; - CurrentThreadPartitionContext::set(0xdeadbeef); + CurrentThreadPartitionContext::set(0xdead_beef); counter!("foo", 7); std::thread::spawn(|| { From 18c0960ad695f06e5aa43b9d947c1ecf94f9618e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 21 Oct 2024 11:07:11 -0700 Subject: [PATCH 164/191] Add metric-tracing crate This will be used to support metric partitioning in unit tests --- Cargo.toml | 2 +- ipa-metrics-tracing/Cargo.toml | 10 +++ ipa-metrics-tracing/src/layer.rs | 123 +++++++++++++++++++++++++++++++ ipa-metrics-tracing/src/lib.rs | 7 ++ 4 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 ipa-metrics-tracing/Cargo.toml create mode 100644 ipa-metrics-tracing/src/layer.rs create mode 100644 ipa-metrics-tracing/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 1aed2b4b7..deb437919 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["ipa-core", "ipa-step", "ipa-step-derive", "ipa-step-test", "ipa-metrics"] +members = ["ipa-core", "ipa-step", "ipa-step-derive", "ipa-step-test", "ipa-metrics", "ipa-metrics-tracing"] [profile.release] incremental = true diff --git a/ipa-metrics-tracing/Cargo.toml b/ipa-metrics-tracing/Cargo.toml new file mode 100644 index 000000000..ac7314c19 --- /dev/null +++ b/ipa-metrics-tracing/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "ipa-metrics-tracing" +version = "0.1.0" +edition = "2021" + +[dependencies] +# requires partitions feature because without it, it does not make sense to use +ipa-metrics = { version = "*", path = "../ipa-metrics", features = ["partitions"] } +tracing = "0.1" +tracing-subscriber = "0.3" \ No newline at end of file diff --git a/ipa-metrics-tracing/src/layer.rs b/ipa-metrics-tracing/src/layer.rs new file mode 100644 index 000000000..85d07d910 --- /dev/null +++ b/ipa-metrics-tracing/src/layer.rs @@ -0,0 +1,123 @@ +use std::fmt::Debug; + +use ipa_metrics::{CurrentThreadPartitionContext, MetricPartition, MetricsCurrentThreadContext}; +use tracing::{ + field::{Field, Visit}, + span::{Attributes, Record}, + Id, Subscriber, +}; +use tracing_subscriber::{ + layer::Context, + registry::{Extensions, ExtensionsMut, LookupSpan}, + Layer, +}; + +pub const FIELD: &str = concat!(env!("CARGO_PKG_NAME"), "-", "metrics-partition"); + +/// This layer allows partitioning metric stores. +/// This can be used in tests, where each unit test +/// creates its own unique root span. Upon entering +/// this span, this layer sets a unique partition key +#[derive(Default)] +pub struct MetricsPartitioningLayer; + +impl LookupSpan<'s>> Layer for MetricsPartitioningLayer { + fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) { + #[derive(Default)] + struct MaybeMetricPartition(Option); + + impl Visit for MaybeMetricPartition { + fn record_u64(&mut self, field: &Field, value: u64) { + if field.name() == FIELD { + self.0 = Some(value); + } + } + + fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) { + // not interested in anything else except MetricPartition values. + } + } + + let record = Record::new(attrs.values()); + let mut metric_partition = MaybeMetricPartition::default(); + record.record(&mut metric_partition); + if let Some(v) = metric_partition.0 { + let span = ctx.span(id).expect("Span should exists upon entering"); + span.extensions_mut().insert(MetricPartitionExt { + prev: None, + current: v, + }); + } + } + + fn on_enter(&self, id: &Id, ctx: Context<'_, S>) { + let span = ctx.span(id).expect("Span should exists upon entering"); + MetricPartitionExt::span_enter(span.extensions_mut()); + } + + fn on_exit(&self, id: &Id, ctx: Context<'_, S>) { + let span = ctx.span(id).expect("Span should exists upon exiting"); + MetricPartitionExt::span_exit(span.extensions_mut()); + } + + fn on_close(&self, id: Id, ctx: Context<'_, S>) { + let span = ctx.span(&id).expect("Span should exists before closing it"); + MetricPartitionExt::span_close(&span.extensions()); + } +} + +struct MetricPartitionExt { + // Partition active before span is entered. + prev: Option, + // Partition that must be set when this span is entered. + current: MetricPartition, +} + +impl MetricPartitionExt { + fn span_enter(mut span_ext: ExtensionsMut<'_>) { + if let Some(MetricPartitionExt { current, prev }) = span_ext.get_mut() { + *prev = CurrentThreadPartitionContext::get(); + CurrentThreadPartitionContext::set(*current); + } + } + + fn span_exit(mut span_ext: ExtensionsMut) { + if let Some(MetricPartitionExt { prev, .. }) = span_ext.get_mut() { + CurrentThreadPartitionContext::toggle(prev.take()); + } + } + + fn span_close(span_ext: &Extensions) { + if let Some(MetricPartitionExt { .. }) = span_ext.get() { + MetricsCurrentThreadContext::flush(); + } + } +} + +#[cfg(test)] +mod tests { + use ipa_metrics::CurrentThreadPartitionContext; + use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + + use crate::{layer::FIELD, MetricsPartitioningLayer}; + + #[test] + fn basic() { + CurrentThreadPartitionContext::set(0); + tracing_subscriber::registry() + .with(MetricsPartitioningLayer) + .init(); + let span1 = tracing::info_span!("", { FIELD } = 1_u64); + let span2 = tracing::info_span!("", { FIELD } = 2_u64); + { + let _guard1 = span1.enter(); + assert_eq!(Some(1), CurrentThreadPartitionContext::get()); + { + let _guard2 = span2.enter(); + assert_eq!(Some(2), CurrentThreadPartitionContext::get()); + } + assert_eq!(Some(1), CurrentThreadPartitionContext::get()); + } + assert_eq!(Some(0), CurrentThreadPartitionContext::get()); + } +} diff --git a/ipa-metrics-tracing/src/lib.rs b/ipa-metrics-tracing/src/lib.rs new file mode 100644 index 000000000..c72bb9e54 --- /dev/null +++ b/ipa-metrics-tracing/src/lib.rs @@ -0,0 +1,7 @@ +#![deny(clippy::pedantic)] +#![allow(clippy::similar_names)] +#![allow(clippy::module_name_repetitions)] + +mod layer; + +pub use layer::{MetricsPartitioningLayer, FIELD as PARTITION_FIELD}; From a0341191b40679a435a9eae9bf128e596cc2014c Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Mon, 21 Oct 2024 13:45:29 -0700 Subject: [PATCH 165/191] Addressing comments --- ipa-core/src/cli/playbook/mod.rs | 2 +- ipa-core/src/config.rs | 42 +++++++++++++------------ ipa-core/src/net/server/handlers/mod.rs | 2 +- ipa-core/src/net/server/mod.rs | 25 ++++++++------- ipa-core/src/net/test.rs | 2 +- ipa-core/src/net/transport.rs | 2 +- ipa-core/src/sharding.rs | 15 +++++++++ 7 files changed, 54 insertions(+), 36 deletions(-) diff --git a/ipa-core/src/cli/playbook/mod.rs b/ipa-core/src/cli/playbook/mod.rs index fbf7843ec..4b5164f56 100644 --- a/ipa-core/src/cli/playbook/mod.rs +++ b/ipa-core/src/cli/playbook/mod.rs @@ -199,7 +199,7 @@ pub async fn make_clients( let network = if let Some(path) = network_path { NetworkConfig::from_toml_str(&fs::read_to_string(path).unwrap()).unwrap() } else { - NetworkConfig::::new_ring( + NetworkConfig::::new_mpc( vec![ PeerConfig::new("localhost:3000".parse().unwrap(), None), PeerConfig::new("localhost:3001".parse().unwrap(), None), diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 090c65b40..50bd90f4b 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -1,6 +1,5 @@ use std::{ borrow::{Borrow, Cow}, - collections::HashMap, fmt::{Debug, Formatter}, iter::zip, path::PathBuf, @@ -46,7 +45,7 @@ pub enum Error { /// The most important thing this contains is discovery information for each of the participating /// peers. #[derive(Clone, Debug, Deserialize)] -pub struct NetworkConfig { +pub struct NetworkConfig { peers: Vec, /// HTTP client configuration. @@ -56,10 +55,10 @@ pub struct NetworkConfig { /// The identities of the index-matching peers. Separating this from [`Self::peers`](field) so /// that parsing is easy to implement. #[serde(skip)] - identities: Vec, + identities: Vec, } -impl NetworkConfig { +impl NetworkConfig { /// # Panics /// If `PathAndQuery::from_str("")` fails #[must_use] @@ -102,7 +101,7 @@ impl NetworkConfig { /// the certificate against a truststore and identifying the peer by the certificate /// subject). This could be changed if the need arises. #[must_use] - pub fn identify_cert(&self, cert: Option<&CertificateDer>) -> Option { + pub fn identify_cert(&self, cert: Option<&CertificateDer>) -> Option { let cert = cert?; for (id, p) in zip(self.identities.iter(), self.peers.iter()) { if p.certificate.as_ref() == Some(cert) { @@ -119,35 +118,27 @@ impl NetworkConfig { } impl NetworkConfig { + /// # Panics + /// In the unlikely event a usize cannot be turned into a u32 #[must_use] pub fn new_shards(peers: Vec, client: ClientConfig) -> Self { - let mut identities = Vec::with_capacity(peers.len()); - for (i, _p) in zip(0u32.., peers.iter()) { - identities.push(ShardIndex(i)); - } + let identities = (0u32..peers.len().try_into().unwrap()) + .map(ShardIndex::from) + .collect(); Self { peers, client, identities, } } - - #[must_use] - pub fn peers_map(&self) -> HashMap { - let mut indexed_peers = HashMap::new(); - for (ix, p) in zip(self.identities.iter(), self.peers.iter()) { - indexed_peers.insert(*ix, p); - } - indexed_peers - } } impl NetworkConfig { - /// Creates a new ring configuration. + /// Creates a new configuration for 3 MPC clients (ring) configuration. /// # Panics /// If the vector doesn't contain exactly 3 items. #[must_use] - pub fn new_ring(ring: Vec, client: ClientConfig) -> Self { + pub fn new_mpc(ring: Vec, client: ClientConfig) -> Self { assert_eq!(3, ring.len()); Self { peers: ring, @@ -532,10 +523,12 @@ mod tests { use rand::rngs::StdRng; use rand_core::SeedableRng; + use super::{NetworkConfig, PeerConfig}; use crate::{ config::{ClientConfig, HpkeClientConfig, Http2Configurator, HttpClientConfigurator}, helpers::HelperIdentity, net::test::TestConfigBuilder, + sharding::ShardIndex, }; const URI_1: &str = "http://localhost:3000"; @@ -610,4 +603,13 @@ mod tests { }), ); } + + #[test] + fn indexing_peer_happy_case() { + let uri1 = URI_1.parse::().unwrap(); + let pc1 = PeerConfig::new(uri1, None); + let client = ClientConfig::default(); + let conf = NetworkConfig::new_shards(vec![pc1.clone()], client); + assert_eq!(conf.peers[ShardIndex(0)].url, pc1.url); + } } diff --git a/ipa-core/src/net/server/handlers/mod.rs b/ipa-core/src/net/server/handlers/mod.rs index a72b9e8f2..604cd06a0 100644 --- a/ipa-core/src/net/server/handlers/mod.rs +++ b/ipa-core/src/net/server/handlers/mod.rs @@ -8,7 +8,7 @@ use crate::{ sync::Arc, }; -pub fn ring_router(transport: Arc) -> Router { +pub fn mpc_router(transport: Arc) -> Router { echo::router().nest( http_serde::query::BASE_AXUM_PATH, Router::new() diff --git a/ipa-core/src/net/server/mod.rs b/ipa-core/src/net/server/mod.rs index 3ca8277a4..c5c996f08 100644 --- a/ipa-core/src/net/server/mod.rs +++ b/ipa-core/src/net/server/mod.rs @@ -89,12 +89,12 @@ pub struct MpcHelperServer { } impl MpcHelperServer { - pub fn new_ring( + pub fn new_mpc( transport: Arc, config: ServerConfig, network_config: NetworkConfig, ) -> Self { - let router = handlers::ring_router(transport); + let router = handlers::mpc_router(transport); MpcHelperServer { config, network_config, @@ -309,12 +309,13 @@ async fn rustls_config( Ok(RustlsConfig::from_config(Arc::new(config))) } -/// Axum `Extension` indicating the authenticated remote helper identity, if any. +/// Axum `Extension` indicating the authenticated remote identity, if any. This can be either a +/// Shard authenticating or another Helper. // -// Presence or absence of authentication is indicated by presence or absence of the extension. Even -// at some inconvenience (e.g. `MaybeExtensionExt`), we avoid using `Option` within the extension, -// to avoid possible confusion about how many times the return from `req.extensions().get()` must be -// unwrapped to ensure valid authentication. +/// Presence or absence of authentication is indicated by presence or absence of the extension. Even +/// at some inconvenience (e.g. `MaybeExtensionExt`), we avoid using `Option` within the extension, +/// to avoid possible confusion about how many times the return from `req.extensions().get()` must be +/// unwrapped to ensure valid authentication. #[derive(Clone, Copy, Debug, PartialEq)] struct ClientIdentity(pub I); @@ -326,10 +327,10 @@ impl Deref for ClientIdentity { } } -impl TryFrom for ClientIdentity { +impl TryFrom<&HeaderValue> for ClientIdentity { type Error = Error; - fn try_from(value: HeaderValue) -> Result { + fn try_from(value: &HeaderValue) -> Result { let header_str = value.to_str()?; I::from_str(header_str) .map_err(|e| Error::InvalidHeader(Box::new(e))) @@ -461,7 +462,7 @@ where fn call(&mut self, mut req: Request) -> Self::Future { if let Some(header_value) = req.headers().get(F::identity_header()) { - let id_result = ClientIdentity::::try_from(header_value.clone()); + let id_result = ClientIdentity::::try_from(header_value); match id_result { Ok(id) => req.extensions_mut().insert(id), Err(err) => return ready(Ok(err.into_response())).right_future(), @@ -480,7 +481,7 @@ mod tests { #[test] fn identify_from_header_happy_case() { let h = HeaderValue::from_static("A"); - let id = ClientIdentity::::try_from(h); + let id = ClientIdentity::::try_from(&h); assert_eq!(id.unwrap(), ClientIdentity(HelperIdentity::ONE)); } @@ -488,7 +489,7 @@ mod tests { #[should_panic = "The string H1 is an invalid Helper Identity"] fn identify_from_header_wrong_header() { let h = HeaderValue::from_static("H1"); - let id = ClientIdentity::::try_from(h); + let id = ClientIdentity::::try_from(&h); id.unwrap(); } } diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index 9fdcf5877..bfa29f308 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -175,7 +175,7 @@ impl TestConfigBuilder { }, }) .collect::>(); - let network = NetworkConfig::::new_ring( + let network = NetworkConfig::::new_mpc( peers, self.use_http1 .then(ClientConfig::use_http1) diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 7ada0da2f..f0cf4c477 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -74,7 +74,7 @@ impl HttpTransport { ) -> (Arc, MpcHelperServer) { let transport = Self::new_internal(runtime, identity, clients, handler); let server = - MpcHelperServer::new_ring(Arc::clone(&transport), server_config, network_config); + MpcHelperServer::new_mpc(Arc::clone(&transport), server_config, network_config); (transport, server) } diff --git a/ipa-core/src/sharding.rs b/ipa-core/src/sharding.rs index 195573f45..433254a19 100644 --- a/ipa-core/src/sharding.rs +++ b/ipa-core/src/sharding.rs @@ -1,6 +1,7 @@ use std::{ fmt::{Debug, Display, Formatter}, num::TryFromIntError, + ops::{Index, IndexMut}, }; /// A unique zero-based index of the helper shard. @@ -112,6 +113,20 @@ impl TryFrom for ShardIndex { } } +impl Index for Vec { + type Output = T; + + fn index(&self, index: ShardIndex) -> &Self::Output { + self.as_slice().index(usize::from(index)) + } +} + +impl IndexMut for Vec { + fn index_mut(&mut self, index: ShardIndex) -> &mut Self::Output { + self.as_mut_slice().index_mut(usize::from(index)) + } +} + #[cfg(all(test, unit_test))] mod tests { use std::iter::empty; From e448f5e14f704c5f11cc9c747f48638b43fec017 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 21 Oct 2024 15:04:07 -0700 Subject: [PATCH 166/191] Fix compile error --- ipa-core/src/protocol/context/mod.rs | 2 +- ipa-core/src/query/runner/reshard_tag.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index bafc17be9..6fe6d673d 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -601,7 +601,7 @@ pub trait DZKPContext: Context { async fn validate_record(&self, record_id: RecordId) -> Result<(), Error>; } -#[cfg(test)] +#[cfg(all(test, unit_test))] mod tests { use std::{iter, iter::repeat, pin::Pin, task::Poll}; diff --git a/ipa-core/src/query/runner/reshard_tag.rs b/ipa-core/src/query/runner/reshard_tag.rs index 5ef7b6311..92ef8d53e 100644 --- a/ipa-core/src/query/runner/reshard_tag.rs +++ b/ipa-core/src/query/runner/reshard_tag.rs @@ -80,7 +80,7 @@ impl>, K, A> Stream for StreamSplitter<'_ } } -#[cfg(test)] +#[cfg(all(test, unit_test))] mod tests { use futures::{stream, StreamExt}; From 239dc687167c159c4e9bc97bd7937ce42de1b1a6 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 21 Oct 2024 15:28:22 -0700 Subject: [PATCH 167/191] Add a test to verify stream rejection in case when there are more records than specified --- ipa-core/src/protocol/context/mod.rs | 64 +++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 6fe6d673d..d84dd0926 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -455,11 +455,17 @@ where (input, &mut send_channels, &mut counter), |(mut input, send_channels, i)| { let ctx = ctx.clone(); - async { // Process more data as it comes in, or close the sending channels, if there is nothing // left. if let Some(val) = input.try_next().await? { + if usize::try_from(*i).unwrap() >= input_len { + return Err(crate::error::Error::RecordIdOutOfRange { + record_id: RecordId::from(*i), + total_records: input_len, + }); + } + let dest_shard = shard_picker(ctx, RecordId::from(*i), &val); *i += 1; if dest_shard == my_shard { @@ -980,6 +986,62 @@ mod tests { }); } + #[test] + #[should_panic(expected = "RecordIdOutOfRange { record_id: RecordId(1), total_records: 1 }")] + fn reshard_try_stream_more_items_than_expected() { + #[pin_project] + struct AdversaryStream { + #[pin] + inner: S, + wrong_length: usize, + } + + impl AdversaryStream { + fn new(inner: S, wrong_length: usize) -> Self { + assert!(wrong_length > 0); + Self { + inner, + wrong_length, + } + } + } + + impl Stream for AdversaryStream { + type Item = S::Item; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.project(); + + this.inner.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.wrong_length)) + } + } + + run(|| async move { + const SHARDS: u32 = 5; + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let input: Vec<_> = (0..5 * SHARDS).map(BA8::truncate_from).collect(); + world + .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { + reshard_try_stream( + ctx, + AdversaryStream::new(stream::iter(shard_input).map(Ok), 1), + |_, _, _| ShardIndex::FIRST, + ) + .await + .unwrap() + }) + .await; + }); + } + #[test] fn reshard_try_stream_less_items_than_expected() { /// This allows advertising higher upper bound limit From 0b2f19f12678d5755e77ec65eaba34fd5e049283 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 21 Oct 2024 16:13:29 -0700 Subject: [PATCH 168/191] Fix reshard_aad doc --- ipa-core/src/query/runner/reshard_tag.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/query/runner/reshard_tag.rs b/ipa-core/src/query/runner/reshard_tag.rs index 92ef8d53e..5d1c3b8f5 100644 --- a/ipa-core/src/query/runner/reshard_tag.rs +++ b/ipa-core/src/query/runner/reshard_tag.rs @@ -19,7 +19,7 @@ use crate::{ type DataWithTag = Result<(D, A), Error>; /// Helper function to work with inputs to hybrid queries. Each encryption needs -/// to be checked for uniqueness and we use AAD tag for that. While match keys are +/// to be checked for uniqueness and we use AAD tag for that. While reports are /// being collected, AAD tags need to be resharded. This function does both at the same /// time which should reduce the perceived latency of queries. /// From 64eccd983a153562058a30fc118a2ebc27626b87 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Mon, 21 Oct 2024 16:16:37 -0700 Subject: [PATCH 169/191] ConnectionFlavor-aware HttpTransport --- ipa-core/src/bin/helper.rs | 26 +- ipa-core/src/helpers/gateway/mod.rs | 4 +- ipa-core/src/net/client/mod.rs | 2 +- ipa-core/src/net/mod.rs | 2 +- ipa-core/src/net/server/handlers/mod.rs | 9 +- .../src/net/server/handlers/query/create.rs | 11 +- .../src/net/server/handlers/query/input.rs | 10 +- .../src/net/server/handlers/query/kill.rs | 12 +- ipa-core/src/net/server/handlers/query/mod.rs | 17 +- .../src/net/server/handlers/query/prepare.rs | 11 +- .../src/net/server/handlers/query/results.rs | 10 +- .../src/net/server/handlers/query/status.rs | 10 +- .../src/net/server/handlers/query/step.rs | 15 +- ipa-core/src/net/server/mod.rs | 26 +- ipa-core/src/net/test.rs | 7 +- ipa-core/src/net/transport.rs | 326 ++++++++++++------ 16 files changed, 318 insertions(+), 180 deletions(-) diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 1b2a50398..e11074cd8 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -1,4 +1,5 @@ use std::{ + collections::HashMap, fs, io::BufReader, net::TcpListener, @@ -17,7 +18,8 @@ use ipa_core::{ error::BoxError, executor::IpaRuntime, helpers::HelperIdentity, - net::{ClientIdentity, HttpShardTransport, HttpTransport, MpcHelperClient}, + net::{ClientIdentity, HttpTransport, MpcHelperClient, MpcHttpTransport, ShardHttpTransport}, + sharding::ShardIndex, AppConfig, AppSetup, NonZeroU32PowerOfTwo, }; use tokio::runtime::Runtime; @@ -158,13 +160,19 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { let network_config_path = args.network.as_deref().unwrap(); let network_config = NetworkConfig::from_toml_str(&fs::read_to_string(network_config_path)?)? .override_scheme(&scheme); + + // TODO: Following is just temporary until Shard Transport is actually used. + let shard_clients_config = network_config.client.clone(); + let shard_server_config = server_config.clone(); + // --- + let http_runtime = new_http_runtime(); let clients = MpcHelperClient::from_conf( &IpaRuntime::from_tokio_runtime(&http_runtime), &network_config, &identity, ); - let (transport, server) = HttpTransport::new( + let (transport, server) = MpcHttpTransport::new( IpaRuntime::from_tokio_runtime(&http_runtime), my_identity, server_config, @@ -173,7 +181,19 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { Some(handler), ); - let _app = setup.connect(transport.clone(), HttpShardTransport); + // TODO: Following is just temporary until Shard Transport is actually used. + let shard_network_config = NetworkConfig::new_shards(vec![], shard_clients_config); + let (shard_transport, _shard_server) = ShardHttpTransport::new( + IpaRuntime::from_tokio_runtime(&http_runtime), + ShardIndex::FIRST, + shard_server_config, + shard_network_config, + HashMap::new(), + None, + ); + // --- + + let _app = setup.connect(transport.clone(), shard_transport.clone()); let listener = args.server_socket_fd .map(|fd| { diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 2d48380db..a25321da4 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -45,9 +45,9 @@ pub type MpcTransportImpl = TransportImpl; pub type ShardTransportImpl = TransportImpl; #[cfg(feature = "real-world-infra")] -pub type MpcTransportImpl = crate::sync::Arc; +pub type MpcTransportImpl = crate::net::MpcHttpTransport; #[cfg(feature = "real-world-infra")] -pub type ShardTransportImpl = crate::net::HttpShardTransport; +pub type ShardTransportImpl = crate::net::ShardHttpTransport; pub type MpcTransportError = ::Error; diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index d602e5ce2..3693431f2 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -691,7 +691,7 @@ pub(crate) mod tests { resp_ok(resp).await.unwrap(); - let mut stream = Arc::clone(&transport) + let mut stream = transport .receive(HelperIdentity::ONE, (QueryId, expected_step.clone())) .into_bytes_stream(); diff --git a/ipa-core/src/net/mod.rs b/ipa-core/src/net/mod.rs index 6f60116ca..e0fdca35a 100644 --- a/ipa-core/src/net/mod.rs +++ b/ipa-core/src/net/mod.rs @@ -26,7 +26,7 @@ mod transport; pub use client::{ClientIdentity, MpcHelperClient}; pub use error::Error; pub use server::{MpcHelperServer, TracingSpanMaker}; -pub use transport::{HttpShardTransport, HttpTransport}; +pub use transport::{HttpTransport, MpcHttpTransport, ShardHttpTransport}; const APPLICATION_JSON: &str = "application/json"; const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; diff --git a/ipa-core/src/net/server/handlers/mod.rs b/ipa-core/src/net/server/handlers/mod.rs index 604cd06a0..3e83c6568 100644 --- a/ipa-core/src/net/server/handlers/mod.rs +++ b/ipa-core/src/net/server/handlers/mod.rs @@ -3,16 +3,13 @@ mod query; use axum::Router; -use crate::{ - net::{http_serde, HttpTransport}, - sync::Arc, -}; +use crate::net::{http_serde, transport::MpcHttpTransport}; -pub fn mpc_router(transport: Arc) -> Router { +pub fn mpc_router(transport: MpcHttpTransport) -> Router { echo::router().nest( http_serde::query::BASE_AXUM_PATH, Router::new() - .merge(query::query_router(Arc::clone(&transport))) + .merge(query::query_router(transport.clone())) .merge(query::h2h_router(transport)), ) } diff --git a/ipa-core/src/net/server/handlers/query/create.rs b/ipa-core/src/net/server/handlers/query/create.rs index f56c0b8d2..58bf71e3b 100644 --- a/ipa-core/src/net/server/handlers/query/create.rs +++ b/ipa-core/src/net/server/handlers/query/create.rs @@ -2,22 +2,21 @@ use axum::{routing::post, Extension, Json, Router}; use hyper::StatusCode; use crate::{ - helpers::{ApiError, BodyStream, Transport}, + helpers::{ApiError, BodyStream}, net::{ http_serde::{self, query::QueryConfigQueryParams}, - Error, HttpTransport, + transport::MpcHttpTransport, + Error, }, query::NewQueryError, - sync::Arc, }; /// Takes details from the HTTP request and creates a `[TransportCommand]::CreateQuery` that is sent /// to the [`HttpTransport`]. async fn handler( - transport: Extension>, + transport: Extension, QueryConfigQueryParams(query_config): QueryConfigQueryParams, ) -> Result, Error> { - let transport = Transport::clone_ref(&*transport); match transport.dispatch(query_config, BodyStream::empty()).await { Ok(resp) => Ok(Json(resp.try_into()?)), Err(err @ ApiError::NewQuery(NewQueryError::State { .. })) => { @@ -27,7 +26,7 @@ async fn handler( } } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(http_serde::query::create::AXUM_PATH, post(handler)) .layer(Extension(transport)) diff --git a/ipa-core/src/net/server/handlers/query/input.rs b/ipa-core/src/net/server/handlers/query/input.rs index 844604485..da47e9386 100644 --- a/ipa-core/src/net/server/handlers/query/input.rs +++ b/ipa-core/src/net/server/handlers/query/input.rs @@ -2,14 +2,13 @@ use axum::{extract::Path, routing::post, Extension, Router}; use hyper::StatusCode; use crate::{ - helpers::{query::QueryInput, routing::RouteId, BodyStream, Transport}, - net::{http_serde, Error, HttpTransport}, + helpers::{query::QueryInput, routing::RouteId, BodyStream}, + net::{http_serde, transport::MpcHttpTransport, Error}, protocol::QueryId, - sync::Arc, }; async fn handler( - transport: Extension>, + transport: Extension, Path(query_id): Path, input_stream: BodyStream, ) -> Result<(), Error> { @@ -17,7 +16,6 @@ async fn handler( query_id, input_stream, }; - let transport = Transport::clone_ref(&*transport); let _ = transport .dispatch( (RouteId::QueryInput, query_input.query_id), @@ -29,7 +27,7 @@ async fn handler( Ok(()) } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(http_serde::query::input::AXUM_PATH, post(handler)) .layer(Extension(transport)) diff --git a/ipa-core/src/net/server/handlers/query/kill.rs b/ipa-core/src/net/server/handlers/query/kill.rs index aae68b993..f97fa1657 100644 --- a/ipa-core/src/net/server/handlers/query/kill.rs +++ b/ipa-core/src/net/server/handlers/query/kill.rs @@ -2,24 +2,22 @@ use axum::{extract::Path, routing::post, Extension, Json, Router}; use hyper::StatusCode; use crate::{ - helpers::{ApiError, BodyStream, Transport}, + helpers::{ApiError, BodyStream}, net::{ - http_serde::query::{kill, kill::Request}, + http_serde::query::kill::{self, Request}, server::Error, + transport::MpcHttpTransport, Error::QueryIdNotFound, - HttpTransport, }, protocol::QueryId, query::QueryKillStatus, - sync::Arc, }; async fn handler( - transport: Extension>, + transport: Extension, Path(query_id): Path, ) -> Result, Error> { let req = Request { query_id }; - let transport = Transport::clone_ref(&*transport); match transport.dispatch(req, BodyStream::empty()).await { Ok(state) => Ok(Json(kill::ResponseBody::from(state))), Err(ApiError::QueryKill(QueryKillStatus::NoSuchQuery(query_id))) => Err( @@ -29,7 +27,7 @@ async fn handler( } } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(kill::AXUM_PATH, post(handler)) .layer(Extension(transport)) diff --git a/ipa-core/src/net/server/handlers/query/mod.rs b/ipa-core/src/net/server/handlers/query/mod.rs index c15817a42..13b3b962d 100644 --- a/ipa-core/src/net/server/handlers/query/mod.rs +++ b/ipa-core/src/net/server/handlers/query/mod.rs @@ -19,8 +19,7 @@ use tower::{layer::layer_fn, Service}; use crate::{ helpers::HelperIdentity, - net::{server::ClientIdentity, HttpTransport}, - sync::Arc, + net::{server::ClientIdentity, transport::MpcHttpTransport}, }; /// Construct router for IPA query web service @@ -28,12 +27,12 @@ use crate::{ /// In principle, this web service could be backed by either an HTTP-interconnected helper network or /// an in-memory helper network. These are the APIs used by external callers (report collectors) to /// examine attribution results. -pub fn query_router(transport: Arc) -> Router { +pub fn query_router(transport: MpcHttpTransport) -> Router { Router::new() - .merge(create::router(Arc::clone(&transport))) - .merge(input::router(Arc::clone(&transport))) - .merge(status::router(Arc::clone(&transport))) - .merge(kill::router(Arc::clone(&transport))) + .merge(create::router(transport.clone())) + .merge(input::router(transport.clone())) + .merge(status::router(transport.clone())) + .merge(kill::router(transport.clone())) .merge(results::router(transport)) } @@ -44,9 +43,9 @@ pub fn query_router(transport: Arc) -> Router { /// particular query, to coordinate servicing that query. // // It might make sense to split the query and h2h handlers into two modules. -pub fn h2h_router(transport: Arc) -> Router { +pub fn h2h_router(transport: MpcHttpTransport) -> Router { Router::new() - .merge(prepare::router(Arc::clone(&transport))) + .merge(prepare::router(transport.clone())) .merge(step::router(transport)) .layer(layer_fn(HelperAuthentication::new)) } diff --git a/ipa-core/src/net/server/handlers/query/prepare.rs b/ipa-core/src/net/server/handlers/query/prepare.rs index add6c7a95..51ed1019d 100644 --- a/ipa-core/src/net/server/handlers/query/prepare.rs +++ b/ipa-core/src/net/server/handlers/query/prepare.rs @@ -2,24 +2,24 @@ use axum::{extract::Path, response::IntoResponse, routing::post, Extension, Json use hyper::StatusCode; use crate::{ - helpers::{query::PrepareQuery, BodyStream, HelperIdentity, Transport}, + helpers::{query::PrepareQuery, BodyStream, HelperIdentity}, net::{ http_serde::{ self, query::{prepare::RequestBody, QueryConfigQueryParams}, }, server::ClientIdentity, - Error, HttpTransport, + transport::MpcHttpTransport, + Error, }, protocol::QueryId, query::PrepareQueryError, - sync::Arc, }; /// Called by whichever peer helper is the leader for an individual query, to initiatialize /// processing of that query. async fn handler( - transport: Extension>, + transport: Extension, _: Extension>, // require that client is an authenticated helper Path(query_id): Path, QueryConfigQueryParams(config): QueryConfigQueryParams, @@ -30,7 +30,6 @@ async fn handler( config, roles, }; - let transport = Transport::clone_ref(&*transport); let _ = transport .dispatch(data, BodyStream::empty()) .await @@ -45,7 +44,7 @@ impl IntoResponse for PrepareQueryError { } } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(http_serde::query::prepare::AXUM_PATH, post(handler)) .layer(Extension(transport)) diff --git a/ipa-core/src/net/server/handlers/query/results.rs b/ipa-core/src/net/server/handlers/query/results.rs index abd77b947..1c359b659 100644 --- a/ipa-core/src/net/server/handlers/query/results.rs +++ b/ipa-core/src/net/server/handlers/query/results.rs @@ -2,31 +2,29 @@ use axum::{extract::Path, routing::get, Extension, Router}; use hyper::StatusCode; use crate::{ - helpers::{BodyStream, Transport}, + helpers::BodyStream, net::{ http_serde::{self, query::results::Request}, server::Error, - HttpTransport, + transport::MpcHttpTransport, }, protocol::QueryId, - sync::Arc, }; /// Handles the completion of the query by blocking the sender until query is completed. async fn handler( - transport: Extension>, + transport: Extension, Path(query_id): Path, ) -> Result, Error> { let req = Request { query_id }; // TODO: we may be able to stream the response - let transport = Transport::clone_ref(&*transport); match transport.dispatch(req, BodyStream::empty()).await { Ok(resp) => Ok(resp.into_body()), Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), } } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(http_serde::query::results::AXUM_PATH, get(handler)) .layer(Extension(transport)) diff --git a/ipa-core/src/net/server/handlers/query/status.rs b/ipa-core/src/net/server/handlers/query/status.rs index dcd4e1c62..0056b76d0 100644 --- a/ipa-core/src/net/server/handlers/query/status.rs +++ b/ipa-core/src/net/server/handlers/query/status.rs @@ -2,29 +2,27 @@ use axum::{extract::Path, routing::get, Extension, Json, Router}; use hyper::StatusCode; use crate::{ - helpers::{BodyStream, Transport}, + helpers::BodyStream, net::{ http_serde::query::status::{self, Request}, server::Error, - HttpTransport, + transport::MpcHttpTransport, }, protocol::QueryId, - sync::Arc, }; async fn handler( - transport: Extension>, + transport: Extension, Path(query_id): Path, ) -> Result, Error> { let req = Request { query_id }; - let transport = Transport::clone_ref(&*transport); match transport.dispatch(req, BodyStream::empty()).await { Ok(state) => Ok(Json(status::ResponseBody::from(state))), Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), } } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(status::AXUM_PATH, get(handler)) .layer(Extension(transport)) diff --git a/ipa-core/src/net/server/handlers/query/step.rs b/ipa-core/src/net/server/handlers/query/step.rs index b0b5d06ba..2c112be92 100644 --- a/ipa-core/src/net/server/handlers/query/step.rs +++ b/ipa-core/src/net/server/handlers/query/step.rs @@ -1,30 +1,28 @@ use axum::{extract::Path, routing::post, Extension, Router}; use crate::{ - helpers::{BodyStream, HelperIdentity, Transport}, + helpers::{BodyStream, HelperIdentity}, net::{ http_serde, server::{ClientIdentity, Error}, - HttpTransport, + transport::MpcHttpTransport, }, protocol::{Gate, QueryId}, - sync::Arc, }; #[allow(clippy::unused_async)] // axum doesn't like synchronous handler #[tracing::instrument(level = "trace", "step", skip_all, fields(from = ?**from, gate = ?gate))] async fn handler( - transport: Extension>, + transport: Extension, from: Extension>, Path((query_id, gate)): Path<(QueryId, Gate)>, body: BodyStream, ) -> Result<(), Error> { - let transport = Transport::clone_ref(&*transport); transport.receive_stream(query_id, gate, **from, body); Ok(()) } -pub fn router(transport: Arc) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() .route(http_serde::query::step::AXUM_PATH, post(handler)) .layer(Extension(transport)) @@ -41,7 +39,7 @@ mod tests { use super::*; use crate::{ - helpers::{HelperIdentity, MESSAGE_PAYLOAD_SIZE_BYTES}, + helpers::{HelperIdentity, Transport, MESSAGE_PAYLOAD_SIZE_BYTES}, net::{ server::handlers::query::test_helpers::{assert_fails_with, MaybeExtensionExt}, test::TestServer, @@ -65,7 +63,8 @@ mod tests { test_server.server.handle_req(req.into()).await; - let mut stream = Arc::clone(&test_server.transport) + let mut stream = test_server + .transport .receive(HelperIdentity::TWO, (QueryId, step)) .into_bytes_stream(); diff --git a/ipa-core/src/net/server/mod.rs b/ipa-core/src/net/server/mod.rs index c5c996f08..f0229f4e5 100644 --- a/ipa-core/src/net/server/mod.rs +++ b/ipa-core/src/net/server/mod.rs @@ -39,6 +39,10 @@ use tower::{layer::layer_fn, Service}; use tower_http::trace::TraceLayer; use tracing::{error, Span}; +use super::{ + transport::{MpcHttpTransport, ShardHttpTransport}, + Shard, +}; use crate::{ config::{ NetworkConfig, OwnedCertificate, OwnedPrivateKey, PeerConfig, ServerConfig, TlsConfig, @@ -48,7 +52,7 @@ use crate::{ helpers::TransportIdentity, net::{ parse_certificate_and_private_key_bytes, server::config::HttpServerConfig, - ConnectionFlavor, Error, Helper, HttpTransport, CRYPTO_PROVIDER, + ConnectionFlavor, Error, Helper, CRYPTO_PROVIDER, }, sync::Arc, telemetry::metrics::{web::RequestProtocolVersion, REQUESTS_RECEIVED}, @@ -89,12 +93,13 @@ pub struct MpcHelperServer { } impl MpcHelperServer { + #[must_use] pub fn new_mpc( - transport: Arc, + transport: &MpcHttpTransport, config: ServerConfig, network_config: NetworkConfig, ) -> Self { - let router = handlers::mpc_router(transport); + let router = handlers::mpc_router(transport.clone()); MpcHelperServer { config, network_config, @@ -103,6 +108,21 @@ impl MpcHelperServer { } } +impl MpcHelperServer { + #[must_use] + pub fn new_shards( + _transport: &ShardHttpTransport, + config: ServerConfig, + network_config: NetworkConfig, + ) -> Self { + MpcHelperServer { + config, + network_config, + router: Router::new(), + } + } +} + impl MpcHelperServer { #[cfg(all(test, unit_test))] async fn handle_req(&self, req: hyper::Request) -> axum::response::Response { diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index bfa29f308..fa6e95ed8 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -16,6 +16,7 @@ use std::{ use once_cell::sync::Lazy; use rustls_pki_types::CertificateDer; +use super::transport::MpcHttpTransport; use crate::{ config::{ ClientConfig, HpkeClientConfig, HpkeServerConfig, NetworkConfig, PeerConfig, ServerConfig, @@ -24,7 +25,7 @@ use crate::{ executor::{IpaJoinHandle, IpaRuntime}, helpers::{HandlerBox, HelperIdentity, RequestHandler}, hpke::{Deserializable as _, IpaPublicKey}, - net::{ClientIdentity, Helper, HttpTransport, MpcHelperClient, MpcHelperServer}, + net::{ClientIdentity, Helper, MpcHelperClient, MpcHelperServer}, sync::Arc, test_fixture::metrics::MetricsHandle, }; @@ -199,7 +200,7 @@ impl TestConfigBuilder { pub struct TestServer { pub addr: SocketAddr, pub handle: IpaJoinHandle<()>, - pub transport: Arc, + pub transport: MpcHttpTransport, pub server: MpcHelperServer, pub client: MpcHelperClient, pub request_handler: Option>>, @@ -295,7 +296,7 @@ impl TestServerBuilder { ); let handler = self.handler.as_ref().map(HandlerBox::owning_ref); let client = clients[0].clone(); - let (transport, server) = HttpTransport::new( + let (transport, server) = MpcHttpTransport::new( IpaRuntime::current(), HelperIdentity::ONE, server_config, diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index f0cf4c477..eecff2b92 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -1,5 +1,6 @@ use std::{ borrow::Borrow, + collections::HashMap, future::Future, pin::Pin, task::{Context, Poll}, @@ -9,7 +10,7 @@ use async_trait::async_trait; use futures::{Stream, TryFutureExt}; use pin_project::{pin_project, pinned_drop}; -use super::{client::resp_ok, Helper}; +use super::{client::resp_ok, ConnectionFlavor, Helper, Shard}; use crate::{ config::{NetworkConfig, ServerConfig}, executor::IpaRuntime, @@ -26,21 +27,27 @@ use crate::{ sync::Arc, }; -/// HTTP transport for IPA helper service. -/// TODO: rename to MPC -pub struct HttpTransport { +/// Shared implementation used by [`MpcHttpTransport`] and [`ShardHttpTransport`] +pub struct HttpTransport { http_runtime: IpaRuntime, - identity: HelperIdentity, - clients: [MpcHelperClient; 3], - // TODO(615): supporting multiple queries likely require a hashmap here. It will be ok if we - // only allow one query at a time. - record_streams: StreamCollection, - handler: Option, + identity: F::Identity, + clients: HashMap>, + record_streams: StreamCollection, + handler: Option>, } -/// A stub for HTTP transport implementation, suitable for serviing inter-shard traffic -#[derive(Clone, Default)] -pub struct HttpShardTransport; +/// HTTP transport for helper to helper traffic. +#[derive(Clone)] +pub struct MpcHttpTransport { + inner_transport: Arc>, +} + +/// A stub for HTTP transport implementation, suitable for serving shard-to-shard traffic +#[derive(Clone)] +pub struct ShardHttpTransport { + #[allow(dead_code)] + inner_transport: Arc>, +} impl RouteParams for QueryConfig { type Params = String; @@ -62,35 +69,64 @@ impl RouteParams for QueryConfig { } } -impl HttpTransport { - #[must_use] - pub fn new( - runtime: IpaRuntime, - identity: HelperIdentity, - server_config: ServerConfig, - network_config: NetworkConfig, - clients: [MpcHelperClient; 3], - handler: Option, - ) -> (Arc, MpcHelperServer) { - let transport = Self::new_internal(runtime, identity, clients, handler); - let server = - MpcHelperServer::new_mpc(Arc::clone(&transport), server_config, network_config); - (transport, server) +impl HttpTransport { + async fn send< + D: Stream> + Send + 'static, + Q: QueryIdBinding, + S: StepBinding, + R: RouteParams, + >( + &self, + dest: F::Identity, + route: R, + data: D, + ) -> Result<(), Error> + where + Option: From, + Option: From, + { + let route_id = route.resource_identifier(); + match route_id { + RouteId::Records => { + // TODO(600): These fallible extractions aren't really necessary. + let query_id = >::from(route.query_id()) + .expect("query_id required when sending records"); + let step = + >::from(route.gate()).expect("step required when sending records"); + let resp_future = self.clients[&dest].step(query_id, &step, data)?; + // Use a dedicated HTTP runtime to poll this future for several reasons: + // - avoid blocking this task, if the current runtime is overloaded + // - use the runtime that enables IO (current runtime may not). + self.http_runtime + .spawn(resp_future.map_err(Into::into).and_then(resp_ok)) + .await?; + Ok(()) + } + RouteId::PrepareQuery => { + let req = serde_json::from_str(route.extra().borrow()).unwrap(); + self.clients[&dest].prepare_query(req).await + } + evt @ (RouteId::QueryInput + | RouteId::ReceiveQuery + | RouteId::QueryStatus + | RouteId::CompleteQuery + | RouteId::KillQuery) => { + unimplemented!( + "attempting to send client-specific request {evt:?} to another helper" + ) + } + } } - fn new_internal( - runtime: IpaRuntime, - identity: HelperIdentity, - clients: [MpcHelperClient; 3], - handler: Option, - ) -> Arc { - Arc::new(Self { - http_runtime: runtime, - identity, - clients, - handler, - record_streams: StreamCollection::default(), - }) + fn receive>( + &self, + from: F::Identity, + route: &R, + ) -> ReceiveRecords { + ReceiveRecords::new( + (route.query_id(), from, route.gate()), + self.record_streams.clone(), + ) } /// Dispatches the given request to the [`RequestHandler`] connected to this transport. @@ -114,13 +150,13 @@ impl HttpTransport { /// This implementation is a poor man's safety net and only works because we run /// one query at a time and don't use query identifiers. #[pin_project(PinnedDrop)] - struct ClearOnDrop { - transport: Arc, + struct ClearOnDrop { + transport: Arc>, #[pin] inner: F, } - impl Future for ClearOnDrop { + impl Future for ClearOnDrop { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -129,7 +165,7 @@ impl HttpTransport { } #[pinned_drop] - impl PinnedDrop for ClearOnDrop { + impl PinnedDrop for ClearOnDrop { fn drop(self: Pin<&mut Self>) { self.transport.record_streams.clear(); } @@ -152,30 +188,88 @@ impl HttpTransport { r.await } } +} + +impl MpcHttpTransport { + #[must_use] + pub fn new( + http_runtime: IpaRuntime, + identity: HelperIdentity, + server_config: ServerConfig, + network_config: NetworkConfig, + clients: [MpcHelperClient; 3], + handler: Option>, + ) -> (Self, MpcHelperServer) { + let transport = Self::new_internal(http_runtime, identity, clients, handler); + let server = MpcHelperServer::new_mpc(&transport, server_config, network_config); + (transport, server) + } + + fn new_internal( + http_runtime: IpaRuntime, + identity: HelperIdentity, + clients: [MpcHelperClient; 3], + handler: Option>, + ) -> Self { + let mut id_clients = HashMap::new(); + let [c1, c2, c3] = clients; + id_clients.insert(HelperIdentity::ONE, c1); + id_clients.insert(HelperIdentity::TWO, c2); + id_clients.insert(HelperIdentity::THREE, c3); + Self { + inner_transport: Arc::new(HttpTransport { + http_runtime, + identity, + clients: id_clients, + handler, + record_streams: StreamCollection::default(), + }), + } + } - /// Connect an inbound stream of MPC record data. + /// Connect an inbound stream of record data. /// /// This is called by peer helpers via the HTTP server. pub fn receive_stream( - self: Arc, + &self, query_id: QueryId, gate: Gate, from: HelperIdentity, stream: BodyStream, ) { - self.record_streams + self.inner_transport + .record_streams .add_stream((query_id, from, gate), stream); } + + /// Dispatches the given request to the [`RequestHandler`] connected to this transport. + /// + /// ## Errors + /// Returns an error, if handler rejects the request for any reason. + /// + /// ## Panics + /// This will panic if request handler hasn't been previously set for this transport. + pub async fn dispatch>( + &self, + req: R, + body: BodyStream, + ) -> Result + where + Option: From, + { + let t = Arc::clone(&self.inner_transport); + t.dispatch(req, body).await + } } #[async_trait] -impl Transport for Arc { +impl Transport for MpcHttpTransport { type Identity = HelperIdentity; - type RecordsStream = ReceiveRecords; + type RecordsStream = ReceiveRecords; type Error = Error; - fn identity(&self) -> HelperIdentity { - self.identity + fn identity(&self) -> Self::Identity { + self.inner_transport.identity } async fn send< @@ -185,7 +279,7 @@ impl Transport for Arc { R: RouteParams, >( &self, - dest: HelperIdentity, + dest: Self::Identity, route: R, data: D, ) -> Result<(), Error> @@ -193,67 +287,70 @@ impl Transport for Arc { Option: From, Option: From, { - let route_id = route.resource_identifier(); - match route_id { - RouteId::Records => { - // TODO(600): These fallible extractions aren't really necessary. - let query_id = >::from(route.query_id()) - .expect("query_id required when sending records"); - let step = - >::from(route.gate()).expect("step required when sending records"); - let resp_future = self.clients[dest].step(query_id, &step, data)?; - - // Use a dedicated HTTP runtime to poll this future for several reasons: - // - avoid blocking this task, if the current runtime is overloaded - // - use the runtime that enables IO (current runtime may not). - self.http_runtime - .spawn(resp_future.map_err(Into::into).and_then(resp_ok)) - .await?; - Ok(()) - } - RouteId::PrepareQuery => { - let req = serde_json::from_str(route.extra().borrow()).unwrap(); - self.clients[dest].prepare_query(req).await - } - evt @ (RouteId::QueryInput - | RouteId::ReceiveQuery - | RouteId::QueryStatus - | RouteId::CompleteQuery - | RouteId::KillQuery) => { - unimplemented!( - "attempting to send client-specific request {evt:?} to another helper" - ) - } - } + self.inner_transport.send(dest, route, data).await } fn receive>( &self, - from: HelperIdentity, + from: Self::Identity, route: R, ) -> Self::RecordsStream { - ReceiveRecords::new( - (route.query_id(), from, route.gate()), - self.record_streams.clone(), - ) + self.inner_transport.receive(from, &route) + } +} + +impl ShardHttpTransport { + #[must_use] + pub fn new( + http_runtime: IpaRuntime, + identity: ShardIndex, + server_config: ServerConfig, + network_config: NetworkConfig, + clients: HashMap>, + handler: Option>, + ) -> (Self, MpcHelperServer) { + let transport = Self::new_internal(http_runtime, identity, clients, handler); + let server = MpcHelperServer::new_shards(&transport, server_config, network_config); + (transport, server) + } + + fn new_internal( + http_runtime: IpaRuntime, + identity: ShardIndex, + clients: HashMap>, + handler: Option>, + ) -> Self { + let mut base_clients = HashMap::new(); + for (ix, client) in clients { + base_clients.insert(ix, client); + } + Self { + inner_transport: Arc::new(HttpTransport { + http_runtime, + identity, + clients: base_clients, + handler, + record_streams: StreamCollection::default(), + }), + } } } #[async_trait] -impl Transport for HttpShardTransport { +impl Transport for ShardHttpTransport { type Identity = ShardIndex; type RecordsStream = ReceiveRecords; - type Error = (); + type Error = Error; fn identity(&self) -> Self::Identity { - unimplemented!() + self.inner_transport.identity } async fn send( &self, - _dest: Self::Identity, - _route: R, - _data: D, + dest: Self::Identity, + route: R, + data: D, ) -> Result<(), Self::Error> where Option: From, @@ -263,15 +360,15 @@ impl Transport for HttpShardTransport { R: RouteParams, D: Stream> + Send + 'static, { - unimplemented!() + self.inner_transport.send(dest, route, data).await } fn receive>( &self, - _from: Self::Identity, - _route: R, + from: Self::Identity, + route: R, ) -> Self::RecordsStream { - unimplemented!() + self.inner_transport.receive(from, &route) } } @@ -319,18 +416,18 @@ mod tests { .build() .await; - transport.record_streams.add_stream( + transport.inner_transport.record_streams.add_stream( (QueryId, HelperIdentity::ONE, Gate::default()), BodyStream::empty(), ); - assert_eq!(1, transport.record_streams.len()); + assert_eq!(1, transport.inner_transport.record_streams.len()); Transport::clone_ref(&transport) .dispatch((RouteId::KillQuery, QueryId), BodyStream::empty()) .await .unwrap(); - assert!(transport.record_streams.is_empty()); + assert!(transport.inner_transport.record_streams.is_empty()); } #[tokio::test] @@ -344,10 +441,10 @@ mod tests { let body = BodyStream::from_bytes_stream(ReceiverStream::new(rx)); // Register the stream with the transport (normally called by step data HTTP API handler) - Arc::clone(&transport).receive_stream(QueryId, STEP.clone(), HelperIdentity::TWO, body); + transport.receive_stream(QueryId, STEP.clone(), HelperIdentity::TWO, body); // Request step data reception (normally called by protocol) - let mut stream = Arc::clone(&transport) + let mut stream = transport .receive(HelperIdentity::TWO, (QueryId, STEP.clone())) .into_bytes_stream(); @@ -396,19 +493,34 @@ mod tests { network_config, &identity, ); - let (transport, server) = HttpTransport::new( + let (transport, server) = MpcHttpTransport::new( IpaRuntime::current(), id, - server_config, + server_config.clone(), network_config.clone(), clients, Some(handler), ); + // TODO: Following is just temporary until Shard Transport is actually used. + let shard_clients_config = network_config.client.clone(); + let shard_server_config = server_config; + let shard_network_config = + NetworkConfig::new_shards(vec![], shard_clients_config); + let (shard_transport, _shard_server) = ShardHttpTransport::new( + IpaRuntime::current(), + ShardIndex::FIRST, + shard_server_config, + shard_network_config, + HashMap::new(), + None, + ); + // --- + server .start_on(&IpaRuntime::current(), Some(socket), ()) .await; - setup.connect(transport, HttpShardTransport) + setup.connect(transport, shard_transport) }, ), ) From 8d77817facf4d28ac5ab178d0e010b15bc437221 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Mon, 21 Oct 2024 17:04:39 -0700 Subject: [PATCH 170/191] Clippy --- ipa-core/src/bin/helper.rs | 2 +- ipa-core/src/net/transport.rs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index e11074cd8..e309720ea 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -18,7 +18,7 @@ use ipa_core::{ error::BoxError, executor::IpaRuntime, helpers::HelperIdentity, - net::{ClientIdentity, HttpTransport, MpcHelperClient, MpcHttpTransport, ShardHttpTransport}, + net::{ClientIdentity, MpcHelperClient, MpcHttpTransport, ShardHttpTransport}, sharding::ShardIndex, AppConfig, AppSetup, NonZeroU32PowerOfTwo, }; diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index eecff2b92..28bf1aac8 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -45,7 +45,6 @@ pub struct MpcHttpTransport { /// A stub for HTTP transport implementation, suitable for serving shard-to-shard traffic #[derive(Clone)] pub struct ShardHttpTransport { - #[allow(dead_code)] inner_transport: Arc>, } From 45cadd29486140f79c4917594a910a703715f133 Mon Sep 17 00:00:00 2001 From: Thomas James Yurek Date: Mon, 21 Oct 2024 18:03:05 -0700 Subject: [PATCH 171/191] final changes --- ipa-core/src/report/hybrid.rs | 71 +++++++++++++++-------------------- ipa-core/src/report/ipa.rs | 45 +++++++--------------- 2 files changed, 43 insertions(+), 73 deletions(-) diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index 7afbee256..0956495ae 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -1,7 +1,6 @@ use std::{ collections::HashSet, convert::Infallible, - fmt::{Display, Formatter}, marker::PhantomData, ops::{Add, Deref}, }; @@ -31,24 +30,13 @@ use crate::{ // TODO(679): This needs to come from configuration. static HELPER_ORIGIN: &str = "github.com/private-attribution"; -#[derive(Debug)] -pub struct NonAsciiStringError { - input: String, -} - -impl Display for NonAsciiStringError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "string contains non-ascii symbols: {}", self.input) - } -} - -impl std::error::Error for NonAsciiStringError {} +#[derive(Debug, thiserror::Error)] +#[error("string contains non-ascii symbols: {0}")] +pub struct NonAsciiStringError(String); impl From<&'_ str> for NonAsciiStringError { fn from(input: &str) -> Self { - Self { - input: input.to_string(), - } + Self(input.to_string()) } } @@ -153,17 +141,17 @@ where self.breakdown_key .serialize(GenericArray::from_mut_slice(&mut plaintext_btt[..])); + let pk = key_registry.public_key(key_id).ok_or(CryptError::NoSuchKey(key_id))?; + let (encap_key_mk, ciphertext_mk, tag_mk) = seal_in_place( - key_registry.public_key(key_id) - .ok_or(CryptError::NoSuchKey(key_id))?, + pk, plaintext_mk.as_mut(), &info.to_bytes(), rng, )?; let (encap_key_btt, ciphertext_btt, tag_btt) = seal_in_place( - key_registry.public_key(key_id) - .ok_or(CryptError::NoSuchKey(key_id))?, + pk, plaintext_btt.as_mut(), &info.to_bytes(), rng, @@ -299,29 +287,16 @@ where let mut ct_mk: GenericArray = *GenericArray::from_slice(self.mk_ciphertext()); - let plaintext_mk = open_in_place( - key_registry - .private_key(self.key_id()) - .ok_or(CryptError::NoSuchKey(self.key_id()))?, - self.encap_key_mk(), - &mut ct_mk, - &info.to_bytes(), - )?; + let sk = key_registry + .private_key(self.key_id()) + .ok_or(CryptError::NoSuchKey(self.key_id()))?; + let plaintext_mk = open_in_place(sk, self.encap_key_mk(), &mut ct_mk, &info.to_bytes())?; let mut ct_btt: GenericArray> = GenericArray::from_slice(self.btt_ciphertext()).clone(); - let plaintext_btt = open_in_place( - key_registry - .private_key(self.key_id()) - .ok_or(CryptError::NoSuchKey(self.key_id()))?, - self.encap_key_btt(), - &mut ct_btt, - &info.to_bytes(), - )?; + let plaintext_btt = open_in_place(sk, self.encap_key_btt(), &mut ct_btt, &info.to_bytes())?; Ok(HybridImpressionReport:: { - //match_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_mk)) - // .unwrap_infallible(), match_key: Replicated::::deserialize_infallible(GenericArray::from_slice( plaintext_mk, )), @@ -660,8 +635,11 @@ mod test { ) .unwrap(); assert_eq!(hybrid_impression_report, hybrid_impression_report2); + } - let hybrid_report3 = HybridImpressionReport::::deserialize(GenericArray::from_slice( + #[test] + fn deserialzation_from_constant() { + let hybrid_report = HybridImpressionReport::::deserialize(GenericArray::from_slice( &hex::decode("4123a6e38ef1d6d9785c948797cb744d38f4").unwrap(), )) .unwrap(); @@ -676,12 +654,23 @@ mod test { .unwrap(); assert_eq!( - hybrid_report3, + hybrid_report, HybridImpressionReport:: { match_key, breakdown_key } ); + + let mut hybrid_impression_report_bytes = + [0u8; as Serializable>::Size::USIZE]; + hybrid_report.serialize(GenericArray::from_mut_slice( + &mut hybrid_impression_report_bytes[..], + )); + + assert_eq!( + hybrid_impression_report_bytes.to_vec(), + hex::decode("4123a6e38ef1d6d9785c948797cb744d38f4").unwrap() + ); } #[test] @@ -714,6 +703,6 @@ mod test { fn non_ascii_string() { let non_ascii_string = "☃️☃️☃️"; let err = HybridImpressionInfo::new(0, non_ascii_string).unwrap_err(); - assert!(matches!(err, NonAsciiStringError { input: _ })); + assert!(matches!(err, NonAsciiStringError(_))); } } diff --git a/ipa-core/src/report/ipa.rs b/ipa-core/src/report/ipa.rs index 6f8d6e865..cfaf4349f 100644 --- a/ipa-core/src/report/ipa.rs +++ b/ipa-core/src/report/ipa.rs @@ -407,25 +407,14 @@ where let mut ct_mk: GenericArray = *GenericArray::from_slice(self.mk_ciphertext()); - let plaintext_mk = open_in_place( - key_registry - .private_key(self.key_id()) - .ok_or(CryptError::NoSuchKey(self.key_id()))?, - self.encap_key_mk(), - &mut ct_mk, - &info.to_bytes(), - )?; + let sk = key_registry + .private_key(self.key_id()) + .ok_or(CryptError::NoSuchKey(self.key_id()))?; + let plaintext_mk = open_in_place(sk, self.encap_key_mk(), &mut ct_mk, &info.to_bytes())?; let mut ct_btt: GenericArray> = GenericArray::from_slice(self.btt_ciphertext()).clone(); - let plaintext_btt = open_in_place( - key_registry - .private_key(self.key_id()) - .ok_or(CryptError::NoSuchKey(self.key_id()))?, - self.encap_key_btt(), - &mut ct_btt, - &info.to_bytes(), - )?; + let plaintext_btt = open_in_place(sk, self.encap_key_btt(), &mut ct_btt, &info.to_bytes())?; Ok(OprfReport:: { timestamp: Replicated::::deserialize(GenericArray::from_slice( @@ -591,23 +580,15 @@ where ..(Self::TV_OFFSET + as Serializable>::Size::USIZE)], )); - let (encap_key_mk, ciphertext_mk, tag_mk) = seal_in_place( - key_registry - .public_key(key_id) - .ok_or(CryptError::NoSuchKey(key_id))?, - plaintext_mk.as_mut(), - &info.to_bytes(), - rng, - )?; + let pk = key_registry + .public_key(key_id) + .ok_or(CryptError::NoSuchKey(key_id))?; - let (encap_key_btt, ciphertext_btt, tag_btt) = seal_in_place( - key_registry - .public_key(key_id) - .ok_or(CryptError::NoSuchKey(key_id))?, - plaintext_btt.as_mut(), - &info.to_bytes(), - rng, - )?; + let (encap_key_mk, ciphertext_mk, tag_mk) = + seal_in_place(pk, plaintext_mk.as_mut(), &info.to_bytes(), rng)?; + + let (encap_key_btt, ciphertext_btt, tag_btt) = + seal_in_place(pk, plaintext_btt.as_mut(), &info.to_bytes(), rng)?; out.put_slice(&encap_key_mk.to_bytes()); out.put_slice(ciphertext_mk); From 008cee361ed1cbdd614df7542a0e2090a5ea846b Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 22 Oct 2024 09:59:07 -0700 Subject: [PATCH 172/191] Feedback --- ipa-metrics/src/context.rs | 14 +++++++------- ipa-metrics/src/key.rs | 18 ++++++------------ ipa-metrics/src/label.rs | 14 ++++++++------ ipa-metrics/src/lib.rs | 26 +++++++++++++++++++++++++- ipa-metrics/src/store.rs | 7 ++++--- 5 files changed, 50 insertions(+), 29 deletions(-) diff --git a/ipa-metrics/src/context.rs b/ipa-metrics/src/context.rs index 2020f14fd..938d4560b 100644 --- a/ipa-metrics/src/context.rs +++ b/ipa-metrics/src/context.rs @@ -84,21 +84,21 @@ impl MetricsContext { &mut self.store } - fn is_connected(&self) -> bool { - self.tx.is_some() - } - fn flush(&mut self) { if self.store.is_empty() { return; } - if self.is_connected() { + if let Some(tx) = self.tx.as_ref() { let store = mem::take(&mut self.store); - match self.tx.as_ref().unwrap().send(store) { + match tx.send(store) { Ok(()) => {} Err(e) => { - tracing::warn!("MetricsContext is not connected: {e}"); + // Note that the store is dropped at this point. + // If it becomes a problem with collector threads disconnecting + // somewhat randomly, we can keep the old store around + // and clone it when sending. + tracing::warn!("MetricsContext is disconnected from the collector: {e}"); } } } else { diff --git a/ipa-metrics/src/key.rs b/ipa-metrics/src/key.rs index eadb5e392..8f01ea2f4 100644 --- a/ipa-metrics/src/key.rs +++ b/ipa-metrics/src/key.rs @@ -112,7 +112,7 @@ impl<'lv, const LABELS: usize> Name<'lv, LABELS> { /// Same as [`Name`], but intended for internal use. This is an owned /// version of it, that does not borrow anything from outside. /// This is the key inside metric stores which are simple hashmaps. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq)] pub struct OwnedName { key: &'static str, labels: [Option; 5], @@ -140,7 +140,9 @@ impl OwnedName { impl Hash for Name<'_, LABELS> { fn hash(&self, state: &mut H) { - state.write(self.key.as_bytes()); + Hash::hash(&self.key, state); + // to be consistent with `OwnedName` hashing, we need to + // serialize labels without slice length prefix. for label in &self.labels { label.hash(state); } @@ -186,20 +188,13 @@ impl<'a, const LABELS: usize> PartialEq> for OwnedName { impl PartialEq for OwnedName { fn eq(&self, other: &OwnedName) -> bool { - self.key == other.key - && iter::zip(&self.labels, &other.labels).all(|(a, b)| match (a, b) { - (Some(a), Some(b)) => a == b, - (None, None) => true, - _ => false, - }) + self.key == other.key && self.labels.eq(&other.labels) } } -impl Eq for OwnedName {} - impl Hash for OwnedName { fn hash(&self, state: &mut H) { - state.write(self.key.as_bytes()); + Hash::hash(self.key, state); for label in self.labels.iter().flatten() { label.hash(state); } @@ -216,7 +211,6 @@ pub fn compute_hash(value: V) -> u64 { #[cfg(test)] mod tests { - use crate::{ key::{compute_hash, Name}, label::Label, diff --git a/ipa-metrics/src/label.rs b/ipa-metrics/src/label.rs index 9ee414e26..b4e37a704 100644 --- a/ipa-metrics/src/label.rs +++ b/ipa-metrics/src/label.rs @@ -3,16 +3,17 @@ use std::{ hash::{Hash, Hasher}, }; -pub use Value as LabelValue; - pub const MAX_LABELS: usize = 5; /// Dimension value (or label value) must be sendable to another thread /// and there must be a way to show it -pub trait Value: Display + Send { +pub trait LabelValue: Display + Send { /// Creates a unique hash for this value. /// It is easy to create collisions, so better avoid them, /// by assigning a unique integer to each value + /// + /// Note that this value is used for uniqueness check inside + /// metric stores fn hash(&self) -> u64; /// Creates an owned copy of this value. Dynamic dispatch @@ -31,10 +32,9 @@ impl LabelValue for u32 { } } -#[derive()] pub struct Label<'lv> { pub name: &'static str, - pub val: &'lv dyn Value, + pub val: &'lv dyn LabelValue, } impl Label<'_> { @@ -76,7 +76,7 @@ impl PartialEq for Label<'_> { /// inside metric hashmaps as they need to own the keys. pub struct OwnedLabel { pub name: &'static str, - pub val: Box, + pub val: Box, } impl Clone for OwnedLabel { @@ -118,6 +118,8 @@ impl PartialEq for OwnedLabel { } } +impl Eq for OwnedLabel {} + #[cfg(test)] mod tests { diff --git a/ipa-metrics/src/lib.rs b/ipa-metrics/src/lib.rs index 2ee9d5be0..f84f8dc1c 100644 --- a/ipa-metrics/src/lib.rs +++ b/ipa-metrics/src/lib.rs @@ -32,6 +32,30 @@ pub use producer::Producer as MetricsProducer; #[cfg(not(feature = "partitions"))] pub use store::Store as MetricsStore; +/// Creates metric infrastructure that is ready to use +/// in the application code. It consists a triple of +/// [`MetricsCollector`], [`MetricsProducer`], and +/// [`MetricsCollectorController`]. +/// +/// Collector is used in the centralized place (a dedicated thread) +/// to collect metrics coming from thread local stores. +/// +/// Metric producer must be installed on every thread that is used +/// to emit telemetry, and it connects that thread to the collector. +/// +/// Controller provides command-line API interface to the collector. +/// A thread that owns the controller, can request current snapshot. +/// For more information about API, see [`Command`]. +/// +/// ## Example +/// ```rust +/// let (collector, producer, controller) = ipa_metrics::install(); +/// ``` +/// +/// [`MetricsCollector`]: crate::MetricsCollector +/// [`MetricsProducer`]: crate::MetricsProducer +/// [`MetricsCollectorController`]: crate::MetricsCollectorController +/// [`Command`]: crate::ControllerCommand #[must_use] pub fn install() -> ( MetricsCollector, @@ -51,7 +75,7 @@ pub fn install() -> ( ) } -/// Same as [`installer]` but spawns a new thread to run the collector. +/// Same as [`install`] but spawns a new thread to run the collector. /// /// ## Errors /// if thread cannot be started diff --git a/ipa-metrics/src/store.rs b/ipa-metrics/src/store.rs index 58160ed4c..501b875a2 100644 --- a/ipa-metrics/src/store.rs +++ b/ipa-metrics/src/store.rs @@ -63,9 +63,10 @@ impl Store { } } - /// Returns the value for the specified metric taking into account - /// its dimensionality. That is (foo, dim1 = 1, dim2 = 2) will be - /// different from (foo, dim1 = 1). + /// Returns the value for the specified metric, limited by any specified dimensions, + /// but not by any unspecified dimensions. If metric foo has dimensions dim1 and dim2, + /// a query for (foo, dim1 = 1) will sum the counter values having dim1 = 1 + /// and any value of dim2. /// The cost of this operation is `O(N*M)` where `N` - number of unique metrics /// registered in this store and `M` number of dimensions. /// From e68e014fc9cf2eaaff4f41177cf95ebfcdb00559 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Tue, 22 Oct 2024 13:45:52 -0700 Subject: [PATCH 173/191] Remove some unneeded dead_code waivers; fix a test --- ipa-core/src/helpers/buffers/unordered_receiver.rs | 1 - ipa-core/src/protocol/dp/mod.rs | 2 +- .../src/protocol/ipa_prf/oprf_padding/insecure.rs | 13 +++++++------ ipa-core/src/protocol/ipa_prf/shuffle/base.rs | 2 -- ipa-core/src/query/state.rs | 1 - 5 files changed, 8 insertions(+), 11 deletions(-) diff --git a/ipa-core/src/helpers/buffers/unordered_receiver.rs b/ipa-core/src/helpers/buffers/unordered_receiver.rs index 92cfbf2e1..3377995cf 100644 --- a/ipa-core/src/helpers/buffers/unordered_receiver.rs +++ b/ipa-core/src/helpers/buffers/unordered_receiver.rs @@ -295,7 +295,6 @@ where inner: Arc>>, } -#[allow(dead_code)] impl UnorderedReceiver where S: Stream + Send, diff --git a/ipa-core/src/protocol/dp/mod.rs b/ipa-core/src/protocol/dp/mod.rs index b2d8b55e1..856073b77 100644 --- a/ipa-core/src/protocol/dp/mod.rs +++ b/ipa-core/src/protocol/dp/mod.rs @@ -553,7 +553,7 @@ fn delta_constraint(num_bernoulli: u32, noise_params: &NoiseParams) -> bool { lhs >= rhs } /// error of mechanism in Thm 1 -#[allow(dead_code)] +#[cfg(all(test, unit_test))] fn error(num_bernoulli: u32, noise_params: &NoiseParams) -> f64 { noise_params.dimensions * noise_params.quantization_scale.powi(2) diff --git a/ipa-core/src/protocol/ipa_prf/oprf_padding/insecure.rs b/ipa-core/src/protocol/ipa_prf/oprf_padding/insecure.rs index b7268900d..d37edfecb 100644 --- a/ipa-core/src/protocol/ipa_prf/oprf_padding/insecure.rs +++ b/ipa-core/src/protocol/ipa_prf/oprf_padding/insecure.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - use std::f64::consts::E; use rand::distributions::{BernoulliError, Distribution}; @@ -77,6 +75,7 @@ impl Dp { }) } + #[cfg(all(test, unit_test))] fn apply(&self, mut input: I, rng: &mut R) where R: RngCore + CryptoRng, @@ -521,16 +520,18 @@ mod test { println!("A sample value equal to {sample} occurred {count} time(s)",); } } + + #[test] fn test_oprf_padding_dp_constructor() { let mut actual = OPRFPaddingDp::new(-1.0, 1e-6, 10); // (epsilon, delta, sensitivity) let mut expected = Err(Error::BadEpsilon(-1.0)); - assert_eq!(expected, Ok(actual)); + assert_eq!(expected, actual); actual = OPRFPaddingDp::new(1.0, -1e-6, 10); // (epsilon, delta, sensitivity) expected = Err(Error::BadDelta(-1e-6)); - assert_eq!(expected, Ok(actual)); - actual = OPRFPaddingDp::new(1.0, -1e-6, 1_000_001); // (epsilon, delta, sensitivity) + assert_eq!(expected, actual); + actual = OPRFPaddingDp::new(1.0, 1e-6, 1_000_001); // (epsilon, delta, sensitivity) expected = Err(Error::BadSensitivity(1_000_001)); - assert_eq!(expected, Ok(actual)); + assert_eq!(expected, actual); } #[test] diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs index d999db229..83343b739 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs @@ -68,7 +68,6 @@ where } } -#[allow(dead_code)] /// This struct stores some intermediate messages during the shuffle. /// In a maliciously secure shuffle, /// these messages need to be checked for consistency across helpers. @@ -79,7 +78,6 @@ pub struct IntermediateShuffleMessages { x2_or_y2: Option>, } -#[allow(dead_code)] impl IntermediateShuffleMessages { /// When `IntermediateShuffleMessages` is initialized correctly, /// this function returns `x1` when `Role = H1` diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index 9d42a0439..460296022 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -19,7 +19,6 @@ use crate::{ /// The status of query processing #[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] -#[allow(dead_code)] pub enum QueryStatus { /// Only query running on the coordinator helper can be in this state. Means that coordinator /// sent out requests to other helpers and asked them to assume a given role for this query. From 7331a1cd442a36f1bcfea607bca70bc7346abb95 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 22 Oct 2024 19:00:09 -0700 Subject: [PATCH 174/191] Fix `Label` Hash implementation --- ipa-metrics/src/label.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ipa-metrics/src/label.rs b/ipa-metrics/src/label.rs index b4e37a704..27da2b116 100644 --- a/ipa-metrics/src/label.rs +++ b/ipa-metrics/src/label.rs @@ -58,8 +58,8 @@ impl Debug for Label<'_> { impl Hash for Label<'_> { fn hash(&self, state: &mut H) { - state.write(self.name.as_bytes()); - state.write_u64(self.val.hash()); + Hash::hash(&self.name, state); + Hash::hash(&self.val.hash(), state); } } From 3cea96da24ce9476a54fe3e6a27efdafad89cb89 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 22 Oct 2024 19:38:43 -0700 Subject: [PATCH 175/191] Newline in cargo.toml --- ipa-metrics-tracing/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-metrics-tracing/Cargo.toml b/ipa-metrics-tracing/Cargo.toml index ac7314c19..061bfe4ff 100644 --- a/ipa-metrics-tracing/Cargo.toml +++ b/ipa-metrics-tracing/Cargo.toml @@ -7,4 +7,4 @@ edition = "2021" # requires partitions feature because without it, it does not make sense to use ipa-metrics = { version = "*", path = "../ipa-metrics", features = ["partitions"] } tracing = "0.1" -tracing-subscriber = "0.3" \ No newline at end of file +tracing-subscriber = "0.3" From d358ec0f78cc5567668c31819283c1f7f49bc990 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 23 Oct 2024 11:53:11 -0700 Subject: [PATCH 176/191] Addressing comments --- ipa-core/src/bin/helper.rs | 5 +- ipa-core/src/helpers/transport/mod.rs | 19 ++++++++ ipa-core/src/net/test.rs | 2 +- ipa-core/src/net/transport.rs | 67 +++++++++------------------ 4 files changed, 43 insertions(+), 50 deletions(-) diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index e309720ea..7c8190c20 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -1,5 +1,4 @@ use std::{ - collections::HashMap, fs, io::BufReader, net::TcpListener, @@ -177,7 +176,7 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { my_identity, server_config, network_config, - clients, + &clients, Some(handler), ); @@ -188,7 +187,7 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { ShardIndex::FIRST, shard_server_config, shard_network_config, - HashMap::new(), + vec![], None, ); // --- diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index de20a7d90..f72814614 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -50,6 +50,9 @@ pub trait Identity: /// # Errors /// If there where any problems parsing the identity. fn from_str(s: &str) -> Result; + + /// Returns a 0-based index suitable to index Vec or other containers. + fn as_index(&self) -> usize; } impl Identity for ShardIndex { @@ -64,6 +67,10 @@ impl Identity for ShardIndex { }) .map(ShardIndex::from) } + + fn as_index(&self) -> usize { + usize::from(*self) + } } impl Identity for HelperIdentity { fn as_str(&self) -> Cow<'static, str> { @@ -85,6 +92,10 @@ impl Identity for HelperIdentity { ))), } } + + fn as_index(&self) -> usize { + usize::from(self.id) - 1 + } } /// Role is an identifier of helper peer, only valid within a given query. For every query, there @@ -104,6 +115,14 @@ impl Identity for Role { ))), } } + + fn as_index(&self) -> usize { + match self { + Self::H1 => 0, + Self::H2 => 1, + Self::H3 => 2, + } + } } pub trait ResourceIdentifier: Sized {} diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index fa6e95ed8..e62bccce6 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -301,7 +301,7 @@ impl TestServerBuilder { HelperIdentity::ONE, server_config, network_config.clone(), - clients, + &clients, handler, ); let (addr, handle) = server diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 28bf1aac8..a0fcd92a3 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -1,6 +1,5 @@ use std::{ borrow::Borrow, - collections::HashMap, future::Future, pin::Pin, task::{Context, Poll}, @@ -19,7 +18,7 @@ use crate::{ routing::{Addr, RouteId}, ApiError, BodyStream, HandlerRef, HelperIdentity, HelperResponse, NoQueryId, NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords, RequestHandler, RouteParams, - StepBinding, StreamCollection, Transport, + StepBinding, StreamCollection, Transport, TransportIdentity, }, net::{client::MpcHelperClient, error::Error, MpcHelperServer}, protocol::{Gate, QueryId}, @@ -31,7 +30,7 @@ use crate::{ pub struct HttpTransport { http_runtime: IpaRuntime, identity: F::Identity, - clients: HashMap>, + clients: Vec>, record_streams: StreamCollection, handler: Option>, } @@ -85,6 +84,7 @@ impl HttpTransport { Option: From, { let route_id = route.resource_identifier(); + let client_ix = dest.as_index(); match route_id { RouteId::Records => { // TODO(600): These fallible extractions aren't really necessary. @@ -92,7 +92,7 @@ impl HttpTransport { .expect("query_id required when sending records"); let step = >::from(route.gate()).expect("step required when sending records"); - let resp_future = self.clients[&dest].step(query_id, &step, data)?; + let resp_future = self.clients[client_ix].step(query_id, &step, data)?; // Use a dedicated HTTP runtime to poll this future for several reasons: // - avoid blocking this task, if the current runtime is overloaded // - use the runtime that enables IO (current runtime may not). @@ -103,7 +103,7 @@ impl HttpTransport { } RouteId::PrepareQuery => { let req = serde_json::from_str(route.extra().borrow()).unwrap(); - self.clients[&dest].prepare_query(req).await + self.clients[client_ix].prepare_query(req).await } evt @ (RouteId::QueryInput | RouteId::ReceiveQuery @@ -196,34 +196,21 @@ impl MpcHttpTransport { identity: HelperIdentity, server_config: ServerConfig, network_config: NetworkConfig, - clients: [MpcHelperClient; 3], + clients: &[MpcHelperClient; 3], handler: Option>, ) -> (Self, MpcHelperServer) { - let transport = Self::new_internal(http_runtime, identity, clients, handler); - let server = MpcHelperServer::new_mpc(&transport, server_config, network_config); - (transport, server) - } - - fn new_internal( - http_runtime: IpaRuntime, - identity: HelperIdentity, - clients: [MpcHelperClient; 3], - handler: Option>, - ) -> Self { - let mut id_clients = HashMap::new(); - let [c1, c2, c3] = clients; - id_clients.insert(HelperIdentity::ONE, c1); - id_clients.insert(HelperIdentity::TWO, c2); - id_clients.insert(HelperIdentity::THREE, c3); - Self { + let transport = Self { inner_transport: Arc::new(HttpTransport { http_runtime, identity, - clients: id_clients, + clients: clients.to_vec(), handler, record_streams: StreamCollection::default(), }), - } + }; + + let server = MpcHelperServer::new_mpc(&transport, server_config, network_config); + (transport, server) } /// Connect an inbound stream of record data. @@ -305,33 +292,21 @@ impl ShardHttpTransport { identity: ShardIndex, server_config: ServerConfig, network_config: NetworkConfig, - clients: HashMap>, + clients: Vec>, handler: Option>, ) -> (Self, MpcHelperServer) { - let transport = Self::new_internal(http_runtime, identity, clients, handler); - let server = MpcHelperServer::new_shards(&transport, server_config, network_config); - (transport, server) - } - - fn new_internal( - http_runtime: IpaRuntime, - identity: ShardIndex, - clients: HashMap>, - handler: Option>, - ) -> Self { - let mut base_clients = HashMap::new(); - for (ix, client) in clients { - base_clients.insert(ix, client); - } - Self { + let transport = Self { inner_transport: Arc::new(HttpTransport { http_runtime, identity, - clients: base_clients, + clients, handler, record_streams: StreamCollection::default(), }), - } + }; + + let server = MpcHelperServer::new_shards(&transport, server_config, network_config); + (transport, server) } } @@ -497,7 +472,7 @@ mod tests { id, server_config.clone(), network_config.clone(), - clients, + &clients, Some(handler), ); // TODO: Following is just temporary until Shard Transport is actually used. @@ -510,7 +485,7 @@ mod tests { ShardIndex::FIRST, shard_server_config, shard_network_config, - HashMap::new(), + vec![], None, ); // --- From 58d8340ae7a4ef66e0aacf10f4989e12cb27ffed Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 22 Oct 2024 20:41:17 -0700 Subject: [PATCH 177/191] Fix compact gate build errors This fixes #1301 and cleans up assertions --- ipa-core/Cargo.toml | 10 ++++++---- ipa-core/build.rs | 2 +- ipa-core/src/protocol/ipa_prf/mod.rs | 2 +- ipa-core/src/report/hybrid.rs | 9 ++------- scripts/coverage-ci | 3 +-- 5 files changed, 11 insertions(+), 15 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 0081a0a50..3faa2b95e 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -16,11 +16,14 @@ default = [ "stall-detection", "aggregate-circuit", "ipa-prf", - "ipa-step/string-step", + "descriptive-gate", ] cli = ["comfy-table", "clap"] -# Enabling compact gates disables any tests that rely on descriptive gates. -compact-gate = ["ipa-step/string-step"] +# Enable compact gate optimization +compact-gate = [] +# mutually exclusive with compact-gate and disables compact gate optimization. +# It is enabled by default +descriptive-gate = ["ipa-step/string-step"] disable-metrics = [] # TODO move web-app to a separate crate. It adds a lot of build time to people who mostly write protocols # TODO Consider moving out benches as well @@ -82,7 +85,6 @@ ipa-step = { version = "*", path = "../ipa-step" } ipa-step-derive = { version = "*", path = "../ipa-step-derive" } aes = "0.8.3" -assertions = "0.1.0" async-trait = "0.1.79" async-scoped = { version = "0.9.0", features = ["use-tokio"], optional = true } axum = { version = "0.7.5", optional = true, features = ["http2", "macros"] } diff --git a/ipa-core/build.rs b/ipa-core/build.rs index 768dc5040..ce1987c72 100644 --- a/ipa-core/build.rs +++ b/ipa-core/build.rs @@ -44,7 +44,7 @@ fn main() { // https://docs.rs/tectonic_cfg_support/latest/tectonic_cfg_support/struct.TargetConfiguration.html cfg_aliases! { compact_gate: { feature = "compact-gate" }, - descriptive_gate: { not(compact_gate) }, + descriptive_gate: { all(not(feature = "compact-gate"), feature = "descriptive-gate") }, unit_test: { all(not(feature = "shuttle"), feature = "in-memory-infra", descriptive_gate) }, web_test: { all(not(feature = "shuttle"), feature = "real-world-infra") }, } diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 9626731e4..3c6cddb61 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -733,7 +733,7 @@ pub mod tests { } } -#[cfg(all(test, all(feature = "compact-gate", feature = "in-memory-infra")))] +#[cfg(all(test, all(compact_gate, feature = "in-memory-infra")))] mod compact_gate_tests { use ipa_step::{CompactStep, StepNarrow}; diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index 0956495ae..81f1f1b6d 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -5,7 +5,6 @@ use std::{ ops::{Add, Deref}, }; -use assertions::const_assert; use bytes::{BufMut, Bytes}; use generic_array::{ArrayLength, GenericArray}; use hpke::Serializable as _; @@ -13,6 +12,7 @@ use rand_core::{CryptoRng, RngCore}; use typenum::{Sum, Unsigned, U16}; use crate::{ + const_assert_eq, error::{BoxError, Error}, ff::{boolean_array::BA64, Serializable}, hpke::{ @@ -407,14 +407,9 @@ impl UniqueBytes for EncryptedHybridReport { } impl UniqueTag { - fn _compile_check() { - // This will vaild at compile time if TAG_SIZE doesn't match U16 - // the macro expansion needs to be wrapped in a function - const_assert!(TAG_SIZE == 16); - } - // Function to attempt to create a UniqueTag from a UniqueBytes implementor pub fn from_unique_bytes(item: &T) -> Self { + const_assert_eq!(16, TAG_SIZE); UniqueTag { bytes: item.unique_bytes(), } diff --git a/scripts/coverage-ci b/scripts/coverage-ci index c652e9f65..6f79c7599 100755 --- a/scripts/coverage-ci +++ b/scripts/coverage-ci @@ -14,8 +14,7 @@ cargo test --features "cli test-fixture relaxed-dp" # Provide code coverage stats for ipa-metrics crate with partitions enabled cargo test -p ipa-metrics --features "partitions" -# descriptive-gate does not require a feature flag. -for gate in "compact-gate" ""; do +for gate in "compact-gate" "descriptive-gate"; do cargo test --no-default-features --features "cli web-app real-world-infra test-fixture $gate" done From ab75fdee7305cfca112c7710d94cd4957374c173 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Thu, 24 Oct 2024 09:12:39 -0700 Subject: [PATCH 178/191] Indistinguishable reports (#1367) * add IndistinguishableHybridReport struct * implement traits on both Conversion and Impression structs; add tests * convert decrypted hybrid reports into indistinguishable reports * PR feedback and comments --- ipa-core/src/query/runner/hybrid.rs | 11 +- ipa-core/src/report/hybrid.rs | 156 +++++++++++++++++++++++++++- 2 files changed, 162 insertions(+), 5 deletions(-) diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index 7154c066b..bdc5d9791 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -1,4 +1,4 @@ -use std::{marker::PhantomData, sync::Arc}; +use std::{convert::Into, marker::PhantomData, sync::Arc}; use futures::{stream::iter, StreamExt, TryStreamExt}; @@ -12,7 +12,9 @@ use crate::{ hpke::PrivateKeyRegistry, protocol::{context::ShardedContext, hybrid::step::HybridStep, step::ProtocolStep::Hybrid}, query::runner::reshard_tag::reshard_aad, - report::hybrid::{EncryptedHybridReport, UniqueTag, UniqueTagValidator}, + report::hybrid::{ + EncryptedHybridReport, IndistinguishableHybridReport, UniqueTag, UniqueTagValidator, + }, secret_sharing::{replicated::semi_honest::AdditiveShare as ReplicatedShare, SharedValue}, }; @@ -73,7 +75,7 @@ where }) .try_flatten() .take(sz); - let (_decrypted_reports, resharded_tags) = reshard_aad( + let (decrypted_reports, resharded_tags) = reshard_aad( ctx.narrow(&HybridStep::ReshardByTag), stream, |ctx, _, tag| tag.shard_picker(ctx.shard_count()), @@ -87,6 +89,9 @@ where .check_duplicates(&resharded_tags) .unwrap(); + let _indistinguishable_reports: Vec> = + decrypted_reports.into_iter().map(Into::into).collect(); + unimplemented!("query::runnner::HybridQuery.execute is not fully implemented") } } diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index 0956495ae..b13abfe90 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -1,3 +1,32 @@ +//! Provides report types which are aggregated by the Hybrid protocol +//! +//! The `IndistinguishableHybridReport` is the primary data type which each helpers uses +//! to aggreate in the Hybrid protocol. +//! +//! From each Helper's POV, the Report Collector POSTs a length delimited byte +//! stream, which is then processed as follows: +//! +//! `BodyStream` → `EncryptedHybridReport` → `HybridReport` → `IndistinguishableHybridReport` +//! +//! The difference between a `HybridReport` and a `IndistinguishableHybridReport` is that a +//! a `HybridReport` is an `enum` with two possible options: `Impression` and `Conversion`. +//! These two options are implemented as `HybridImpressionReport` and `HybridConversionReport`. +//! A `IndistinguishableHybridReport` contains the union of the fields across +//! `HybridImpressionReport` and `HybridConversionReport`. Those fields are secret sharings, +//! which allows for building a collection of `IndistinguishableHybridReport` which carry +//! the information of the underlying `HybridImpressionReport` and `HybridConversionReport` +//! (and secret sharings of zero in the fields unique to each report type) without the +//! ability to infer if a given report is a `HybridImpressionReport` +//! or a `HybridConversionReport`. + +//! Note: immediately following convertion of a `HybridReport` into a +//! `IndistinguishableHybridReport`, each helper will know which type it was built from, +//! both from the position in the collection as well as the fact that both replicated +//! secret shares for one or more fields are zero. A shuffle is required to delink +//! a `IndistinguishableHybridReport`'s position in a collection, which also rerandomizes +//! all secret sharings (including the sharings of zero), making the collection of reports +//! cryptographically indistinguishable. + use std::{ collections::HashSet, convert::Infallible, @@ -52,6 +81,7 @@ pub enum InvalidHybridReportError { Length(usize, usize), } +/// Reports for impression events are represented here. #[derive(Clone, Debug, Eq, PartialEq)] pub struct HybridImpressionReport where @@ -169,6 +199,7 @@ where } } +/// Reports for conversion events are represented here. #[derive(Clone, Debug, Eq, PartialEq)] pub struct HybridConversionReport where @@ -178,6 +209,7 @@ where value: Replicated, } +/// This enum contains both report types, impression and conversion. #[derive(Clone, Debug, Eq, PartialEq)] pub enum HybridReport where @@ -205,6 +237,9 @@ where } } +/// `HybridImpressionReport`s are encrypted when they arrive to the helpers, +/// which is represented here. A `EncryptedHybridImpressionReport` decrypts +/// into a `HybridImpressionReport`. #[derive(Copy, Clone, Eq, PartialEq)] pub struct EncryptedHybridImpressionReport where @@ -308,6 +343,62 @@ where } } +/// This struct is designed to fit both `HybridConversionReport`s +/// and `HybridImpressionReport`s so that they can be made indistingushable. +/// Note: these need to be shuffled (and secret shares need to be rerandomized) +/// to provide any formal indistinguishability. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct IndistinguishableHybridReport +where + BK: SharedValue, + V: SharedValue, +{ + match_key: Replicated, + value: Replicated, + breakdown_key: Replicated, +} + +impl From> for IndistinguishableHybridReport +where + BK: SharedValue, + V: SharedValue, +{ + fn from(report: HybridReport) -> Self { + match report { + HybridReport::Impression(r) => r.into(), + HybridReport::Conversion(r) => r.into(), + } + } +} + +impl From> for IndistinguishableHybridReport +where + BK: SharedValue, + V: SharedValue, +{ + fn from(impression_report: HybridImpressionReport) -> Self { + Self { + match_key: impression_report.match_key, + value: Replicated::ZERO, + breakdown_key: impression_report.breakdown_key, + } + } +} + +impl From> for IndistinguishableHybridReport +where + BK: SharedValue, + V: SharedValue, +{ + fn from(conversion_report: HybridConversionReport) -> Self { + Self { + match_key: conversion_report.match_key, + value: conversion_report.value, + breakdown_key: Replicated::ZERO, + } + } +} + #[derive(Clone)] pub struct EncryptedHybridReport { bytes: Bytes, @@ -496,8 +587,8 @@ mod test { use super::{ EncryptedHybridImpressionReport, EncryptedHybridReport, GenericArray, - HybridConversionReport, HybridImpressionReport, HybridReport, UniqueTag, - UniqueTagValidator, + HybridConversionReport, HybridImpressionReport, HybridReport, + IndistinguishableHybridReport, UniqueTag, UniqueTagValidator, }; use crate::{ error::Error, @@ -594,6 +685,67 @@ mod test { assert_eq!(hybrid_report, hybrid_report2); } + /// We create a random `HybridConversionReport`, convert it into an + ///`IndistinguishableHybridReport`, and check that the field values are the same + /// (or zero, for the breakdown key, which doesn't exist on the conversion report.) + /// We then build a generic `HybridReport` from the conversion report, convert it + /// into an `IndistingushableHybridReport`, and validate that it has the same value + /// as the previous `IndistingushableHybridReport`. + #[test] + fn convert_hybrid_conversion_report_to_indistinguishable_report() { + let mut rng = thread_rng(); + + let conversion_report = HybridConversionReport:: { + match_key: AdditiveShare::new(rng.gen(), rng.gen()), + value: AdditiveShare::new(rng.gen(), rng.gen()), + }; + let indistinguishable_report: IndistinguishableHybridReport = + conversion_report.clone().into(); + assert_eq!( + conversion_report.match_key, + indistinguishable_report.match_key + ); + assert_eq!(conversion_report.value, indistinguishable_report.value); + assert_eq!(AdditiveShare::ZERO, indistinguishable_report.breakdown_key); + + let hybrid_report = HybridReport::Conversion::(conversion_report.clone()); + let indistinguishable_report2: IndistinguishableHybridReport = + hybrid_report.clone().into(); + assert_eq!(indistinguishable_report, indistinguishable_report2); + } + + /// We create a random `HybridImpressionReport`, convert it into an + ///`IndistinguishableHybridReport`, and check that the field values are the same + /// (or zero, for the value, which doesn't exist on the impression report.) + /// We then build a generic `HybridReport` from the impression report, convert it + /// into an `IndistingushableHybridReport`, and validate that it has the same value + /// as the previous `IndistingushableHybridReport`. + #[test] + fn convert_hybrid_impression_report_to_indistinguishable_report() { + let mut rng = thread_rng(); + + let impression_report = HybridImpressionReport:: { + match_key: AdditiveShare::new(rng.gen(), rng.gen()), + breakdown_key: AdditiveShare::new(rng.gen(), rng.gen()), + }; + let indistinguishable_report: IndistinguishableHybridReport = + impression_report.clone().into(); + assert_eq!( + impression_report.match_key, + indistinguishable_report.match_key + ); + assert_eq!(AdditiveShare::ZERO, indistinguishable_report.value); + assert_eq!( + impression_report.breakdown_key, + indistinguishable_report.breakdown_key + ); + + let hybrid_report = HybridReport::Impression::(impression_report.clone()); + let indistinguishable_report2: IndistinguishableHybridReport = + hybrid_report.clone().into(); + assert_eq!(indistinguishable_report, indistinguishable_report2); + } + #[test] fn unique_encrypted_hybrid_reports() { let tag1 = generate_random_tag(); From 48407bbba2b0bfada99b28377ffa22d64012894f Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 24 Oct 2024 10:49:05 -0700 Subject: [PATCH 179/191] Use record IDs to index proof batches (#1350) * Proof batching for quicksort Fixes #1311 * PR feedback from #1352 * Use record IDs to index proof batches Fixes #1269 * Fix conditional compilation issue * Check max proof size * Revise comment * PR feedback, and rename record_counter --- .clippy.toml | 2 +- ipa-core/src/helpers/hashing.rs | 2 +- ipa-core/src/helpers/mod.rs | 12 + .../src/protocol/context/dzkp_malicious.rs | 5 +- .../src/protocol/context/dzkp_validator.rs | 91 +++++- ipa-core/src/protocol/context/step.rs | 12 - .../ipa_prf/aggregation/breakdown_reveal.rs | 4 +- .../src/protocol/ipa_prf/aggregation/step.rs | 10 +- .../ipa_prf/malicious_security/lagrange.rs | 10 + .../ipa_prf/malicious_security/prover.rs | 32 +- ipa-core/src/protocol/ipa_prf/mod.rs | 4 +- .../src/protocol/ipa_prf/prf_sharding/step.rs | 2 +- ipa-core/src/protocol/ipa_prf/quicksort.rs | 74 ++++- ipa-core/src/protocol/ipa_prf/step.rs | 6 +- .../validation_protocol/proof_generation.rs | 163 +++++++---- .../ipa_prf/validation_protocol/validation.rs | 276 ++++++++++++------ ipa-core/src/protocol/mod.rs | 27 +- 17 files changed, 511 insertions(+), 221 deletions(-) diff --git a/.clippy.toml b/.clippy.toml index 9ed6287e0..800113fde 100644 --- a/.clippy.toml +++ b/.clippy.toml @@ -8,4 +8,4 @@ disallowed-methods = [ { path = "std::vec::Vec::leak", reason = "Not running the destructors on futures created inside seq_join module will cause UB in IPA. Make sure you don't leak any of those." }, ] -future-size-threshold = 8192 \ No newline at end of file +future-size-threshold = 10240 diff --git a/ipa-core/src/helpers/hashing.rs b/ipa-core/src/helpers/hashing.rs index 10f50484b..ae9579097 100644 --- a/ipa-core/src/helpers/hashing.rs +++ b/ipa-core/src/helpers/hashing.rs @@ -12,7 +12,7 @@ use crate::{ protocol::prss::FromRandomU128, }; -#[derive(Debug, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq)] pub struct Hash(Output); impl Serializable for Hash { diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 5f86f4305..d116181be 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -7,6 +7,7 @@ use std::{ convert::Infallible, fmt::{Debug, Display, Formatter}, num::NonZeroUsize, + ops::Not, }; use generic_array::GenericArray; @@ -271,6 +272,17 @@ pub enum Direction { Right, } +impl Not for Direction { + type Output = Self; + + fn not(self) -> Self { + match self { + Direction::Left => Direction::Right, + Direction::Right => Direction::Left, + } + } +} + impl Role { const H1_STR: &'static str = "H1"; const H2_STR: &'static str = "H2"; diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 671dfa08d..951b5b4c4 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -13,7 +13,6 @@ use crate::{ context::{ dzkp_validator::{Batch, MaliciousDZKPValidatorInner, Segment}, prss::InstrumentedIndexedSharedRandomness, - step::DzkpBatchStep, Context as ContextTrait, DZKPContext, InstrumentedSequentialSharedRandomness, MaliciousContext, }, @@ -100,9 +99,7 @@ impl<'a, B: ShardBinding> DZKPContext for DZKPUpgraded<'a, B> { .batcher .lock() .unwrap() - .validate_record(record_id, |batch_idx, batch| { - batch.validate(ctx.narrow(&DzkpBatchStep(batch_idx))) - }); + .validate_record(record_id, |batch_idx, batch| batch.validate(ctx, batch_idx)); validation_future.await } diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index a734ce629..2b33181b1 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -15,11 +15,14 @@ use crate::{ dzkp_field::{DZKPBaseField, UVTupleBlock}, dzkp_malicious::DZKPUpgraded as MaliciousDZKPUpgraded, dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded, - step::{DzkpSingleBatchStep, DzkpValidationProtocolStep as Step}, + step::DzkpValidationProtocolStep as Step, Base, Context, DZKPContext, MaliciousContext, MaliciousProtocolSteps, }, - ipa_prf::validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, - Gate, RecordId, + ipa_prf::{ + validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, + LargeProofGenerator, SmallProofGenerator, + }, + Gate, RecordId, RecordIdRange, }, seq_join::{seq_join, SeqJoin}, sharding::ShardBinding, @@ -52,6 +55,26 @@ pub const TARGET_PROOF_SIZE: usize = 8192; #[cfg(not(test))] pub const TARGET_PROOF_SIZE: usize = 50_000_000; +/// Maximum proof recursion depth. +// +// This is a hard limit. Each GF(2) multiply generates four G values and four H values, +// and the last level of the proof is limited to (small_recursion_factor - 1), so the +// restriction is: +// +// $$ large_recursion_factor * (small_recursion_factor - 1) +// * small_recursion_factor ^ (depth - 2) >= 4 * target_proof_size $$ +// +// With large_recursion_factor = 32 and small_recursion_factor = 8, this means: +// +// $$ depth >= log_8 (8/7 * target_proof_size) $$ +// +// Because the number of records in a proof batch is often rounded up to a power of two +// (and less significantly, because multiplication intermediate storage gets rounded up +// to blocks of 256), leaving some margin is advised. +// +// The implementation requires that MAX_PROOF_RECURSION is at least 2. +pub const MAX_PROOF_RECURSION: usize = 9; + /// `MultiplicationInputsBlock` is a block of fixed size of intermediate values /// that occur duringa multiplication. /// These values need to be verified since there might have been malicious behavior. @@ -573,9 +596,22 @@ impl Batch { /// ## Panics /// If `usize` to `u128` conversion fails. - pub(super) async fn validate(self, ctx: Base<'_, B>) -> Result<(), Error> { + pub(super) async fn validate( + self, + ctx: Base<'_, B>, + batch_index: usize, + ) -> Result<(), Error> { + const PRSS_RECORDS_PER_BATCH: usize = LargeProofGenerator::PROOF_LENGTH + + (MAX_PROOF_RECURSION - 1) * SmallProofGenerator::PROOF_LENGTH + + 2; // P and Q masks + let proof_ctx = ctx.narrow(&Step::GenerateProof); + let record_id = RecordId::from(batch_index); + let prss_record_id_start = RecordId::from(batch_index * PRSS_RECORDS_PER_BATCH); + let prss_record_id_end = RecordId::from((batch_index + 1) * PRSS_RECORDS_PER_BATCH); + let prss_record_ids = RecordIdRange::from(prss_record_id_start..prss_record_id_end); + if self.is_empty() { return Ok(()); } @@ -587,11 +623,12 @@ impl Batch { q_mask_from_left_prover, ) = { // generate BatchToVerify - ProofBatch::generate(&proof_ctx, self.get_field_values_prover()) + ProofBatch::generate(&proof_ctx, prss_record_ids, self.get_field_values_prover()) }; let chunk_batch = BatchToVerify::generate_batch_to_verify( proof_ctx, + record_id, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -601,7 +638,7 @@ impl Batch { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = chunk_batch - .generate_challenges(ctx.narrow(&Step::Challenge)) + .generate_challenges(ctx.narrow(&Step::Challenge), record_id) .await; let (sum_of_uv, p_r_right_prover, q_r_left_prover) = { @@ -635,6 +672,7 @@ impl Batch { chunk_batch .verify( ctx.narrow(&Step::VerifyProof), + record_id, sum_of_uv, p_r_right_prover, q_r_left_prover, @@ -670,7 +708,18 @@ pub trait DZKPValidator: Send + Sync { /// /// # Panics /// May panic if the above restrictions on validator usage are not followed. - async fn validate(self) -> Result<(), Error>; + async fn validate(self) -> Result<(), Error> + where + Self: Sized, + { + self.validate_indexed(0).await + } + + /// Validates all of the multiplies associated with this validator, specifying + /// an explicit batch index. + /// + /// This should be used when the protocol is explicitly managing batches. + async fn validate_indexed(self, batch_index: usize) -> Result<(), Error>; /// `is_verified` checks that there are no `MultiplicationInputs` that have not been verified /// within the associated `DZKPBatch` @@ -716,6 +765,20 @@ pub trait DZKPValidator: Send + Sync { } } +// Wrapper to avoid https://github.com/rust-lang/rust/issues/100013. +pub fn validated_seq_join<'st, V, S, F, O>( + validator: V, + source: S, +) -> impl Stream> + Send + 'st +where + V: DZKPValidator + 'st, + S: Stream + Send + 'st, + F: Future> + Send + 'st, + O: Send + Sync + 'static, +{ + validator.validated_seq_join(source) +} + #[derive(Clone)] pub struct SemiHonestDZKPValidator<'a, B: ShardBinding> { context: SemiHonestDZKPUpgraded<'a, B>, @@ -741,7 +804,7 @@ impl<'a, B: ShardBinding> DZKPValidator for SemiHonestDZKPValidator<'a, B> { // Semi-honest validator doesn't do anything, so doesn't care. } - async fn validate(self) -> Result<(), Error> { + async fn validate_indexed(self, _batch_index: usize) -> Result<(), Error> { Ok(()) } @@ -787,7 +850,7 @@ impl<'a, B: ShardBinding> DZKPValidator for MaliciousDZKPValidator<'a, B> { .set_total_records(total_records); } - async fn validate(mut self) -> Result<(), Error> { + async fn validate_indexed(mut self, batch_index: usize) -> Result<(), Error> { let arc = self .inner_ref .take() @@ -802,7 +865,7 @@ impl<'a, B: ShardBinding> DZKPValidator for MaliciousDZKPValidator<'a, B> { batcher .into_single_batch() - .validate(validate_ctx.narrow(&DzkpSingleBatchStep)) + .validate(validate_ctx, batch_index) .await } @@ -1271,16 +1334,12 @@ mod tests { } proptest! { - #![proptest_config(ProptestConfig::with_cases(50))] + #![proptest_config(ProptestConfig::with_cases(20))] #[test] fn batching_proptest((record_count, max_multiplications_per_gate) in batching()) { println!("record_count {record_count} batch {max_multiplications_per_gate}"); - if record_count / max_multiplications_per_gate >= 192 { - // TODO: #1269, or even if we don't fix that, don't hardcode the limit. - println!("skipping config because batch count exceeds limit of 192"); - } // This condition is correct only for active_work = 16 and record size of 1 byte. - else if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 { + if max_multiplications_per_gate != 1 && max_multiplications_per_gate % 16 != 0 { // TODO: #1300, read_size | batch_size. // Note: for active work < 2048, read size matches active work. diff --git a/ipa-core/src/protocol/context/step.rs b/ipa-core/src/protocol/context/step.rs index aeb6bd76f..24a8872be 100644 --- a/ipa-core/src/protocol/context/step.rs +++ b/ipa-core/src/protocol/context/step.rs @@ -28,18 +28,6 @@ pub(crate) enum ValidateStep { CheckZero, } -// This really is only for DZKPs and not for MACs. The MAC protocol uses record IDs to -// count batches. DZKP probably should do the same to avoid the fixed upper limit. -#[derive(CompactStep)] -#[step(count = 600, child = DzkpValidationProtocolStep)] -pub(crate) struct DzkpBatchStep(pub usize); - -// This is used when we don't do batched verification, to avoid paying for x256 as many -// steps in compact gate. -#[derive(CompactStep)] -#[step(child = DzkpValidationProtocolStep)] -pub(crate) struct DzkpSingleBatchStep; - #[derive(CompactStep)] pub(crate) enum DzkpValidationProtocolStep { /// Step for proof generation diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index 7e10adf57..b0b17396a 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -118,7 +118,7 @@ where let validator = ctx.clone().dzkp_validator( MaliciousProtocolSteps { protocol: &Step::aggregate(depth), - validate: &Step::aggregate_validate(chunk_counter), + validate: &Step::AggregateValidate, }, usize::MAX, // See note about batching above. ); @@ -129,7 +129,7 @@ where Some(&mut record_ids), ) .await?; - validator.validate().await?; + validator.validate_indexed(chunk_counter).await?; chunk_counter += 1; next_intermediate_results.push(result); } diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs index 0995a8e54..8be4fdcd1 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs @@ -10,17 +10,17 @@ pub(crate) enum AggregationStep { #[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)] Shuffle, Reveal, - #[step(child = crate::protocol::context::step::DzkpSingleBatchStep)] + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] RevealValidate, // only partly used -- see code - #[step(count = 4, child = AggregateChunkStep)] + #[step(count = 4, child = AggregateChunkStep, name = "chunks")] Aggregate(usize), - #[step(count = 600, child = crate::protocol::context::step::DzkpSingleBatchStep)] - AggregateValidate(usize), + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] + AggregateValidate, } // The step count here is duplicated as the AGGREGATE_DEPTH constant in the code. #[derive(CompactStep)] -#[step(count = 24, child = AggregateValuesStep, name = "depth")] +#[step(count = 24, child = AggregateValuesStep, name = "fold")] pub(crate) struct AggregateChunkStep(usize); #[derive(CompactStep)] diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 2477f0867..32dbc3a18 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -50,6 +50,16 @@ where } } +impl Default for CanonicalLagrangeDenominator +where + F: PrimeField + TryFrom, + >::Error: Debug, +{ + fn default() -> Self { + Self::new() + } +} + /// `LagrangeTable` is a precomputed table for the Lagrange evaluation. /// Allows to compute points on the polynomial, i.e. output points, /// given enough points on the polynomial, i.e. input points, diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index 780eabcf2..808dc4476 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -3,14 +3,14 @@ use std::{borrow::Borrow, iter::zip, marker::PhantomData}; #[cfg(all(test, unit_test))] use crate::ff::Fp31; use crate::{ - error::{Error, Error::DZKPMasks}, + error::Error::{self, DZKPMasks}, ff::{Fp61BitPrime, PrimeField}, helpers::hashing::{compute_hash, hash_to_field}, protocol::{ context::Context, ipa_prf::malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, prss::SharedRandomness, - RecordId, + RecordId, RecordIdRange, }, }; @@ -179,7 +179,7 @@ impl ProofGenerat .collect::>() } - fn gen_proof_shares_from_prss(ctx: &C, record_counter: &mut RecordId) -> ([F; P], [F; P]) + fn gen_proof_shares_from_prss(ctx: &C, record_ids: &mut RecordIdRange) -> ([F; P], [F; P]) where C: Context, { @@ -187,9 +187,9 @@ impl ProofGenerat let mut out_right = [F::ZERO; P]; // use PRSS for i in 0..P { - let (left, right) = ctx.prss().generate_fields::(*record_counter); - - *record_counter += 1; + let (left, right) = ctx + .prss() + .generate_fields::(record_ids.expect_next()); out_left[i] = left; out_right[i] = right; @@ -215,7 +215,7 @@ impl ProofGenerat /// `my_proof_left_share` has type `Vec<[F; P]>`, pub fn gen_artefacts_from_recursive_step( ctx: &C, - record_counter: &mut RecordId, + record_ids: &mut RecordIdRange, lagrange_table: &LagrangeTable, uv_iterator: J, ) -> (UVValues, [F; P], [F; P]) @@ -230,7 +230,7 @@ impl ProofGenerat // generate proof shares from prss let (share_of_proof_from_prover_left, my_proof_right_share) = - Self::gen_proof_shares_from_prss(ctx, record_counter); + Self::gen_proof_shares_from_prss(ctx, record_ids); // generate prover left proof let my_proof_left_share = Self::gen_other_proof_share(my_proof, my_proof_right_share); @@ -267,7 +267,7 @@ mod test { lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, prover::{LargeProofGenerator, SmallProofGenerator, TestProofGenerator, UVValues}, }, - RecordId, + RecordId, RecordIdRange, }, seq_join::SeqJoin, test_executor::run, @@ -396,11 +396,11 @@ mod test { // first iteration let world = TestWorld::default(); - let mut record_counter = RecordId::from(0); + let mut record_ids = RecordIdRange::ALL; let (uv_values, _, _) = TestProofGenerator::gen_artefacts_from_recursive_step::<_, _, _, 4>( &world.contexts()[0], - &mut record_counter, + &mut record_ids, &lagrange_table, uv_1.iter(), ); @@ -496,11 +496,11 @@ mod test { let world = TestWorld::default(); let [helper_1_proofs, helper_2_proofs, helper_3_proofs] = world .semi_honest((), |ctx, ()| async move { - let mut record_counter = RecordId::from(0); + let mut record_ids = RecordIdRange::ALL; (0..NUM_PROOFS) .map(|i| { - assert_eq!(i * 7, usize::from(record_counter)); - TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_counter) + assert_eq!(i * 7, usize::from(record_ids.peek_first())); + TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_ids) }) .collect::>() }) @@ -550,9 +550,9 @@ mod test { let [(h1_proof_left, h1_proof_right), (h2_proof_left, h2_proof_right), (h3_proof_left, h3_proof_right)] = world .semi_honest((), |ctx, ()| async move { - let mut record_counter = RecordId::from(0); + let mut record_ids = RecordIdRange::ALL; let (proof_share_left, my_share_of_right) = - TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_counter); + TestProofGenerator::gen_proof_shares_from_prss(&ctx, &mut record_ids); let proof_u128 = match ctx.role() { Role::H1 => PROOF_1, Role::H2 => PROOF_2, diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 9626731e4..559705df5 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -55,6 +55,8 @@ pub(crate) mod shuffle; pub(crate) mod step; pub mod validation_protocol; +pub use malicious_security::prover::{LargeProofGenerator, SmallProofGenerator}; + /// Match key type pub type MatchKey = BA64; /// Match key size @@ -755,7 +757,7 @@ mod compact_gate_tests { fn step_count_limit() { // This is an arbitrary limit intended to catch changes that unintentionally // blow up the step count. It can be increased, within reason. - const STEP_COUNT_LIMIT: u32 = 35_000; + const STEP_COUNT_LIMIT: u32 = 24_000; assert!( ProtocolStep::STEP_COUNT < STEP_COUNT_LIMIT, "Step count of {actual} exceeds limit of {STEP_COUNT_LIMIT}.", diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs index 710b0a7e3..d3f123bf3 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs @@ -8,7 +8,7 @@ pub struct UserNthRowStep(usize); pub(crate) enum AttributionStep { #[step(child = UserNthRowStep)] Attribute, - #[step(child = crate::protocol::context::step::DzkpBatchStep)] + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] AttributeValidate, #[step(child = crate::protocol::ipa_prf::aggregation::step::AggregationStep)] Aggregate, diff --git a/ipa-core/src/protocol/ipa_prf/quicksort.rs b/ipa-core/src/protocol/ipa_prf/quicksort.rs index 30218e5c2..943dfb1ec 100644 --- a/ipa-core/src/protocol/ipa_prf/quicksort.rs +++ b/ipa-core/src/protocol/ipa_prf/quicksort.rs @@ -15,8 +15,8 @@ use crate::{ basics::reveal, boolean::{step::ThirtyTwoBitStep, NBitStep}, context::{ - dzkp_validator::DZKPValidator, Context, DZKPUpgraded, MaliciousProtocolSteps, - UpgradableContext, + dzkp_validator::{validated_seq_join, DZKPValidator, TARGET_PROOF_SIZE}, + Context, DZKPUpgraded, MaliciousProtocolSteps, UpgradableContext, }, ipa_prf::{ boolean_ops::comparison_and_subtraction_sequential::compare_gt, @@ -97,6 +97,10 @@ where } } +fn quicksort_proof_chunk(key_bits: usize) -> usize { + (TARGET_PROOF_SIZE / key_bits / SORT_CHUNK).next_power_of_two() +} + /// Insecure quicksort using MPC comparisons and a key extraction function `get_key`. /// /// `get_key` takes as input an element in the slice and outputs the key by which we sort by @@ -174,9 +178,7 @@ where protocol: &Step::quicksort_pass(quicksort_pass), validate: &Step::quicksort_pass_validate(quicksort_pass), }, - // TODO: use something like this when validating in chunks - // `TARGET_PROOF_SIZE / usize::try_from(K::BITS).unwrap() / SORT_CHUNK`` - total_records_usize.next_power_of_two(), + quicksort_proof_chunk(usize::try_from(K::BITS).unwrap()), ); let c = v.context(); let cmp_ctx = c.narrow(&QuicksortPassStep::Compare); @@ -186,7 +188,7 @@ where stream::iter(ranges_to_sort.clone().into_iter().filter(|r| r.len() > 1)) .flat_map(|range| { // set up iterator - let mut iterator = list[range.clone()].iter().map(get_key).cloned(); + let mut iterator = list[range].iter().map(get_key).cloned(); // first element is pivot, apply key extraction function f let pivot = iterator.next().unwrap(); repeat(pivot).zip(stream::iter(iterator)) @@ -197,8 +199,8 @@ where K::BITS <= ThirtyTwoBitStep::BITS, "ThirtyTwoBitStep is not large enough to accommodate this sort" ); - let compare_results = seq_join( - ctx.active_work(), + let compare_results = validated_seq_join( + v, process_stream_by_chunks::<_, _, _, _, _, _, SORT_CHUNK>( compare_index_pairs, (Vec::new(), Vec::new()), @@ -218,9 +220,6 @@ where .try_collect::>() .await?; - // TODO: validate in chunks rather than for the entire input - v.validate().await?; - let revealed: BitVec = seq_join( ctx.active_work(), stream::iter(compare_results).enumerate().map(|(i, chunk)| { @@ -275,7 +274,7 @@ where #[cfg(all(test, unit_test))] pub mod tests { use std::{ - cmp::Ordering, + cmp::{min, Ordering}, iter::{repeat, repeat_with}, }; @@ -392,6 +391,57 @@ pub mod tests { }); } + #[test] + fn test_quicksort_insecure_malicious_batching() { + run(|| async move { + const COUNT: usize = 600; + let world = TestWorld::default(); + let mut rng = thread_rng(); + + // generate vector of random values + let records: Vec = repeat_with(|| rng.gen()).take(COUNT).collect(); + + // Smaller ranges means fewer passes, makes the test faster. + // (With no impact on proof size, because there is a proof per pass.) + let ranges = (0..COUNT) + .step_by(8) + .map(|i| i..min(i + 8, COUNT)) + .collect::>(); + + // convert expected into more readable format + let mut expected: Vec = + records.clone().into_iter().map(|x| x.as_u128()).collect(); + // sort expected + for range in ranges.iter().cloned() { + expected[range].sort_unstable(); + } + + // compute mpc sort + let result: Vec<_> = world + .malicious(records.into_iter(), |ctx, mut r| { + let ranges_copy = ranges.clone(); + async move { + #[allow(clippy::single_range_in_vec_init)] + quicksort_ranges_by_key_insecure(ctx, &mut r, false, |x| x, ranges_copy) + .await + .unwrap(); + r + } + }) + .await + .reconstruct(); + + assert_eq!( + // convert into more readable format + result + .into_iter() + .map(|x| x.as_u128()) + .collect::>(), + expected + ); + }); + } + #[test] fn test_quicksort_insecure_semi_honest_trivial() { run(|| async move { diff --git a/ipa-core/src/protocol/ipa_prf/step.rs b/ipa-core/src/protocol/ipa_prf/step.rs index 633f157cc..5020a7185 100644 --- a/ipa-core/src/protocol/ipa_prf/step.rs +++ b/ipa-core/src/protocol/ipa_prf/step.rs @@ -8,7 +8,7 @@ pub(crate) enum IpaPrfStep { Shuffle, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::Fp25519ConversionStep)] ConvertFp25519, - #[step(child = crate::protocol::context::step::DzkpBatchStep)] + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] ConvertFp25519Validate, PrfKeyGen, #[step(child = crate::protocol::context::step::MaliciousProtocolStep)] @@ -19,7 +19,7 @@ pub(crate) enum IpaPrfStep { Attribution, #[step(child = crate::protocol::dp::step::DPStep, name = "dp")] DifferentialPrivacy, - #[step(child = crate::protocol::context::step::DzkpSingleBatchStep)] + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] DifferentialPrivacyValidate, } @@ -28,7 +28,7 @@ pub(crate) enum QuicksortStep { /// Sort up to 1B rows. We can't exceed that limit for other reasons as well `record_id`. #[step(count = 30, child = crate::protocol::ipa_prf::step::QuicksortPassStep)] QuicksortPass(usize), - #[step(count = 30, child = crate::protocol::context::step::DzkpSingleBatchStep)] + #[step(count = 30, child = crate::protocol::context::step::DzkpValidationProtocolStep)] QuicksortPassValidate(usize), } diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs index 5eccdc084..cb2754e5f 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs @@ -1,10 +1,16 @@ +use std::{array, iter::zip}; + +use typenum::{UInt, UTerm, Unsigned, B0, B1}; + use crate::{ + const_assert_eq, error::Error, - ff::Fp61BitPrime, - helpers::{Direction, TotalRecords}, + ff::{Fp61BitPrime, Serializable}, + helpers::{Direction, MpcMessage, TotalRecords}, protocol::{ context::{ dzkp_field::{UVTupleBlock, BLOCK_SIZE}, + dzkp_validator::MAX_PROOF_RECURSION, Context, }, ipa_prf::malicious_security::{ @@ -12,8 +18,9 @@ use crate::{ prover::{LargeProofGenerator, SmallProofGenerator}, }, prss::SharedRandomness, - RecordId, + RecordId, RecordIdRange, }, + secret_sharing::SharedValue, }; /// This a `ProofBatch` generated by a prover. @@ -47,11 +54,18 @@ impl ProofBatch { self.proofs.len() * SmallProofGenerator::PROOF_LENGTH + LargeProofGenerator::PROOF_LENGTH } - /// This function returns an iterator over the field elements of all proofs. - fn iter(&self) -> impl Iterator { - self.first_proof + #[allow(clippy::unnecessary_box_returns)] // clippy bug? `Array` exceeds unnecessary-box-size + fn to_array(&self) -> Box { + assert!(self.len() <= ARRAY_LEN); + let iter = self + .first_proof .iter() - .chain(self.proofs.iter().flat_map(|x| x.iter())) + .chain(self.proofs.iter().flat_map(|x| x.iter())); + let mut array = Box::new(array::from_fn(|_| Fp61BitPrime::ZERO)); + for (i, v) in iter.enumerate() { + array[i] = *v; + } + array } /// Each helper party generates a set of proofs, which are secret-shared. @@ -66,7 +80,11 @@ impl ProofBatch { /// ## Panics /// Panics when the function fails to set the masks without overwritting `u` and `v` values. /// This only happens when there is an issue in the recursion. - pub fn generate(ctx: &C, uv_tuple_inputs: I) -> (Self, Self, Fp61BitPrime, Fp61BitPrime) + pub fn generate( + ctx: &C, + mut prss_record_ids: RecordIdRange, + uv_tuple_inputs: I, + ) -> (Self, Self, Fp61BitPrime, Fp61BitPrime) where C: Context, I: Iterator> + Clone, @@ -77,9 +95,6 @@ impl ProofBatch { const SLL: usize = SmallProofGenerator::LAGRANGE_LENGTH; const SPL: usize = SmallProofGenerator::PROOF_LENGTH; - // set up record counter - let mut record_counter = RecordId::FIRST; - // precomputation for first proof let first_denominator = CanonicalLagrangeDenominator::::new(); let first_lagrange_table = LagrangeTable::::from(first_denominator); @@ -88,32 +103,40 @@ impl ProofBatch { let (mut uv_values, first_proof_from_left, my_first_proof_left_share) = LargeProofGenerator::gen_artefacts_from_recursive_step( ctx, - &mut record_counter, + &mut prss_record_ids, &first_lagrange_table, ProofBatch::polynomials_from_inputs(uv_tuple_inputs), ); - // approximate length of proof vector (rounded up) - let uv_len_bits: u32 = usize::BITS - uv_values.len().leading_zeros(); - let small_recursion_factor_bits: u32 = usize::BITS - SRF.leading_zeros(); - let expected_len = 1 << (uv_len_bits - small_recursion_factor_bits); + // `MAX_PROOF_RECURSION - 2` because: + // * The first level of recursion has already happened. + // * We need (SRF - 1) at the last level to have room for the masks. + let max_uv_values: usize = + (SRF - 1) * SRF.pow(u32::try_from(MAX_PROOF_RECURSION - 2).unwrap()); + assert!( + uv_values.len() <= max_uv_values, + "Proof batch is too large: have {} uv_values, max is {}", + uv_values.len(), + max_uv_values, + ); // storage for other proofs - let mut my_proofs_left_shares = Vec::<[Fp61BitPrime; SPL]>::with_capacity(expected_len); + let mut my_proofs_left_shares = + Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); let mut shares_of_proofs_from_prover_left = - Vec::<[Fp61BitPrime; SPL]>::with_capacity(expected_len); + Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); // generate masks // Prover `P_i` and verifier `P_{i-1}` both compute p(x) // therefore the "right" share computed by this verifier corresponds to that which // was used by the prover to the right. - let (my_p_mask, p_mask_from_right_prover) = ctx.prss().generate_fields(record_counter); - record_counter += 1; + let (my_p_mask, p_mask_from_right_prover) = + ctx.prss().generate_fields(prss_record_ids.expect_next()); // Prover `P_i` and verifier `P_{i+1}` both compute q(x) // therefore the "left" share computed by this verifier corresponds to that which // was used by the prover to the left. - let (q_mask_from_left_prover, my_q_mask) = ctx.prss().generate_fields(record_counter); - record_counter += 1; + let (q_mask_from_left_prover, my_q_mask) = + ctx.prss().generate_fields(prss_record_ids.expect_next()); let denominator = CanonicalLagrangeDenominator::::new(); let lagrange_table = LagrangeTable::::from(denominator); @@ -135,7 +158,7 @@ impl ProofBatch { let (uv_values_new, share_of_proof_from_prover_left, my_proof_left_share) = SmallProofGenerator::gen_artefacts_from_recursive_step( ctx, - &mut record_counter, + &mut prss_record_ids, &lagrange_table, uv_values.iter(), ); @@ -165,52 +188,40 @@ impl ProofBatch { /// /// ## Errors /// Propagates error from sending values over the network channel. - pub async fn send_to_left(&self, ctx: &C) -> Result<(), Error> + pub async fn send_to_left(&self, ctx: &C, record_id: RecordId) -> Result<(), Error> where C: Context, { - // set up context for the communication over the network - let communication_ctx = ctx.set_total_records(TotalRecords::specified(self.len())?); - - // set up channel - let send_channel_left = - &communication_ctx.send_channel::(ctx.role().peer(Direction::Left)); - - // send to left - // we send the proof batch via sending the individual field elements - communication_ctx - .parallel_join( - self.iter().enumerate().map(|(i, x)| async move { - send_channel_left.send(RecordId::from(i), x).await - }), - ) - .await?; - Ok(()) + Ok(ctx + .set_total_records(TotalRecords::Indeterminate) + .send_channel::>(ctx.role().peer(Direction::Left)) + .send(record_id, self.to_array()) + .await?) } /// This function receives a `Proof` from the party on the right. /// /// ## Errors /// Propagates errors from receiving values over the network channel. - pub async fn receive_from_right(ctx: &C, length: usize) -> Result + /// + /// ## Panics + /// If the recursion depth implied by `length` exceeds `MAX_PROOF_RECURSION`. + pub async fn receive_from_right( + ctx: &C, + record_id: RecordId, + length: usize, + ) -> Result where C: Context, { - // set up context - let communication_ctx = ctx.set_total_records(TotalRecords::specified(length)?); - - // set up channel - let receive_channel_right = - &communication_ctx.recv_channel::(ctx.role().peer(Direction::Right)); - - // receive from the right + assert!(length <= ARRAY_LEN); Ok(ctx - .parallel_join( - (0..length) - .map(|i| async move { receive_channel_right.receive(RecordId::from(i)).await }), - ) + .set_total_records(TotalRecords::Indeterminate) + .recv_channel::>(ctx.role().peer(Direction::Right)) + .receive(record_id) .await? .into_iter() + .take(length) .collect()) } @@ -251,6 +262,47 @@ impl ProofBatch { } } +const_assert_eq!( + MAX_PROOF_RECURSION, + 9, + "following impl valid only for MAX_PROOF_RECURSION = 9" +); + +#[rustfmt::skip] +type U1464 = UInt, B0>, B1>, B1>, B0>, B1>, B1>, B1>, B0>, B0>, B0>; + +const ARRAY_LEN: usize = 183; +type Array = [Fp61BitPrime; ARRAY_LEN]; + +impl Serializable for Box { + type Size = U1464; + + type DeserializationError = ::DeserializationError; + + fn serialize(&self, buf: &mut generic_array::GenericArray) { + for (hash, buf) in zip( + **self, + buf.chunks_mut(<::Size as Unsigned>::to_usize()), + ) { + hash.serialize(buf.try_into().unwrap()); + } + } + + fn deserialize( + buf: &generic_array::GenericArray, + ) -> Result { + Ok(buf + .chunks(<::Size as Unsigned>::to_usize()) + .map(|buf| Fp61BitPrime::deserialize(buf.try_into().unwrap())) + .collect::, _>>()? + .try_into() + .unwrap()) + } +} + +impl MpcMessage for Box {} + #[cfg(all(test, unit_test))] mod test { use rand::{thread_rng, Rng}; @@ -263,6 +315,7 @@ mod test { proof_generation::ProofBatch, validation::{test::simple_proof_check, BatchToVerify}, }, + RecordId, RecordIdRange, }, secret_sharing::replicated::ReplicatedSecretSharing, test_executor::run, @@ -312,11 +365,13 @@ mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, uv_tuple_vec.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs index bbb994f70..f0430e996 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs @@ -1,16 +1,23 @@ -use std::iter::{once, repeat}; +use std::{ + array, + iter::{once, repeat, zip}, +}; use futures_util::future::{try_join, try_join4}; +use typenum::{Unsigned, U288, U80}; use crate::{ - error::Error, - ff::Fp61BitPrime, + const_assert_eq, + error::{Error, UnwrapInfallible}, + ff::{Fp61BitPrime, Serializable}, helpers::{ hashing::{compute_hash, hash_to_field, Hash}, - Direction, TotalRecords, + Direction, MpcMessage, TotalRecords, }, protocol::{ - context::{step::DzkpProofVerifyStep as Step, Context}, + context::{ + dzkp_validator::MAX_PROOF_RECURSION, step::DzkpProofVerifyStep as Step, Context, + }, ipa_prf::{ malicious_security::{ prover::{LargeProofGenerator, SmallProofGenerator}, @@ -55,6 +62,7 @@ impl BatchToVerify { /// Panics when send and receive over the network channels fail. pub async fn generate_batch_to_verify( ctx: C, + record_id: RecordId, my_batch_left_shares: ProofBatch, shares_of_batch_from_left_prover: ProofBatch, p_mask_from_right_prover: Fp61BitPrime, @@ -66,8 +74,8 @@ impl BatchToVerify { // send one batch left and receive one batch from the right let length = my_batch_left_shares.len(); let ((), shares_of_batch_from_right_prover) = try_join( - my_batch_left_shares.send_to_left(&ctx), - ProofBatch::receive_from_right(&ctx, length), + my_batch_left_shares.send_to_left(&ctx, record_id), + ProofBatch::receive_from_right(&ctx, record_id, length), ) .await .unwrap(); @@ -88,7 +96,11 @@ impl BatchToVerify { /// ## Panics /// Panics when recursion factor constant cannot be converted to `u128` /// or when sending and receiving hashes over the network fails. - pub async fn generate_challenges(&self, ctx: C) -> (Vec, Vec) + pub async fn generate_challenges( + &self, + ctx: C, + record_id: RecordId, + ) -> (Vec, Vec) where C: Context, { @@ -101,15 +113,25 @@ impl BatchToVerify { let exclude_small = u128::try_from(SRF).unwrap(); // generate hashes - let my_hashes_prover_left = ProofHashes::generate_hashes(self, Side::Left); - let my_hashes_prover_right = ProofHashes::generate_hashes(self, Side::Right); + let my_hashes_prover_left = ProofHashes::generate_hashes(self, Direction::Left); + let my_hashes_prover_right = ProofHashes::generate_hashes(self, Direction::Right); // receive hashes from the other verifier let ((), (), other_hashes_prover_left, other_hashes_prover_right) = try_join4( - my_hashes_prover_left.send_hashes(&ctx, Side::Left), - my_hashes_prover_right.send_hashes(&ctx, Side::Right), - ProofHashes::receive_hashes(&ctx, my_hashes_prover_left.hashes.len(), Side::Left), - ProofHashes::receive_hashes(&ctx, my_hashes_prover_right.hashes.len(), Side::Right), + my_hashes_prover_left.send_hashes(&ctx, record_id, Direction::Left), + my_hashes_prover_right.send_hashes(&ctx, record_id, Direction::Right), + ProofHashes::receive_hashes( + &ctx, + record_id, + my_hashes_prover_left.hashes.len(), + Direction::Left, + ), + ProofHashes::receive_hashes( + &ctx, + record_id, + my_hashes_prover_right.hashes.len(), + Direction::Right, + ), ) .await .unwrap(); @@ -174,6 +196,7 @@ impl BatchToVerify { /// This function computes and outputs the final `p_r_right_prover * q_r_right_prover` value. async fn compute_p_times_q( ctx: C, + record_id: RecordId, p_r_right_prover: Fp61BitPrime, q_r_left_prover: Fp61BitPrime, ) -> Result @@ -181,7 +204,7 @@ impl BatchToVerify { C: Context, { // send to the left - let communication_ctx = ctx.set_total_records(TotalRecords::specified(1usize)?); + let communication_ctx = ctx.set_total_records(TotalRecords::Indeterminate); let send_right = communication_ctx.send_channel::(ctx.role().peer(Direction::Right)); @@ -189,8 +212,8 @@ impl BatchToVerify { communication_ctx.recv_channel::(ctx.role().peer(Direction::Left)); let ((), q_r_right_prover) = try_join( - send_right.send(RecordId::FIRST, q_r_left_prover), - receive_left.receive(RecordId::FIRST), + send_right.send(record_id, q_r_left_prover), + receive_left.receive(record_id), ) .await?; @@ -201,9 +224,14 @@ impl BatchToVerify { /// /// ## Errors /// Propagates network errors or when the proof fails to verify. + /// + /// ## Panics + /// If the proof exceeds `MAX_PROOF_RECURSION`. + #[allow(clippy::too_many_arguments)] pub async fn verify( &self, ctx: C, + record_id: RecordId, sum_of_uv_right: Fp61BitPrime, p_r_right_prover: Fp61BitPrime, q_r_left_prover: Fp61BitPrime, @@ -221,6 +249,7 @@ impl BatchToVerify { let p_times_q_right = Self::compute_p_times_q( ctx.narrow(&Step::PTimesQ), + record_id, p_r_right_prover, q_r_left_prover, ) @@ -243,33 +272,26 @@ impl BatchToVerify { p_times_q_right, ); - // send dif_left to the right + // send diff_left to the right let length = diff_left.len(); + assert!(length <= MAX_PROOF_RECURSION + 1); + let communication_ctx = ctx .narrow(&Step::Diff) - .set_total_records(TotalRecords::specified(length)?); - - let send_channel = - communication_ctx.send_channel::(ctx.role().peer(Direction::Right)); - let receive_channel = - communication_ctx.recv_channel::(ctx.role().peer(Direction::Left)); + .set_total_records(TotalRecords::Indeterminate); - let send_channel_ref = &send_channel; - let receive_channel_ref = &receive_channel; + let send_data = array::from_fn(|i| *diff_left.get(i).unwrap_or(&Fp61BitPrime::ZERO)); - let send_future = communication_ctx.parallel_join( - diff_left - .iter() - .enumerate() - .map(|(i, f)| async move { send_channel_ref.send(RecordId::from(i), f).await }), - ); - - let receive_future = communication_ctx.parallel_join( - (0..length) - .map(|i| async move { receive_channel_ref.receive(RecordId::from(i)).await }), - ); - - let (_, diff_right_from_other_verifier) = try_join(send_future, receive_future).await?; + let ((), receive_data) = try_join( + communication_ctx + .send_channel::(ctx.role().peer(Direction::Right)) + .send(record_id, send_data), + communication_ctx + .recv_channel::(ctx.role().peer(Direction::Left)) + .receive(record_id), + ) + .await?; + let diff_right_from_other_verifier = receive_data[0..length].to_vec(); // compare recombined dif to zero for i in 0..length { @@ -286,21 +308,15 @@ struct ProofHashes { hashes: Vec, } -#[derive(Clone, Copy, Debug)] -enum Side { - Left, - Right, -} - impl ProofHashes { - // Generates hashes for proofs received from prover indicated by `side` - fn generate_hashes(batch_to_verify: &BatchToVerify, side: Side) -> Self { - let (first_proof, other_proofs) = match side { - Side::Left => ( + // Generates hashes for proofs received from prover indicated by `direction` + fn generate_hashes(batch_to_verify: &BatchToVerify, direction: Direction) -> Self { + let (first_proof, other_proofs) = match direction { + Direction::Left => ( &batch_to_verify.first_proof_from_left_prover, &batch_to_verify.proofs_from_left_prover, ), - Side::Right => ( + Direction::Right => ( &batch_to_verify.first_proof_from_right_prover, &batch_to_verify.proofs_from_right_prover, ), @@ -314,54 +330,116 @@ impl ProofHashes { } /// Sends the one verifier's hashes to the other verifier - /// `side` indicates the direction of the prover. - async fn send_hashes(&self, ctx: &C, side: Side) -> Result<(), Error> { - let communication_ctx = ctx.set_total_records(TotalRecords::specified(self.hashes.len())?); - - let send_channel = match side { - // send left hashes to the right - Side::Left => communication_ctx.send_channel::(ctx.role().peer(Direction::Right)), - // send right hashes to the left - Side::Right => communication_ctx.send_channel::(ctx.role().peer(Direction::Left)), - }; - let send_channel_ref = &send_channel; - - communication_ctx - .parallel_join(self.hashes.iter().enumerate().map(|(i, hash)| async move { - send_channel_ref.send(RecordId::from(i), hash).await - })) + /// `direction` indicates the direction of the prover. + async fn send_hashes( + &self, + ctx: &C, + record_id: RecordId, + direction: Direction, + ) -> Result<(), Error> { + assert!(self.hashes.len() <= MAX_PROOF_RECURSION); + let hashes_send = + array::from_fn(|i| self.hashes.get(i).unwrap_or(&Hash::default()).clone()); + let verifier_direction = !direction; + ctx.set_total_records(TotalRecords::Indeterminate) + .send_channel::<[Hash; MAX_PROOF_RECURSION]>(ctx.role().peer(verifier_direction)) + .send(record_id, hashes_send) .await?; Ok(()) } /// This function receives hashes from the other verifier - /// `side` indicates the direction of the prover. - async fn receive_hashes(ctx: &C, length: usize, side: Side) -> Result { - // set up context for the communication over the network - let communication_ctx = ctx.set_total_records(TotalRecords::specified(length)?); - - let recv_channel = match side { - // receive left hashes from the right helper - Side::Left => communication_ctx.recv_channel::(ctx.role().peer(Direction::Right)), - // reeive right hashes from the left helper - Side::Right => communication_ctx.recv_channel::(ctx.role().peer(Direction::Left)), - }; - let recv_channel_ref = &recv_channel; - - let hashes_received = communication_ctx - .parallel_join( - (0..length) - .map(|i| async move { recv_channel_ref.receive(RecordId::from(i)).await }), - ) + /// `direction` indicates the direction of the prover. + async fn receive_hashes( + ctx: &C, + record_id: RecordId, + length: usize, + direction: Direction, + ) -> Result { + assert!(length <= MAX_PROOF_RECURSION); + let verifier_direction = !direction; + let hashes_received = ctx + .set_total_records(TotalRecords::Indeterminate) + .recv_channel::<[Hash; MAX_PROOF_RECURSION]>(ctx.role().peer(verifier_direction)) + .receive(record_id) .await?; - Ok(Self { - hashes: hashes_received, + hashes: hashes_received[0..length].to_vec(), }) } } +const_assert_eq!( + MAX_PROOF_RECURSION, + 9, + "following impl valid only for MAX_PROOF_RECURSION = 9" +); + +impl Serializable for [Hash; MAX_PROOF_RECURSION] { + type Size = U288; + + type DeserializationError = ::DeserializationError; + + fn serialize(&self, buf: &mut generic_array::GenericArray) { + for (hash, buf) in zip( + self, + buf.chunks_mut(<::Size as Unsigned>::to_usize()), + ) { + hash.serialize(buf.try_into().unwrap()); + } + } + + fn deserialize( + buf: &generic_array::GenericArray, + ) -> Result { + Ok(buf + .chunks(<::Size as Unsigned>::to_usize()) + .map(|buf| Hash::deserialize(buf.try_into().unwrap()).unwrap_infallible()) + .collect::>() + .try_into() + .unwrap()) + } +} + +impl MpcMessage for [Hash; MAX_PROOF_RECURSION] {} + +const_assert_eq!( + MAX_PROOF_RECURSION, + 9, + "following impl valid only for MAX_PROOF_RECURSION = 9" +); + +type ProofDiff = [Fp61BitPrime; MAX_PROOF_RECURSION + 1]; + +impl Serializable for ProofDiff { + type Size = U80; + + type DeserializationError = ::DeserializationError; + + fn serialize(&self, buf: &mut generic_array::GenericArray) { + for (hash, buf) in zip( + self, + buf.chunks_mut(<::Size as Unsigned>::to_usize()), + ) { + hash.serialize(buf.try_into().unwrap()); + } + } + + fn deserialize( + buf: &generic_array::GenericArray, + ) -> Result { + Ok(buf + .chunks(<::Size as Unsigned>::to_usize()) + .map(|buf| Fp61BitPrime::deserialize(buf.try_into().unwrap())) + .collect::, _>>()? + .try_into() + .unwrap()) + } +} + +impl MpcMessage for ProofDiff {} + #[cfg(all(test, unit_test))] pub mod test { use futures_util::future::try_join; @@ -384,7 +462,7 @@ pub mod test { validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, }, prss::SharedRandomness, - RecordId, + RecordId, RecordIdRange, }, secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue}, test_executor::run, @@ -528,11 +606,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, uv_tuple_vec.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -541,7 +621,9 @@ pub mod test { .await; // generate and output challenges - batch_to_verify.generate_challenges(ctx).await + batch_to_verify + .generate_challenges(ctx, RecordId::FIRST) + .await }) .await; @@ -639,11 +721,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, vec_my_u_and_v.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -654,7 +738,7 @@ pub mod test { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = batch_to_verify - .generate_challenges(ctx.narrow("generate_hash")) + .generate_challenges(ctx.narrow("generate_hash"), RecordId::FIRST) .await; assert_eq!( @@ -743,11 +827,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, vec_my_u_and_v.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -758,7 +844,7 @@ pub mod test { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = batch_to_verify - .generate_challenges(ctx.narrow("generate_hash")) + .generate_challenges(ctx.narrow("generate_hash"), RecordId::FIRST) .await; assert_eq!( @@ -773,7 +859,10 @@ pub mod test { vec_v_from_left_prover.into_iter(), ); - let p_times_q = BatchToVerify::compute_p_times_q(ctx, p, q).await.unwrap(); + let p_times_q = + BatchToVerify::compute_p_times_q(ctx, RecordId::FIRST, p, q) + .await + .unwrap(); let denominator = CanonicalLagrangeDenominator::< Fp61BitPrime, @@ -828,11 +917,13 @@ pub mod test { q_mask_from_left_prover, ) = ProofBatch::generate( &ctx.narrow("generate_batch"), + RecordIdRange::ALL, vec_my_u_and_v.into_iter(), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), + RecordId::FIRST, my_batch_left_shares, shares_of_batch_from_left_prover, p_mask_from_right_prover, @@ -861,7 +952,7 @@ pub mod test { // generate challenges let (challenges_for_left_prover, challenges_for_right_prover) = batch_to_verify - .generate_challenges(ctx.narrow("generate_hash")) + .generate_challenges(ctx.narrow("generate_hash"), RecordId::FIRST) .await; let (p, q) = batch_to_verify.compute_p_and_q_r( @@ -874,6 +965,7 @@ pub mod test { batch_to_verify .verify( v_ctx, + RecordId::FIRST, sum_of_uv_right, p, q, diff --git a/ipa-core/src/protocol/mod.rs b/ipa-core/src/protocol/mod.rs index 28abf9741..18dfc6221 100644 --- a/ipa-core/src/protocol/mod.rs +++ b/ipa-core/src/protocol/mod.rs @@ -10,7 +10,7 @@ pub mod step; use std::{ fmt::{Debug, Display, Formatter}, hash::Hash, - ops::{Add, AddAssign}, + ops::{Add, AddAssign, Range}, }; pub use basics::{BasicProtocols, BooleanProtocols}; @@ -107,6 +107,7 @@ impl From for RecordId { impl RecordId { pub(crate) const FIRST: Self = Self(0); + pub(crate) const LAST: Self = Self(u32::MAX); } impl From for u128 { @@ -147,6 +148,30 @@ impl AddAssign for RecordId { } } +pub struct RecordIdRange(Range); + +impl RecordIdRange { + pub const ALL: RecordIdRange = RecordIdRange(RecordId::FIRST..RecordId::LAST); + + #[cfg(all(test, unit_test))] + fn peek_first(&self) -> RecordId { + self.0.start + } + + fn expect_next(&mut self) -> RecordId { + assert!(self.0.start < self.0.end, "RecordIdRange exhausted"); + let val = self.0.start; + self.0.start += 1; + val + } +} + +impl From> for RecordIdRange { + fn from(value: Range) -> Self { + Self(value) + } +} + /// Helper used when an operation may or may not be associated with a specific record. This is /// also used to prevent some kinds of invalid uses of record ID iteration. For example, trying to /// use the record ID to iterate over both the inner and outer vectors in a `Vec>` is an From ad8ea607b719b6967af33a9a3e8bbcf1e33168ed Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 24 Oct 2024 11:16:04 -0700 Subject: [PATCH 180/191] Use repeat_n from std now that it is stable --- ipa-core/Cargo.toml | 2 +- ipa-core/src/helpers/mod.rs | 30 ------------------- .../boolean_ops/addition_sequential.rs | 3 +- .../boolean_ops/share_conversion_aby.rs | 4 +-- .../prf_sharding/feature_label_dot_product.rs | 7 +++-- .../src/protocol/ipa_prf/prf_sharding/mod.rs | 8 ++--- 6 files changed, 12 insertions(+), 42 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 0081a0a50..0a64b7ca1 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "ipa-core" version = "0.1.0" -rust-version = "1.80.0" +rust-version = "1.82.0" edition = "2021" build = "build.rs" diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index d116181be..a52516786 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -634,36 +634,6 @@ where } } -pub struct RepeatN { - element: T, - count: usize, -} - -// As of Apr. 2024, this is unstable in `std::iter`. It is also available in `itertools`. -// The advantage over `repeat(element).take(count)` that we care about is that this -// implements `ExactSizeIterator`. The other advantage is that `repeat_n` can return -// the original value (saving a clone) on the last iteration. -pub fn repeat_n(element: T, count: usize) -> RepeatN { - RepeatN { element, count } -} - -impl Iterator for RepeatN { - type Item = T; - - fn next(&mut self) -> Option { - (self.count > 0).then(|| { - self.count -= 1; - self.element.clone() - }) - } - - fn size_hint(&self) -> (usize, Option) { - (self.count, Some(self.count)) - } -} - -impl ExactSizeIterator for RepeatN {} - #[cfg(all(test, unit_test))] mod tests { use super::*; diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs index 1c55fa578..f5fcefe20 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs @@ -1,11 +1,10 @@ -use std::iter::repeat; +use std::iter::{repeat, repeat_n}; use ipa_step::StepNarrow; use crate::{ error::Error, ff::boolean::Boolean, - helpers::repeat_n, protocol::{ basics::{BooleanProtocols, SecureMul}, boolean::{or::bool_or, NBitStep}, diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs index a42bdbbbb..2dabdc3f4 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs @@ -367,7 +367,7 @@ where #[cfg(all(test, unit_test))] mod tests { - use std::iter::{self, repeat_with}; + use std::iter::{self, repeat_n, repeat_with}; use curve25519_dalek::Scalar; use futures::stream::TryStreamExt; @@ -378,7 +378,7 @@ mod tests { use super::*; use crate::{ ff::{boolean_array::BA64, Serializable}, - helpers::{repeat_n, stream::process_slice_by_chunks}, + helpers::stream::process_slice_by_chunks, protocol::{ context::{dzkp_validator::DZKPValidator, UpgradableContext, TEST_DZKP_STEPS}, ipa_prf::{CONV_CHUNK, CONV_PROOF_CHUNK, PRF_CHUNK}, diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs index e2ae77c96..708998ed0 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs @@ -1,4 +1,7 @@ -use std::{convert::Infallible, iter::zip}; +use std::{ + convert::Infallible, + iter::{repeat_n, zip}, +}; use futures::stream; use futures_util::{future::try_join, stream::unfold, Stream, StreamExt}; @@ -6,7 +9,7 @@ use futures_util::{future::try_join, stream::unfold, Stream, StreamExt}; use crate::{ error::{Error, LengthError, UnwrapInfallible}, ff::{boolean::Boolean, boolean_array::BooleanArray, Field, U128Conversions}, - helpers::{repeat_n, stream::TryFlattenItersExt, TotalRecords}, + helpers::{stream::TryFlattenItersExt, TotalRecords}, protocol::{ basics::{SecureMul, ShareKnownValue}, boolean::{and::bool_and_8_bit, or::or}, diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 0490a7f87..03994ef6c 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -1,7 +1,6 @@ use std::{ convert::Infallible, - iter, - iter::zip, + iter::{self, repeat_n, zip}, num::NonZeroU32, ops::{Not, Range}, }; @@ -20,7 +19,7 @@ use crate::{ boolean_array::{BooleanArray, BA32, BA7}, ArrayAccess, Field, U128Conversions, }, - helpers::{repeat_n, stream::TryFlattenItersExt, TotalRecords}, + helpers::{stream::TryFlattenItersExt, TotalRecords}, protocol::{ basics::{select, BooleanArrayMul, BooleanProtocols, Reveal, SecureMul, ShareKnownValue}, boolean::{ @@ -877,7 +876,7 @@ where #[cfg(all(test, unit_test))] pub mod tests { - use std::num::NonZeroU32; + use std::{iter::repeat_n, num::NonZeroU32}; use super::{AttributionOutputs, PrfShardedIpaInputRow}; use crate::{ @@ -886,7 +885,6 @@ pub mod tests { boolean_array::{BooleanArray, BA16, BA20, BA3, BA5, BA8}, Field, U128Conversions, }, - helpers::repeat_n, protocol::ipa_prf::{ oprf_padding::PaddingParameters, prf_sharding::attribute_cap_aggregate, }, From 6293f7f49e5137d0c782ef757d97f9228cd7ea74 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 24 Oct 2024 20:41:44 -0700 Subject: [PATCH 181/191] Extend the compact gate implementation to provide `index` method This method essentially returns the unique index associated with the current gate. The plan is to use that number for metrics uniqueness check: they greatly benefit from knowing that the dimension for compact gate is `u64` --- ipa-step-derive/src/lib.rs | 12 ++++++++++++ ipa-step-test/src/lib.rs | 13 +++++++++++++ 2 files changed, 25 insertions(+) diff --git a/ipa-step-derive/src/lib.rs b/ipa-step-derive/src/lib.rs index ea5d19268..2298e3969 100644 --- a/ipa-step-derive/src/lib.rs +++ b/ipa-step-derive/src/lib.rs @@ -165,6 +165,18 @@ fn derive_gate_impl(ast: &DeriveInput) -> TokenStream { ::fmt(self, f) } } + + impl #name { + /// Returns the current index. It matches the index of the latest step + /// this gate has been narrowed to. + /// + /// If gate hasn't been narrowed yet, it returns the index of the default value. + #[must_use] + pub fn index(&self) -> ::ipa_step::CompactGateIndex { + self.0 + } + } + }; // This environment variable is set by build scripts, diff --git a/ipa-step-test/src/lib.rs b/ipa-step-test/src/lib.rs index 84eab760d..3789e6fcf 100644 --- a/ipa-step-test/src/lib.rs +++ b/ipa-step-test/src/lib.rs @@ -17,15 +17,21 @@ mod tests { #[test] fn narrows() { + assert_eq!(ComplexGate::default().index(), 0); assert_eq!(ComplexGate::default().as_ref(), "/"); assert_eq!( ComplexGate::default().narrow(&ComplexStep::One).as_ref(), "/one" ); + assert_eq!(ComplexGate::default().narrow(&ComplexStep::One).index(), 1,); assert_eq!( ComplexGate::default().narrow(&ComplexStep::Two(2)).as_ref(), "/two2" ); + assert_eq!( + ComplexGate::default().narrow(&ComplexStep::Two(2)).index(), + 10, + ); assert_eq!( ComplexGate::default() .narrow(&ComplexStep::Two(2)) @@ -33,6 +39,13 @@ mod tests { .as_ref(), "/two2/one" ); + assert_eq!( + ComplexGate::default() + .narrow(&ComplexStep::Two(2)) + .narrow(&BasicStep::One) + .index(), + 11, + ); assert_eq!( ComplexGate::from("/two2/one"), ComplexGate::default() From f680adcc5433fd12bcedfce3a2030ea065636f0b Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 24 Oct 2024 20:54:34 -0700 Subject: [PATCH 182/191] Make OwnedName fields public They are required to build snapshots on the IPA side --- ipa-metrics/src/key.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ipa-metrics/src/key.rs b/ipa-metrics/src/key.rs index 8f01ea2f4..02229a11f 100644 --- a/ipa-metrics/src/key.rs +++ b/ipa-metrics/src/key.rs @@ -114,8 +114,8 @@ impl<'lv, const LABELS: usize> Name<'lv, LABELS> { /// This is the key inside metric stores which are simple hashmaps. #[derive(Debug, Clone, Eq)] pub struct OwnedName { - key: &'static str, - labels: [Option; 5], + pub key: &'static str, + pub labels: [Option; 5], } impl OwnedName { From ffbba6ade30de380d1b820bddbaca24df3f99397 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 24 Oct 2024 20:55:33 -0700 Subject: [PATCH 183/191] Provide a default hasher for hashing labels We don't need to import fast hasher everywhere with this change. Labels don't need collision resistance, so implementors of `LabelValue` trait can use it --- ipa-metrics/src/label.rs | 11 +++++++++++ ipa-metrics/src/lib.rs | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/ipa-metrics/src/label.rs b/ipa-metrics/src/label.rs index 27da2b116..86a1b11d1 100644 --- a/ipa-metrics/src/label.rs +++ b/ipa-metrics/src/label.rs @@ -3,8 +3,19 @@ use std::{ hash::{Hash, Hasher}, }; +use rustc_hash::FxHasher; + pub const MAX_LABELS: usize = 5; +/// Provides a fast, non-collision resistant implementation of [`Hasher`] +/// for label values. T +/// +/// [`Hasher`]: std::hash::Hasher +#[must_use] +pub fn label_hasher() -> impl Hasher { + FxHasher::default() +} + /// Dimension value (or label value) must be sendable to another thread /// and there must be a way to show it pub trait LabelValue: Display + Send { diff --git a/ipa-metrics/src/lib.rs b/ipa-metrics/src/lib.rs index f84f8dc1c..a637518c5 100644 --- a/ipa-metrics/src/lib.rs +++ b/ipa-metrics/src/lib.rs @@ -22,7 +22,7 @@ pub use controller::{ Status as ControllerStatus, }; pub use key::{MetricName, OwnedName, UniqueElements}; -pub use label::{Label, LabelValue}; +pub use label::{label_hasher, Label, LabelValue}; #[cfg(feature = "partitions")] pub use partitioned::{ CurrentThreadContext as CurrentThreadPartitionContext, Partition as MetricPartition, From d2bbbff263986099bc4d9bd3d00df313dbc81d1e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 24 Oct 2024 20:56:14 -0700 Subject: [PATCH 184/191] Provide a method to get all counters from the store --- ipa-metrics/src/store.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ipa-metrics/src/store.rs b/ipa-metrics/src/store.rs index 501b875a2..bd792ae6e 100644 --- a/ipa-metrics/src/store.rs +++ b/ipa-metrics/src/store.rs @@ -94,6 +94,13 @@ impl Store { pub fn is_empty(&self) -> bool { self.len() == 0 } + + /// Returns an iterator over the counters in the store. + /// + /// The iterator item is a tuple of the metric name and the counter value. + pub fn counters(&self) -> impl Iterator { + self.counters.iter().map(|(key, value)| (key, *value)) + } } pub struct CounterHandle<'a, const LABELS: usize> { From 2cbf4ed1fc359e619897dfd12b6ff9a21b269408 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 24 Oct 2024 20:56:46 -0700 Subject: [PATCH 185/191] Provide a method that shows whether the current thread is connected to the collector thread --- ipa-metrics/src/context.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ipa-metrics/src/context.rs b/ipa-metrics/src/context.rs index 938d4560b..7d19f2eb8 100644 --- a/ipa-metrics/src/context.rs +++ b/ipa-metrics/src/context.rs @@ -39,6 +39,11 @@ impl CurrentThreadContext { pub fn store_mut T, T>(f: F) -> T { METRICS_CTX.with_borrow_mut(|ctx| f(ctx.store_mut())) } + + #[must_use] + pub fn is_connected() -> bool { + METRICS_CTX.with_borrow(|ctx| ctx.tx.is_some()) + } } /// This context is used inside thread-local storage, From f3608f666c2698f7ac9241457911ea3ca5df8e76 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 24 Oct 2024 20:58:53 -0700 Subject: [PATCH 186/191] Allow customizing channel size between collector and producer thread I noticed in unit tests it gets very hard to get a correct snapshot after test finishes. Often test thread sends the data and moves on, but the collector thread hasn't processed the data yet. So any assertion on the metric fails occasionally. One way to deal with it is to block the sender until receiver processes its data. Obviously performance-wise it is not great, but it is ok for unit tests --- ipa-metrics/src/collector.rs | 27 +++++++++++++++++++++--- ipa-metrics/src/controller.rs | 12 +++++------ ipa-metrics/src/lib.rs | 39 ++++++++++++++++++++++++++++++----- 3 files changed, 64 insertions(+), 14 deletions(-) diff --git a/ipa-metrics/src/collector.rs b/ipa-metrics/src/collector.rs index 4d5995af3..94022e340 100644 --- a/ipa-metrics/src/collector.rs +++ b/ipa-metrics/src/collector.rs @@ -111,7 +111,10 @@ mod tests { thread::{Scope, ScopedJoinHandle}, }; - use crate::{controller::Status, counter, install, install_new_thread, producer::Producer}; + use crate::{ + controller::Status, counter, install, install_new_thread, producer::Producer, + MetricChannelType, + }; struct MeteredScope<'scope, 'env: 'scope>(&'scope Scope<'scope, 'env>, Producer); @@ -145,7 +148,7 @@ mod tests { #[test] fn start_stop() { - let (collector, producer, controller) = install(); + let (collector, producer, controller) = install(MetricChannelType::Unbounded); let handle = thread::spawn(|| { let store = collector.install().block_until_shutdown(); store.counter_val(counter!("foo")) @@ -165,7 +168,8 @@ mod tests { #[test] fn with_thread() { - let (producer, controller, handle) = install_new_thread().unwrap(); + let (producer, controller, handle) = + install_new_thread(MetricChannelType::Unbounded).unwrap(); thread::scope(move |s| { let s = s.metered(producer); s.spawn(|| counter!("baz", 4)); @@ -179,4 +183,21 @@ mod tests { handle.join().unwrap(); // Collector thread should be terminated by now } + + #[test] + fn with_thread_rendezvous() { + let (producer, controller, _handle) = + install_new_thread(MetricChannelType::Rendezvous).unwrap(); + let counter = thread::scope(move |s| { + let s = s.metered(producer); + s.spawn(|| counter!("foo", 3)).join().unwrap(); + s.spawn(|| counter!("foo", 5)).join().unwrap(); + // we don't need to check the status because producer threads are now + // blocked until the collector receives their stores. This means that + // the snapshot must be up to date by now. + controller.snapshot().unwrap().counter_val(counter!("foo")) + }); + + assert_eq!(8, counter); + } } diff --git a/ipa-metrics/src/controller.rs b/ipa-metrics/src/controller.rs index 265dacf45..52deed853 100644 --- a/ipa-metrics/src/controller.rs +++ b/ipa-metrics/src/controller.rs @@ -34,9 +34,9 @@ impl Controller { /// /// ## Example /// ```rust - /// use ipa_metrics::{install_new_thread, MetricsStore}; + /// use ipa_metrics::{install_new_thread, MetricChannelType, MetricsStore}; /// - /// let (_, controller, _handle) = install_new_thread().unwrap(); + /// let (_, controller, _handle) = install_new_thread(MetricChannelType::Unbounded).unwrap(); /// let snapshot = controller.snapshot().unwrap(); /// println!("Current metrics: {snapshot:?}"); /// ``` @@ -60,9 +60,9 @@ impl Controller { /// /// ## Example /// ```rust - /// use ipa_metrics::{install_new_thread, MetricsStore}; + /// use ipa_metrics::{install_new_thread, MetricChannelType, MetricsStore}; /// - /// let (_, controller, _handle) = install_new_thread().unwrap(); + /// let (_, controller, _handle) = install_new_thread(MetricChannelType::Unbounded).unwrap(); /// controller.stop().unwrap(); /// ``` pub fn stop(self) -> Result<(), String> { @@ -81,9 +81,9 @@ impl Controller { /// /// ## Example /// ```rust - /// use ipa_metrics::{install_new_thread, ControllerStatus}; + /// use ipa_metrics::{install_new_thread, ControllerStatus, MetricChannelType}; /// - /// let (_, controller, _handle) = install_new_thread().unwrap(); + /// let (_, controller, _handle) = install_new_thread(MetricChannelType::Unbounded).unwrap(); /// let status = controller.status().unwrap(); /// println!("Collector status: {status:?}"); /// ``` diff --git a/ipa-metrics/src/lib.rs b/ipa-metrics/src/lib.rs index a637518c5..2449d41a3 100644 --- a/ipa-metrics/src/lib.rs +++ b/ipa-metrics/src/lib.rs @@ -32,6 +32,19 @@ pub use producer::Producer as MetricsProducer; #[cfg(not(feature = "partitions"))] pub use store::Store as MetricsStore; +/// Type of the communication channel between metric producers +/// and the collector. +#[derive(Copy, Clone)] +pub enum MetricChannelType { + /// Each send message must be paired with receive. Sends that + /// don't get a pair block the thread until collector processes + /// the request. This mode is suitable for unit tests where metric + /// consistency is important and gets more priority than availability. + Rendezvous, + /// Each channel between producer and collector gets unlimited capacity. + Unbounded, +} + /// Creates metric infrastructure that is ready to use /// in the application code. It consists a triple of /// [`MetricsCollector`], [`MetricsProducer`], and @@ -47,9 +60,19 @@ pub use store::Store as MetricsStore; /// A thread that owns the controller, can request current snapshot. /// For more information about API, see [`Command`]. /// -/// ## Example +/// The communication channel between producers and collector is configured +/// via `channel_type` parameter. See [`MetricChannelType`] for details +/// +/// ## Example 1 (Rendezvous channels) +/// ```rust +/// use ipa_metrics::MetricChannelType; +/// let (collector, producer, controller) = ipa_metrics::install(MetricChannelType::Rendezvous); +/// ``` +/// +/// ## Example 2 (unbounded) /// ```rust -/// let (collector, producer, controller) = ipa_metrics::install(); +/// use ipa_metrics::MetricChannelType; +/// let (collector, producer, controller) = ipa_metrics::install(MetricChannelType::Unbounded); /// ``` /// /// [`MetricsCollector`]: crate::MetricsCollector @@ -57,13 +80,18 @@ pub use store::Store as MetricsStore; /// [`MetricsCollectorController`]: crate::MetricsCollectorController /// [`Command`]: crate::ControllerCommand #[must_use] -pub fn install() -> ( +pub fn install( + channel_type: MetricChannelType, +) -> ( MetricsCollector, MetricsProducer, MetricsCollectorController, ) { let (command_tx, command_rx) = crossbeam_channel::unbounded(); - let (tx, rx) = crossbeam_channel::unbounded(); + let (tx, rx) = match channel_type { + MetricChannelType::Rendezvous => crossbeam_channel::bounded(0), + MetricChannelType::Unbounded => crossbeam_channel::unbounded(), + }; ( MetricsCollector { rx, @@ -80,8 +108,9 @@ pub fn install() -> ( /// ## Errors /// if thread cannot be started pub fn install_new_thread( + channel_type: MetricChannelType, ) -> io::Result<(MetricsProducer, MetricsCollectorController, JoinHandle<()>)> { - let (collector, producer, controller) = install(); + let (collector, producer, controller) = install(channel_type); let handle = std::thread::Builder::new() .name("metric-collector".to_string()) .spawn(|| { From 1bb477e76a8c76731ce1246e025f0edb449f8849 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 25 Oct 2024 10:49:32 -0700 Subject: [PATCH 187/191] Improve coverage --- ipa-metrics/src/context.rs | 15 ++++++++++++++- ipa-metrics/src/key.rs | 2 +- ipa-metrics/src/store.rs | 11 +++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/ipa-metrics/src/context.rs b/ipa-metrics/src/context.rs index 7d19f2eb8..f166d610b 100644 --- a/ipa-metrics/src/context.rs +++ b/ipa-metrics/src/context.rs @@ -127,7 +127,7 @@ impl Drop for MetricsContext { mod tests { use std::thread; - use crate::MetricsContext; + use crate::{context::CurrentThreadContext, MetricsContext}; /// Each thread has its local store by default, and it is exclusive to it #[test] @@ -170,4 +170,17 @@ mod tests { drop(ctx); handle.join().unwrap(); } + + #[test] + fn is_connected() { + assert!(!CurrentThreadContext::is_connected()); + let (tx, rx) = crossbeam_channel::unbounded(); + + CurrentThreadContext::init(tx); + CurrentThreadContext::store_mut(|store| store.counter(counter!("foo")).inc(1)); + CurrentThreadContext::flush(); + + assert!(CurrentThreadContext::is_connected()); + assert_eq!(1, rx.recv().unwrap().counter_val(counter!("foo"))); + } } diff --git a/ipa-metrics/src/key.rs b/ipa-metrics/src/key.rs index 02229a11f..36eb60a65 100644 --- a/ipa-metrics/src/key.rs +++ b/ipa-metrics/src/key.rs @@ -203,7 +203,7 @@ impl Hash for OwnedName { #[cfg(test)] pub fn compute_hash(value: V) -> u64 { - let mut hasher = std::hash::DefaultHasher::default(); + let mut hasher = crate::label_hasher(); value.hash(&mut hasher); hasher.finish() diff --git a/ipa-metrics/src/store.rs b/ipa-metrics/src/store.rs index bd792ae6e..e893ffd84 100644 --- a/ipa-metrics/src/store.rs +++ b/ipa-metrics/src/store.rs @@ -229,4 +229,15 @@ mod tests { store.counter(counter!("bar")).inc(1); assert_eq!(2, store.len()); } + + #[test] + fn counters() { + let mut store = Store::default(); + store.counter(counter!("foo")).inc(1); + store.counter(counter!("foo", "h1" => &1)).inc(1); + store.counter(counter!("foo", "h2" => &2)).inc(1); + store.counter(counter!("bar")).inc(1); + + assert_eq!((4, Some(4)), store.counters().size_hint()); + } } From c0b2e6b027a657ee2fafae4e20297cb38e626627 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 25 Oct 2024 15:52:29 -0700 Subject: [PATCH 188/191] Notes about updating rust version --- ipa-core/Cargo.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 454da04b9..6ddfc2009 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -1,6 +1,11 @@ [package] name = "ipa-core" version = "0.1.0" +# When updating the rust version: +# 1. Check at https://hub.docker.com/_/rust that the relevant version of the +# rust:slim-bullseye docker image is available. +# 2. Update the rust version used for draft in +# https://github.com/private-attribution/draft/blob/main/sidecar/ansible/provision.yaml. rust-version = "1.82.0" edition = "2021" build = "build.rs" From 4549440083d89db1262791c0e9b762e77d23f314 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 25 Oct 2024 17:54:50 -0700 Subject: [PATCH 189/191] Fix pre-commit script --- ipa-core/src/lib.rs | 2 ++ pre-commit | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 1ce693e01..345bbe0ae 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -344,6 +344,8 @@ macro_rules! mutually_incompatible { } mutually_incompatible!("in-memory-infra", "real-world-infra"); +#[cfg(not(any(compact_gate, descriptive_gate)))] +compile_error!("At least one of `compact_gate` or `descriptive_gate` features must be enabled"); #[cfg(test)] mod tests { diff --git a/pre-commit b/pre-commit index 9b1c5cb37..b5319ac91 100755 --- a/pre-commit +++ b/pre-commit @@ -106,7 +106,7 @@ then cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate" check "Web tests (descriptive gate)" \ - cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture" + cargo test -p ipa-core --no-default-features --features "cli web-app real-world-infra test-fixture descriptive-gate" check "Concurrency tests" \ cargo test -p ipa-core --release --features "shuttle multi-threading" From fa98432d0828b6465d12c504b87135f4454eb646 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 25 Oct 2024 18:18:15 -0700 Subject: [PATCH 190/191] Feedback --- ipa-metrics/src/key.rs | 2 +- ipa-metrics/src/label.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ipa-metrics/src/key.rs b/ipa-metrics/src/key.rs index 36eb60a65..620e193e3 100644 --- a/ipa-metrics/src/key.rs +++ b/ipa-metrics/src/key.rs @@ -115,7 +115,7 @@ impl<'lv, const LABELS: usize> Name<'lv, LABELS> { #[derive(Debug, Clone, Eq)] pub struct OwnedName { pub key: &'static str, - pub labels: [Option; 5], + labels: [Option; 5], } impl OwnedName { diff --git a/ipa-metrics/src/label.rs b/ipa-metrics/src/label.rs index 86a1b11d1..dd822be86 100644 --- a/ipa-metrics/src/label.rs +++ b/ipa-metrics/src/label.rs @@ -8,7 +8,7 @@ use rustc_hash::FxHasher; pub const MAX_LABELS: usize = 5; /// Provides a fast, non-collision resistant implementation of [`Hasher`] -/// for label values. T +/// for label values. /// /// [`Hasher`]: std::hash::Hasher #[must_use] From 631bfc3deb5221b8efb58a1b627b152750151a9c Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Fri, 25 Oct 2024 23:38:54 -0700 Subject: [PATCH 191/191] add hybrid_protocol function, unimplemented (#1375) * add hybrid_protocol function, unimplemented * remove traitbounds for now, add them back as needed * update hybrid protocol comment * add comment about copy pasted BreakdownKey trait --- ipa-core/src/protocol/hybrid/mod.rs | 78 +++++++++++++++++++++++++++++ ipa-core/src/query/runner/hybrid.rs | 61 +++++++++++++++++----- 2 files changed, 127 insertions(+), 12 deletions(-) diff --git a/ipa-core/src/protocol/hybrid/mod.rs b/ipa-core/src/protocol/hybrid/mod.rs index 71ff41f4c..482f6e939 100644 --- a/ipa-core/src/protocol/hybrid/mod.rs +++ b/ipa-core/src/protocol/hybrid/mod.rs @@ -1 +1,79 @@ pub(crate) mod step; + +use crate::{ + error::Error, + ff::{ + boolean_array::{BooleanArray, BA5, BA8}, + U128Conversions, + }, + helpers::query::DpMechanism, + protocol::{ + context::{ShardedContext, UpgradableContext}, + ipa_prf::{oprf_padding::PaddingParameters, shuffle::Shuffle}, + }, + report::hybrid::IndistinguishableHybridReport, + secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, +}; + +// In theory, we could support (runtime-configured breakdown count) ≤ (compile-time breakdown count) +// ≤ 2^|bk|, with all three values distinct, but at present, there is no runtime configuration and +// the latter two must be equal. The implementation of `move_single_value_to_bucket` does support a +// runtime-specified count via the `breakdown_count` parameter, and implements a runtime check of +// its value. +// +// It would usually be more appropriate to make `MAX_BREAKDOWNS` an associated constant rather than +// a const parameter. However, we want to use it to enforce a correct pairing of the `BK` type +// parameter and the `B` const parameter, and specifying a constraint like +// `BreakdownKey` on an associated constant is not currently supported. (Nor is +// supplying an associated constant `::MAX_BREAKDOWNS` as the value of a const +// parameter.) Structured the way we have it, it probably doesn't make sense to use the +// `BreakdownKey` trait in places where the `B` const parameter is not already available. +// +// These could be imported from src/protocl/ipa_prf/mod.rs +// however we've copy/pasted them here with the intention of deleting that file [TODO] +pub trait BreakdownKey: BooleanArray + U128Conversions {} +impl BreakdownKey<32> for BA5 {} +impl BreakdownKey<256> for BA8 {} + +/// The Hybrid Protocol +/// +/// This protocol takes in a [`Vec>`] +/// and aggregates it into a summary report. `HybridReport`s are either +/// impressions or conversion. The protocol joins these based on their matchkeys, +/// sums the values from conversions grouped by the breakdown key on impressions. +/// To accomplish this, hte protocol performs the follwoing steps +/// 1. Generates a random number of "dummy records" (needed to mask the information that will +/// be revealed in step 4, and thereby provide a differential privacy guarantee on +/// that information leakage) +/// 2. Shuffles the input +/// 3. Computes an OPRF of these elliptic curve points and reveals this "pseudonym" +/// 4. Groups together rows with the same OPRF and sums both the breakdown keys and values. +/// 5. Generates a random number of "dummy records" (needed to mask the information that will +/// be revealed in step 7) +/// 6. Shuffles the input +/// 7. Reveals breakdown keys +/// 8. Sums the values by breakdown keys +/// 9. Adds random noise to the total value for each breakdown key (to provide a +/// differential privacy guarantee) +/// +/// # Errors +/// Propagates errors from config issues or while running the protocol +/// # Panics +/// Propagates errors from config issues or while running the protocol +pub async fn hybrid_protocol<'ctx, C, BK, V, HV, const SS_BITS: usize, const B: usize>( + _ctx: C, + input_rows: Vec>, + _dp_params: DpMechanism, + _dp_padding_params: PaddingParameters, +) -> Result>, Error> +where + C: UpgradableContext + 'ctx + Shuffle + ShardedContext, + BK: BreakdownKey, + V: BooleanArray + U128Conversions, + HV: BooleanArray + U128Conversions, +{ + if input_rows.is_empty() { + return Ok(vec![Replicated::ZERO; B]); + } + unimplemented!("protocol::hybrid::hybrid_protocol is not fully implemented") +} diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index bdc5d9791..06cc2da4a 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -4,18 +4,26 @@ use futures::{stream::iter, StreamExt, TryStreamExt}; use crate::{ error::Error, - ff::boolean_array::{BA20, BA3, BA8}, + ff::{ + boolean_array::{BooleanArray, BA20, BA3, BA8}, + U128Conversions, + }, helpers::{ - query::{HybridQueryParams, QuerySize}, + query::{DpMechanism, HybridQueryParams, QuerySize}, BodyStream, LengthDelimitedStream, }, hpke::PrivateKeyRegistry, - protocol::{context::ShardedContext, hybrid::step::HybridStep, step::ProtocolStep::Hybrid}, + protocol::{ + context::{ShardedContext, UpgradableContext}, + hybrid::{hybrid_protocol, step::HybridStep}, + ipa_prf::{oprf_padding::PaddingParameters, shuffle::Shuffle}, + step::ProtocolStep::Hybrid, + }, query::runner::reshard_tag::reshard_aad, report::hybrid::{ EncryptedHybridReport, IndistinguishableHybridReport, UniqueTag, UniqueTagValidator, }, - secret_sharing::{replicated::semi_honest::AdditiveShare as ReplicatedShare, SharedValue}, + secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, }; #[allow(dead_code)] @@ -25,10 +33,8 @@ pub struct Query { phantom_data: PhantomData<(C, HV)>, } -impl Query -where - C: ShardedContext, -{ +#[allow(dead_code)] +impl Query { pub fn new(query_params: HybridQueryParams, key_registry: Arc) -> Self { Self { config: query_params, @@ -36,14 +42,21 @@ where phantom_data: PhantomData, } } +} +impl Query +where + C: UpgradableContext + Shuffle + ShardedContext, + HV: BooleanArray + U128Conversions, + R: PrivateKeyRegistry, +{ #[tracing::instrument("hybrid_query", skip_all, fields(sz=%query_size))] pub async fn execute( self, ctx: C, query_size: QuerySize, input_stream: BodyStream, - ) -> Result>, Error> { + ) -> Result>, Error> { let Self { config, key_registry, @@ -89,10 +102,34 @@ where .check_duplicates(&resharded_tags) .unwrap(); - let _indistinguishable_reports: Vec> = + let indistinguishable_reports: Vec> = decrypted_reports.into_iter().map(Into::into).collect(); - unimplemented!("query::runnner::HybridQuery.execute is not fully implemented") + let dp_params: DpMechanism = match config.with_dp { + 0 => DpMechanism::NoDp, + _ => DpMechanism::DiscreteLaplace { + epsilon: config.epsilon, + }, + }; + + #[cfg(feature = "relaxed-dp")] + let padding_params = PaddingParameters::relaxed(); + #[cfg(not(feature = "relaxed-dp"))] + let padding_params = PaddingParameters::default(); + + match config.per_user_credit_cap { + 1 => hybrid_protocol::<_, BA8, BA3, HV, 1, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 2 | 4 => hybrid_protocol::<_, BA8, BA3, HV, 2, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 8 => hybrid_protocol::<_, BA8, BA3, HV, 3, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 16 => hybrid_protocol::<_, BA8, BA3, HV, 4, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 32 => hybrid_protocol::<_, BA8, BA3, HV, 5, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 64 => hybrid_protocol::<_, BA8, BA3, HV, 6, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + 128 => hybrid_protocol::<_, BA8, BA3, HV, 7, 256>(ctx, indistinguishable_reports, dp_params, padding_params).await, + _ => panic!( + "Invalid value specified for per-user cap: {:?}. Must be one of 1, 2, 4, 8, 16, 32, 64, or 128.", + config.per_user_credit_cap + ), + } } } @@ -219,7 +256,7 @@ mod tests { // placeholder until the protocol is complete. can be updated to make sure we // get to the unimplemented() call #[should_panic( - expected = "not implemented: query::runnner::HybridQuery.execute is not fully implemented" + expected = "not implemented: protocol::hybrid::hybrid_protocol is not fully implemented" )] async fn encrypted_hybrid_reports() { // While this test currently checks for an unimplemented panic it is