From 384e76d56eebf0b24f59edf4ff1e29ca5ea73586 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Wed, 4 Sep 2024 15:55:15 -0700 Subject: [PATCH 1/3] shuffle verification + simple test --- ipa-core/src/error.rs | 2 + ipa-core/src/ff/boolean_array.rs | 9 +- ipa-core/src/protocol/ipa_prf/shuffle/base.rs | 152 ++++++++- .../src/protocol/ipa_prf/shuffle/malicious.rs | 314 ++++++++++++++++++ ipa-core/src/protocol/ipa_prf/shuffle/mod.rs | 6 +- ipa-core/src/protocol/ipa_prf/shuffle/step.rs | 4 + 6 files changed, 470 insertions(+), 17 deletions(-) create mode 100644 ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs diff --git a/ipa-core/src/error.rs b/ipa-core/src/error.rs index 0d74da809..5a9352124 100644 --- a/ipa-core/src/error.rs +++ b/ipa-core/src/error.rs @@ -96,6 +96,8 @@ pub enum Error { EpsilonOutOfBounds, #[error("Missing total records in {0}")] MissingTotalRecords(String), + #[error("The verification of the shuffle failed: {0}")] + ShuffleValidationFailed(String), } impl Default for Error { diff --git a/ipa-core/src/ff/boolean_array.rs b/ipa-core/src/ff/boolean_array.rs index b05c31465..ca505099e 100644 --- a/ipa-core/src/ff/boolean_array.rs +++ b/ipa-core/src/ff/boolean_array.rs @@ -9,7 +9,7 @@ use typenum::{U14, U2, U32, U8}; use crate::{ error::LengthError, - ff::{boolean::Boolean, ArrayAccess, Expand, Field, Serializable, U128Conversions}, + ff::{boolean::Boolean, ArrayAccess, Expand, Field, Gf32Bit, Serializable, U128Conversions}, protocol::prss::{FromRandom, FromRandomU128}, secret_sharing::{Block, SharedValue, StdArray, Vectorizable}, }; @@ -32,7 +32,11 @@ macro_rules! store_impl { } pub trait BooleanArray: - SharedValue + ArrayAccess + Expand + FromIterator + SharedValue + + ArrayAccess + + Expand + + FromIterator + + TryInto, Error = crate::error::Error> { } @@ -41,6 +45,7 @@ impl BooleanArray for A where + ArrayAccess + Expand + FromIterator + + TryInto, Error = crate::error::Error> { } diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs index 5a0cdea00..b0b3257f9 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs @@ -18,7 +18,10 @@ use crate::{ /// # Errors /// Will propagate errors from transport and a few typecasts -pub async fn shuffle(ctx: C, shares: I) -> Result>, Error> +pub async fn shuffle( + ctx: C, + shares: I, +) -> Result<(Vec>, IntermediateShuffleMessages), Error> where C: Context, I: IntoIterator>, @@ -32,7 +35,13 @@ where // This protocol can take a mutable iterator and replace items in the input. let shares = shares.into_iter(); let Some(shares_len) = NonZeroUsize::new(shares.len()) else { - return Ok(vec![]); + return Ok(( + vec![], + IntermediateShuffleMessages { + x1_or_y1: None, + x2_or_y2: None, + }, + )); }; let ctx_z = ctx.narrow(&OPRFShuffleStep::GenerateZ); let zs = generate_random_tables_with_peers(shares_len, &ctx_z); @@ -44,12 +53,46 @@ where } } +#[allow(dead_code)] +/// This struct stores some intermediate messages during the shuffle. +/// In a maliciously secure shuffle, +/// these messages need to be checked for consistency across helpers. +/// `H1` stores `x1`, `H2` stores `x2` and `H3` stores `y1` and `y2`. +#[derive(Debug, Clone)] +pub struct IntermediateShuffleMessages { + x1_or_y1: Option>, + x2_or_y2: Option>, +} + +#[allow(dead_code)] +impl IntermediateShuffleMessages { + /// When `IntermediateShuffleMessages` is initialized correctly, + /// this function returns `x1` when `Role = H1` + /// and `y1` when `Role = H3`. + /// + /// ## Panics + /// Panics when `Role = H2`, i.e. `x1_or_y1` is `None`. + pub fn get_x1_or_y1(&self) -> &Vec { + self.x1_or_y1.as_ref().unwrap() + } + + /// When `IntermediateShuffleMessages` is initialized correctly, + /// this function returns `x2` when `Role = H2` + /// and `y2` when `Role = H3`. + /// + /// ## Panics + /// Panics when `Role = H1`, i.e. `x2_or_y2` is `None`. + pub fn get_x2_or_y2(&self) -> &Vec { + self.x2_or_y2.as_ref().unwrap() + } +} + async fn run_h1( ctx: &C, batch_size: NonZeroUsize, shares: I, (z_31, z_12): (Zl, Zr), -) -> Result>, Error> +) -> Result<(Vec>, IntermediateShuffleMessages), Error> where C: Context, I: IntoIterator>, @@ -76,13 +119,21 @@ where let (mut rng_perm_l, mut rng_perm_r) = ctx_perm.prss_rng(); x_1.shuffle(&mut rng_perm_r); - let mut x_2 = x_1; + // need to output x_1 + let mut x_2 = x_1.clone(); add_single_shares_in_place(&mut x_2, z_31); x_2.shuffle(&mut rng_perm_l); send_to_peer(&x_2, ctx, &OPRFShuffleStep::TransferX2, Direction::Right).await?; let res = combine_single_shares(a_hat, b_hat).collect::>(); - Ok(res) + // we only need to store x_1 in IntermediateShuffleMessage + Ok(( + res, + IntermediateShuffleMessages { + x1_or_y1: Some(x_1), + x2_or_y2: None, + }, + )) } async fn run_h2( @@ -90,7 +141,7 @@ async fn run_h2( batch_size: NonZeroUsize, shares: I, (z_12, z_23): (Zl, Zr), -) -> Result>, Error> +) -> Result<(Vec>, IntermediateShuffleMessages), Error> where C: Context, I: IntoIterator>, @@ -127,7 +178,8 @@ where ) .await?; - let mut x_3 = x_2; + // we need to output x_2 + let mut x_3 = x_2.clone(); add_single_shares_in_place(&mut x_3, z_23); x_3.shuffle(&mut rng_perm_r); @@ -154,14 +206,21 @@ where let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()); let res = combine_single_shares(b_hat, c_hat).collect::>(); - Ok(res) + // we only need to store x_2 in IntermediateShuffleMessage + Ok(( + res, + IntermediateShuffleMessages { + x1_or_y1: None, + x2_or_y2: Some(x_2), + }, + )) } async fn run_h3( ctx: &C, batch_size: NonZeroUsize, (z_23, z_31): (Zl, Zr), -) -> Result>, Error> +) -> Result<(Vec>, IntermediateShuffleMessages), Error> where C: Context, S: SharedValue + Add, @@ -186,14 +245,16 @@ where ) .await?; - let mut y_2 = y_1; + // need to output y_1 + let mut y_2 = y_1.clone(); add_single_shares_in_place(&mut y_2, z_31); let ctx_perm = ctx.narrow(&OPRFShuffleStep::ApplyPermutations); let (mut rng_perm_l, mut rng_perm_r) = ctx_perm.prss_rng(); y_2.shuffle(&mut rng_perm_r); - let mut y_3 = y_2; + // need to output y_2 + let mut y_3 = y_2.clone(); add_single_shares_in_place(&mut y_3, z_23); y_3.shuffle(&mut rng_perm_l); @@ -218,7 +279,13 @@ where let c_hat = add_single_shares(c_hat_1, c_hat_2); let res = combine_single_shares(c_hat, a_hat).collect::>(); - Ok(res) + Ok(( + res, + IntermediateShuffleMessages { + x1_or_y1: Some(y_1), + x2_or_y2: Some(y_2), + }, + )) } fn add_single_shares(l: L, r: R) -> impl Iterator @@ -343,9 +410,13 @@ where #[cfg(all(test, unit_test))] pub mod tests { + use rand::{thread_rng, Rng}; + use super::shuffle; use crate::{ ff::{Gf40Bit, U128Conversions}, + secret_sharing::replicated::ReplicatedSecretSharing, + test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig}, }; @@ -364,7 +435,7 @@ pub mod tests { // Stable seed is used to get predictable shuffle results. let mut actual = TestWorld::new_with(TestWorldConfig::default().with_seed(123)) .semi_honest(records.clone().into_iter(), |ctx, shares| async move { - shuffle(ctx, shares).await.unwrap() + shuffle(ctx, shares).await.unwrap().0 }) .await .reconstruct(); @@ -381,4 +452,59 @@ pub mod tests { "Shuffle should not change the items in the set" ); } + + #[test] + fn check_intermediate_messages() { + const RECORD_AMOUNT: usize = 100; + run(|| async { + let world = TestWorld::default(); + let mut rng = thread_rng(); + // using Gf40Bit here since it implements cmp such that vec can later be sorted + let mut records = (0..RECORD_AMOUNT) + .map(|_| rng.gen()) + .collect::>(); + + let [h1, h2, h3] = world + .semi_honest(records.clone().into_iter(), |ctx, records| async move { + shuffle(ctx, records).await + }) + .await; + + // check consistency + // i.e. x_1 xor y_1 = x_2 xor y_2 = C xor A xor B + let (h1_shares, h1_messages) = h1.unwrap(); + let (_, h2_messages) = h2.unwrap(); + let (h3_shares, h3_messages) = h3.unwrap(); + + let mut x1_xor_y1 = h1_messages + .x1_or_y1 + .unwrap() + .iter() + .zip(h3_messages.x1_or_y1.unwrap()) + .map(|(x1, y1)| x1 + y1) + .collect::>(); + let mut x2_xor_y2 = h2_messages + .x2_or_y2 + .unwrap() + .iter() + .zip(h3_messages.x2_or_y2.unwrap()) + .map(|(x2, y2)| x2 + y2) + .collect::>(); + let mut a_xor_b_xor_c = h1_shares + .iter() + .zip(h3_shares) + .map(|(h1_share, h3_share)| h1_share.left() + h1_share.right() + h3_share.left()) + .collect::>(); + + // unshuffle by sorting + records.sort(); + x1_xor_y1.sort(); + x2_xor_y2.sort(); + a_xor_b_xor_c.sort(); + + assert_eq!(records, a_xor_b_xor_c); + assert_eq!(records, x1_xor_y1); + assert_eq!(records, x2_xor_y2); + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs new file mode 100644 index 000000000..a3bd8272b --- /dev/null +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -0,0 +1,314 @@ +use std::{borrow::Borrow, iter}; + +use futures_util::future::{try_join, try_join3}; + +use crate::{ + error::Error, + ff::{boolean_array::BooleanArray, Field, Gf32Bit}, + helpers::{ + hashing::{compute_hash, Hash}, + Direction, Role, + }, + protocol::{ + basics::malicious_reveal, + context::Context, + ipa_prf::shuffle::{base::IntermediateShuffleMessages, step::OPRFShuffleStep}, + RecordId, + }, + secret_sharing::{ + replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, + SharedValue, StdArray, + }, +}; + +/// This function verifies the `shuffled_shares` and the `IntermediateShuffleMessages`. +/// +/// ## Errors +/// Propagates network errors. +/// Further, returns an error when messages are inconsistent with the MAC tags. +async fn verify_shuffle( + ctx: C, + key_shares: &[AdditiveShare], + shuffled_shares: &[AdditiveShare], + messages: IntermediateShuffleMessages, +) -> Result<(), Error> { + // reveal keys + let k_ctx = ctx.narrow(&OPRFShuffleStep::RevealMACKey).set_total_records(key_shares.len()); + let keys = reveal_keys(&k_ctx, key_shares).await?; + + // verify messages and shares + match ctx.role() { + Role::H1 => h1_verify(ctx, &keys, shuffled_shares, messages.get_x1_or_y1()).await, + Role::H2 => h2_verify(ctx, &keys, shuffled_shares, messages.get_x2_or_y2()).await, + Role::H3 => { + h3_verify( + ctx, + &keys, + shuffled_shares, + messages.get_x1_or_y1(), + messages.get_x2_or_y2(), + ) + .await + } + } +} + +/// This is the verification function run by `H1`. +/// `H1` computes the hash for `x1` and `a_xor_b`. +/// Further, he receives `hash_y1` and `hash_c_h3` from `H3` +/// and `hash_c_h2` from `H2`. +/// +/// ## Errors +/// Propagates network errors. Further it returns an error when +/// `hash_x1 != hash_y1` or `hash_c_h2 != hash_a_xor_b` +/// or `hash_c_h3 != hash_a_xor_b`. +async fn h1_verify( + ctx: C, + keys: &[StdArray], + share_a_and_b: &[AdditiveShare], + x1: &[S], +) -> Result<(), Error> { + // compute hashes + // compute hash for x1 + let hash_x1 = compute_row_hash::(&keys, x1); + // compute hash for A xor B + let hash_a_xor_b = compute_row_hash::( + &keys, + share_a_and_b + .iter() + .map(|share| share.left() + share.right()), + ); + + // setup channels + let h3_ctx = ctx + .narrow(&OPRFShuffleStep::HashesH3toH1) + .set_total_records(2); + let h2_ctx = ctx + .narrow(&OPRFShuffleStep::HashH2toH1) + .set_total_records(1); + let channel_h3 = &h3_ctx.recv_channel::(ctx.role().peer(Direction::Left)); + let channel_h2 = &h2_ctx.recv_channel::(ctx.role().peer(Direction::Right)); + + // receive hashes + let (hashes_h3, hash_h2) = try_join( + h3_ctx.parallel_join( + (0..=1).map(|i| async move { channel_h3.receive(RecordId::from(i)).await }), + ), + channel_h2.receive(RecordId::FIRST), + ) + .await?; + + // check y1 + if hash_x1 != hashes_h3[0] { + return Err(Error::ShuffleValidationFailed(format!( + "Y1 is inconsistent: hash of x1: {:?}, hash of y1: {:?}", + hash_x1, hashes_h3[0] + ))); + } + + // check c from h3 + if hash_a_xor_b != hashes_h3[1] { + return Err(Error::ShuffleValidationFailed(format!( + "C from H3 is inconsistent: hash of a_xor_b: {:?}, hash of C: {:?}", + hash_a_xor_b, hashes_h3[1] + ))); + } + + // check h2 + if hash_a_xor_b != hash_h2 { + return Err(Error::ShuffleValidationFailed(format!( + "C from H2 is inconsistent: hash of a_xor_b: {:?}, hash of C: {:?}", + hash_a_xor_b, hash_h2 + ))); + } + + Ok(()) +} + +/// This is the verification function run by `H2`. +/// `H2` computes the hash for `x2` and `c` +/// and sends the latter to `H1`. +/// Further, he receives `hash_y2` from `H3` +/// +/// ## Errors +/// Propagates network errors. Further it returns an error when +/// `hash_x2 != hash_y2`. +async fn h2_verify( + ctx: C, + keys: &[StdArray], + share_b_and_c: &[AdditiveShare], + x2: &[S], +) -> Result<(), Error> { + // compute hashes + // compute hash for x2 + let hash_x2 = compute_row_hash::(&keys, x2); + // compute hash for C + let hash_c = + compute_row_hash::(&keys, share_b_and_c.iter().map(|share| share.right())); + + // setup channels + let h1_ctx = ctx + .narrow(&OPRFShuffleStep::HashH2toH1) + .set_total_records(1); + let h3_ctx = ctx + .narrow(&OPRFShuffleStep::HashH3toH2) + .set_total_records(1); + let channel_h1 = &h1_ctx.send_channel::(ctx.role().peer(Direction::Left)); + let channel_h3 = &h3_ctx.recv_channel::(ctx.role().peer(Direction::Right)); + + // send and receive hash + let (_, hash_h3) = try_join( + channel_h1.send(RecordId::FIRST, hash_c), + channel_h3.receive(RecordId::FIRST), + ) + .await?; + + // check x2 + if hash_x2 != hash_h3 { + return Err(Error::ShuffleValidationFailed(format!( + "X2 is inconsistent: hash of x2: {:?}, hash of y2: {:?}", + hash_x2, hash_h3 + ))); + } + + Ok(()) +} + +/// This is the verification function run by `H3`. +/// `H3` computes the hash for `y1`, `y2` and `c` +/// and sends `y1`, `c` to `H1` and `y2` to `H2`. +/// +/// ## Errors +/// Propagates network errors. +async fn h3_verify( + ctx: C, + keys: &[StdArray], + share_c_and_a: &[AdditiveShare], + y1: &[S], + y2: &[S], +) -> Result<(), Error> { + // compute hashes + // compute hash for y1 + let hash_y1 = compute_row_hash::(&keys, y1); + // compute hash for y2 + let hash_y2 = compute_row_hash::(&keys, y2); + // compute hash for C + let hash_c = compute_row_hash::(&keys, share_c_and_a.iter().map(|share| share.left())); + + // setup channels + let h1_ctx = ctx + .narrow(&OPRFShuffleStep::HashesH3toH1) + .set_total_records(2); + let h2_ctx = ctx + .narrow(&OPRFShuffleStep::HashH3toH2) + .set_total_records(1); + let channel_h1 = &h1_ctx.send_channel::(ctx.role().peer(Direction::Right)); + let channel_h2 = &h2_ctx.send_channel::(ctx.role().peer(Direction::Left)); + + // send and receive hash + let _ = try_join3( + channel_h1.send(RecordId::FIRST, hash_y1), + channel_h1.send(RecordId::from(1usize), hash_c), + channel_h2.send(RecordId::FIRST, hash_y2), + ) + .await?; + + Ok(()) +} + +/// This function computes for each item in the iterator the inner product with `keys`. +/// It concatenates all inner products and hashes them. +/// +/// ## Panics +/// Panics when conversion from `BooleanArray` to `Vec(keys: &[StdArray], row_iterator: I) -> Hash +where + S: BooleanArray, + B: Borrow, + I: IntoIterator, +{ + let iterator = row_iterator + .into_iter() + .map(|s| (*(s.borrow())).try_into().unwrap()); + compute_hash(iterator.map(|row| { + row.iter() + .zip(keys) + .fold(Gf32Bit::ZERO, |acc, (row_entry, key)| { + acc + *row_entry * *key.first() + }) + })) +} + +/// This function reveals the MAC keys, +/// stores them in a vector +/// and appends a `Gf32Bit::ONE` +/// +/// It uses `parallel_join` and therefore vector elements are a `StdArray` of length `1`. +/// +/// ## Errors +/// Propagates errors from `parallel_join` and `malicious_reveal`. +async fn reveal_keys( + ctx: &C, + key_shares: &[AdditiveShare], +) -> Result>, Error> { + // reveal MAC keys + let mut keys = ctx + .parallel_join(key_shares.iter().enumerate().map(|(i, key)| async move { + malicious_reveal(ctx.clone(), RecordId::from(i), None, key).await + })) + .await? + .into_iter() + .flatten() + .collect::>(); + // add a one, since last row element is tag which is not multiplied with a key + keys.push(StdArray::from_iter(iter::once(Gf32Bit::ONE))); + Ok(keys) +} + + +#[cfg(all(test, unit_test))] +mod tests { + use rand::{Rng, thread_rng}; + use crate::ff::boolean_array::{BA64}; + use crate::ff::Serializable; + use crate::protocol::ipa_prf::shuffle::base::shuffle; + use crate::test_executor::run; + use crate::test_fixture::{Runner, TestWorld}; + use super::*; + + /// This test checks the correctness of the malicious shuffle + /// when all parties behave honestly + /// and all the MAC keys are `Gf32Bit::ONE`. + /// Further, each row consists of a `BA32` and a `BA32` tag. + #[test] + fn check_shuffle_with_simple_mac() { + const RECORD_AMOUNT: usize = 10; + run(|| async { + let world = TestWorld::default(); + let mut rng = thread_rng(); + let records = (0..RECORD_AMOUNT) + .map(|_| { + let entry = rng.gen::<[u8;4]>(); + let mut entry_and_tag = [0u8;8]; + entry_and_tag[0..4].copy_from_slice(&entry); + entry_and_tag[4..8].copy_from_slice(&entry); + BA64::deserialize_from_slice(&entry_and_tag) + }) + .collect::>(); + + let _ = world + .semi_honest( + records.into_iter(), + |ctx, rows, | async move { + // trivial shares of Gf32Bit::ONE + let key_shares = vec![AdditiveShare::new(Gf32Bit::ONE,Gf32Bit::ONE);1]; + // run shuffle + let (shares,messages) = shuffle(ctx.narrow("shuffle"), rows).await.unwrap(); + // verify it + verify_shuffle(ctx.narrow("verify"),&key_shares,&shares, messages).await.unwrap(); + }) + .await; + }); + } +} + diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index e0868fb96..2908bf066 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -22,6 +22,8 @@ use crate::{ }; pub mod base; +#[allow(dead_code)] +pub mod malicious; #[cfg(descriptive_gate)] mod sharded; pub(crate) mod step; @@ -42,7 +44,7 @@ where .map(|item| oprfreport_to_shuffle_input::(&item)) .collect::>(); - let shuffled = shuffle(ctx, shuffle_input).await?; + let (shuffled, _) = shuffle(ctx, shuffle_input).await?; Ok(shuffled .into_iter() @@ -69,7 +71,7 @@ where .map(|item| attribution_outputs_to_shuffle_input::(&item)) .collect::>(); - let shuffled = shuffle(ctx, shuffle_input).await?; + let (shuffled, _) = shuffle(ctx, shuffle_input).await?; Ok(shuffled .into_iter() diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs index e5014d775..c9de371b3 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs @@ -9,4 +9,8 @@ pub(crate) enum OPRFShuffleStep { TransferCHat, TransferX2, TransferY1, + RevealMACKey, + HashesH3toH1, + HashH2toH1, + HashH3toH2, } From e05fb5e739414e93ec8169f15ba475894beb538f Mon Sep 17 00:00:00 2001 From: danielmasny Date: Wed, 4 Sep 2024 16:14:34 -0700 Subject: [PATCH 2/3] clean up shuffle verification --- .../src/protocol/ipa_prf/shuffle/malicious.rs | 97 ++++++++++--------- 1 file changed, 51 insertions(+), 46 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index a3bd8272b..a8f368b35 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -7,7 +7,7 @@ use crate::{ ff::{boolean_array::BooleanArray, Field, Gf32Bit}, helpers::{ hashing::{compute_hash, Hash}, - Direction, Role, + Direction, Role, TotalRecords, }, protocol::{ basics::malicious_reveal, @@ -33,7 +33,9 @@ async fn verify_shuffle( messages: IntermediateShuffleMessages, ) -> Result<(), Error> { // reveal keys - let k_ctx = ctx.narrow(&OPRFShuffleStep::RevealMACKey).set_total_records(key_shares.len()); + let k_ctx = ctx + .narrow(&OPRFShuffleStep::RevealMACKey) + .set_total_records(TotalRecords::specified(key_shares.len())?); let keys = reveal_keys(&k_ctx, key_shares).await?; // verify messages and shares @@ -70,10 +72,10 @@ async fn h1_verify( ) -> Result<(), Error> { // compute hashes // compute hash for x1 - let hash_x1 = compute_row_hash::(&keys, x1); + let hash_x1 = compute_row_hash::(keys, x1); // compute hash for A xor B let hash_a_xor_b = compute_row_hash::( - &keys, + keys, share_a_and_b .iter() .map(|share| share.left() + share.right()), @@ -82,17 +84,17 @@ async fn h1_verify( // setup channels let h3_ctx = ctx .narrow(&OPRFShuffleStep::HashesH3toH1) - .set_total_records(2); + .set_total_records(TotalRecords::specified(2)?); let h2_ctx = ctx .narrow(&OPRFShuffleStep::HashH2toH1) - .set_total_records(1); + .set_total_records(TotalRecords::specified(1)?); let channel_h3 = &h3_ctx.recv_channel::(ctx.role().peer(Direction::Left)); let channel_h2 = &h2_ctx.recv_channel::(ctx.role().peer(Direction::Right)); // receive hashes let (hashes_h3, hash_h2) = try_join( h3_ctx.parallel_join( - (0..=1).map(|i| async move { channel_h3.receive(RecordId::from(i)).await }), + (0usize..=1).map(|i| async move { channel_h3.receive(RecordId::from(i)).await }), ), channel_h2.receive(RecordId::FIRST), ) @@ -101,24 +103,23 @@ async fn h1_verify( // check y1 if hash_x1 != hashes_h3[0] { return Err(Error::ShuffleValidationFailed(format!( - "Y1 is inconsistent: hash of x1: {:?}, hash of y1: {:?}", - hash_x1, hashes_h3[0] + "Y1 is inconsistent: hash of x1: {hash_x1:?}, hash of y1: {:?}", + hashes_h3[0] ))); } // check c from h3 if hash_a_xor_b != hashes_h3[1] { return Err(Error::ShuffleValidationFailed(format!( - "C from H3 is inconsistent: hash of a_xor_b: {:?}, hash of C: {:?}", - hash_a_xor_b, hashes_h3[1] + "C from H3 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {:?}", + hashes_h3[1] ))); } // check h2 if hash_a_xor_b != hash_h2 { return Err(Error::ShuffleValidationFailed(format!( - "C from H2 is inconsistent: hash of a_xor_b: {:?}, hash of C: {:?}", - hash_a_xor_b, hash_h2 + "C from H2 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h2:?}" ))); } @@ -141,23 +142,25 @@ async fn h2_verify( ) -> Result<(), Error> { // compute hashes // compute hash for x2 - let hash_x2 = compute_row_hash::(&keys, x2); + let hash_x2 = compute_row_hash::(keys, x2); // compute hash for C - let hash_c = - compute_row_hash::(&keys, share_b_and_c.iter().map(|share| share.right())); + let hash_c = compute_row_hash::( + keys, + share_b_and_c.iter().map(ReplicatedSecretSharing::right), + ); // setup channels let h1_ctx = ctx .narrow(&OPRFShuffleStep::HashH2toH1) - .set_total_records(1); + .set_total_records(TotalRecords::specified(1)?); let h3_ctx = ctx .narrow(&OPRFShuffleStep::HashH3toH2) - .set_total_records(1); + .set_total_records(TotalRecords::specified(1)?); let channel_h1 = &h1_ctx.send_channel::(ctx.role().peer(Direction::Left)); let channel_h3 = &h3_ctx.recv_channel::(ctx.role().peer(Direction::Right)); // send and receive hash - let (_, hash_h3) = try_join( + let ((), hash_h3) = try_join( channel_h1.send(RecordId::FIRST, hash_c), channel_h3.receive(RecordId::FIRST), ) @@ -166,8 +169,7 @@ async fn h2_verify( // check x2 if hash_x2 != hash_h3 { return Err(Error::ShuffleValidationFailed(format!( - "X2 is inconsistent: hash of x2: {:?}, hash of y2: {:?}", - hash_x2, hash_h3 + "X2 is inconsistent: hash of x2: {hash_x2:?}, hash of y2: {hash_h3:?}" ))); } @@ -189,19 +191,22 @@ async fn h3_verify( ) -> Result<(), Error> { // compute hashes // compute hash for y1 - let hash_y1 = compute_row_hash::(&keys, y1); + let hash_y1 = compute_row_hash::(keys, y1); // compute hash for y2 - let hash_y2 = compute_row_hash::(&keys, y2); + let hash_y2 = compute_row_hash::(keys, y2); // compute hash for C - let hash_c = compute_row_hash::(&keys, share_c_and_a.iter().map(|share| share.left())); + let hash_c = compute_row_hash::( + keys, + share_c_and_a.iter().map(ReplicatedSecretSharing::left), + ); // setup channels let h1_ctx = ctx .narrow(&OPRFShuffleStep::HashesH3toH1) - .set_total_records(2); + .set_total_records(TotalRecords::specified(2)?); let h2_ctx = ctx .narrow(&OPRFShuffleStep::HashH3toH2) - .set_total_records(1); + .set_total_records(TotalRecords::specified(1)?); let channel_h1 = &h1_ctx.send_channel::(ctx.role().peer(Direction::Right)); let channel_h2 = &h2_ctx.send_channel::(ctx.role().peer(Direction::Left)); @@ -261,20 +266,21 @@ async fn reveal_keys( .flatten() .collect::>(); // add a one, since last row element is tag which is not multiplied with a key - keys.push(StdArray::from_iter(iter::once(Gf32Bit::ONE))); + keys.push(iter::once(Gf32Bit::ONE).collect()); Ok(keys) } - #[cfg(all(test, unit_test))] mod tests { - use rand::{Rng, thread_rng}; - use crate::ff::boolean_array::{BA64}; - use crate::ff::Serializable; - use crate::protocol::ipa_prf::shuffle::base::shuffle; - use crate::test_executor::run; - use crate::test_fixture::{Runner, TestWorld}; + use rand::{thread_rng, Rng}; + use super::*; + use crate::{ + ff::{boolean_array::BA64, Serializable}, + protocol::ipa_prf::shuffle::base::shuffle, + test_executor::run, + test_fixture::{Runner, TestWorld}, + }; /// This test checks the correctness of the malicious shuffle /// when all parties behave honestly @@ -288,8 +294,8 @@ mod tests { let mut rng = thread_rng(); let records = (0..RECORD_AMOUNT) .map(|_| { - let entry = rng.gen::<[u8;4]>(); - let mut entry_and_tag = [0u8;8]; + let entry = rng.gen::<[u8; 4]>(); + let mut entry_and_tag = [0u8; 8]; entry_and_tag[0..4].copy_from_slice(&entry); entry_and_tag[4..8].copy_from_slice(&entry); BA64::deserialize_from_slice(&entry_and_tag) @@ -297,18 +303,17 @@ mod tests { .collect::>(); let _ = world - .semi_honest( - records.into_iter(), - |ctx, rows, | async move { - // trivial shares of Gf32Bit::ONE - let key_shares = vec![AdditiveShare::new(Gf32Bit::ONE,Gf32Bit::ONE);1]; - // run shuffle - let (shares,messages) = shuffle(ctx.narrow("shuffle"), rows).await.unwrap(); - // verify it - verify_shuffle(ctx.narrow("verify"),&key_shares,&shares, messages).await.unwrap(); + .semi_honest(records.into_iter(), |ctx, rows| async move { + // trivial shares of Gf32Bit::ONE + let key_shares = vec![AdditiveShare::new(Gf32Bit::ONE, Gf32Bit::ONE); 1]; + // run shuffle + let (shares, messages) = shuffle(ctx.narrow("shuffle"), rows).await.unwrap(); + // verify it + verify_shuffle(ctx.narrow("verify"), &key_shares, &shares, messages) + .await + .unwrap(); }) .await; }); } } - From b7194fe55f9e5a0ddc63e7a0becc3a92448899a3 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Fri, 6 Sep 2024 15:28:38 -0700 Subject: [PATCH 3/3] improving shuffle verification using Alex's suggestions --- ipa-core/src/protocol/ipa_prf/shuffle/base.rs | 18 ++++- .../src/protocol/ipa_prf/shuffle/malicious.rs | 73 ++++++++----------- 2 files changed, 46 insertions(+), 45 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs index 8020bfebf..a34477d69 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs @@ -72,8 +72,8 @@ impl IntermediateShuffleMessages { /// /// ## Panics /// Panics when `Role = H2`, i.e. `x1_or_y1` is `None`. - pub fn get_x1_or_y1(&self) -> &Vec { - self.x1_or_y1.as_ref().unwrap() + pub fn get_x1_or_y1(self) -> Vec { + self.x1_or_y1.unwrap() } /// When `IntermediateShuffleMessages` is initialized correctly, @@ -82,8 +82,18 @@ impl IntermediateShuffleMessages { /// /// ## Panics /// Panics when `Role = H1`, i.e. `x2_or_y2` is `None`. - pub fn get_x2_or_y2(&self) -> &Vec { - self.x2_or_y2.as_ref().unwrap() + pub fn get_x2_or_y2(self) -> Vec { + self.x2_or_y2.unwrap() + } + + /// When `IntermediateShuffleMessages` is initialized correctly, + /// this function returns `y1` and `y2` when `Role = H3`. + /// + /// ## Panics + /// Panics when `Role = H1`, i.e. `x2_or_y2` is `None` or + /// when `Role = H2`, i.e. `x1_or_y1` is `None`. + pub fn get_both_x_or_ys(self) -> (Vec, Vec) { + (self.x1_or_y1.unwrap(), self.x2_or_y2.unwrap()) } } diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index a8f368b35..a58fcae58 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -1,4 +1,4 @@ -use std::{borrow::Borrow, iter}; +use std::iter; use futures_util::future::{try_join, try_join3}; @@ -17,7 +17,7 @@ use crate::{ }, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, - SharedValue, StdArray, + SharedValue, SharedValueArray, StdArray, }, }; @@ -43,14 +43,8 @@ async fn verify_shuffle( Role::H1 => h1_verify(ctx, &keys, shuffled_shares, messages.get_x1_or_y1()).await, Role::H2 => h2_verify(ctx, &keys, shuffled_shares, messages.get_x2_or_y2()).await, Role::H3 => { - h3_verify( - ctx, - &keys, - shuffled_shares, - messages.get_x1_or_y1(), - messages.get_x2_or_y2(), - ) - .await + let (y1, y2) = messages.get_both_x_or_ys(); + h3_verify(ctx, &keys, shuffled_shares, y1, y2).await } } } @@ -68,13 +62,13 @@ async fn h1_verify( ctx: C, keys: &[StdArray], share_a_and_b: &[AdditiveShare], - x1: &[S], + x1: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for x1 - let hash_x1 = compute_row_hash::(keys, x1); + let hash_x1 = compute_row_hash(keys, x1); // compute hash for A xor B - let hash_a_xor_b = compute_row_hash::( + let hash_a_xor_b = compute_row_hash( keys, share_a_and_b .iter() @@ -87,32 +81,29 @@ async fn h1_verify( .set_total_records(TotalRecords::specified(2)?); let h2_ctx = ctx .narrow(&OPRFShuffleStep::HashH2toH1) - .set_total_records(TotalRecords::specified(1)?); + .set_total_records(TotalRecords::ONE); let channel_h3 = &h3_ctx.recv_channel::(ctx.role().peer(Direction::Left)); let channel_h2 = &h2_ctx.recv_channel::(ctx.role().peer(Direction::Right)); // receive hashes - let (hashes_h3, hash_h2) = try_join( - h3_ctx.parallel_join( - (0usize..=1).map(|i| async move { channel_h3.receive(RecordId::from(i)).await }), - ), + let (hash_y1, hash_h3, hash_h2) = try_join3( + channel_h3.receive(RecordId::FIRST), + channel_h3.receive(RecordId::from(1usize)), channel_h2.receive(RecordId::FIRST), ) .await?; // check y1 - if hash_x1 != hashes_h3[0] { + if hash_x1 != hash_y1 { return Err(Error::ShuffleValidationFailed(format!( - "Y1 is inconsistent: hash of x1: {hash_x1:?}, hash of y1: {:?}", - hashes_h3[0] + "Y1 is inconsistent: hash of x1: {hash_x1:?}, hash of y1: {hash_y1:?}" ))); } // check c from h3 - if hash_a_xor_b != hashes_h3[1] { + if hash_a_xor_b != hash_h3 { return Err(Error::ShuffleValidationFailed(format!( - "C from H3 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {:?}", - hashes_h3[1] + "C from H3 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h3:?}" ))); } @@ -138,13 +129,13 @@ async fn h2_verify( ctx: C, keys: &[StdArray], share_b_and_c: &[AdditiveShare], - x2: &[S], + x2: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for x2 - let hash_x2 = compute_row_hash::(keys, x2); + let hash_x2 = compute_row_hash(keys, x2); // compute hash for C - let hash_c = compute_row_hash::( + let hash_c = compute_row_hash( keys, share_b_and_c.iter().map(ReplicatedSecretSharing::right), ); @@ -186,16 +177,16 @@ async fn h3_verify( ctx: C, keys: &[StdArray], share_c_and_a: &[AdditiveShare], - y1: &[S], - y2: &[S], + y1: Vec, + y2: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for y1 - let hash_y1 = compute_row_hash::(keys, y1); + let hash_y1 = compute_row_hash(keys, y1); // compute hash for y2 - let hash_y2 = compute_row_hash::(keys, y2); + let hash_y2 = compute_row_hash(keys, y2); // compute hash for C - let hash_c = compute_row_hash::( + let hash_c = compute_row_hash( keys, share_c_and_a.iter().map(ReplicatedSecretSharing::left), ); @@ -226,20 +217,19 @@ async fn h3_verify( /// /// ## Panics /// Panics when conversion from `BooleanArray` to `Vec(keys: &[StdArray], row_iterator: I) -> Hash +fn compute_row_hash(keys: &[StdArray], row_iterator: I) -> Hash where S: BooleanArray, - B: Borrow, - I: IntoIterator, + I: IntoIterator, { let iterator = row_iterator .into_iter() - .map(|s| (*(s.borrow())).try_into().unwrap()); + .map(|row| >>::try_into(row).unwrap()); compute_hash(iterator.map(|row| { - row.iter() + row.into_iter() .zip(keys) .fold(Gf32Bit::ZERO, |acc, (row_entry, key)| { - acc + *row_entry * *key.first() + acc + row_entry * *key.first() }) })) } @@ -257,16 +247,17 @@ async fn reveal_keys( key_shares: &[AdditiveShare], ) -> Result>, Error> { // reveal MAC keys - let mut keys = ctx + let keys = ctx .parallel_join(key_shares.iter().enumerate().map(|(i, key)| async move { malicious_reveal(ctx.clone(), RecordId::from(i), None, key).await })) .await? .into_iter() .flatten() + // add a one, since last row element is tag which is not multiplied with a key + .chain(iter::once(StdArray::from_fn(|_| Gf32Bit::ONE))) .collect::>(); - // add a one, since last row element is tag which is not multiplied with a key - keys.push(iter::once(Gf32Bit::ONE).collect()); + Ok(keys) }