From 21e954d62215cf6940edc700857b877ee958e103 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 24 Oct 2024 13:12:51 -0700 Subject: [PATCH 1/2] Malicious sharded shuffle --- ipa-core/src/protocol/context/malicious.rs | 29 +- ipa-core/src/protocol/context/semi_honest.rs | 1 + ipa-core/src/protocol/ipa_prf/shuffle/base.rs | 26 +- .../src/protocol/ipa_prf/shuffle/malicious.rs | 367 ++++++++++++++++-- ipa-core/src/protocol/ipa_prf/shuffle/mod.rs | 70 +++- .../src/protocol/ipa_prf/shuffle/sharded.rs | 196 +++++----- ipa-core/src/protocol/ipa_prf/shuffle/step.rs | 33 +- ipa-core/src/protocol/step.rs | 2 + 8 files changed, 539 insertions(+), 185 deletions(-) diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index 0253a810d..3f723ecc8 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -9,7 +9,10 @@ use ipa_step::{Step, StepNarrow}; use crate::{ error::Error, - helpers::{Gateway, MpcMessage, MpcReceivingEnd, Role, SendingEnd, TotalRecords}, + helpers::{ + Gateway, Message, MpcMessage, MpcReceivingEnd, Role, SendingEnd, ShardReceivingEnd, + TotalRecords, + }, protocol::{ basics::mul::{semi_honest_multiply, step::MaliciousMultiplyStep::RandomnessForValidation}, context::{ @@ -19,7 +22,7 @@ use crate::{ step::UpgradeStep, upgrade::Upgradable, validator::{self, BatchValidator}, - Base, Context as ContextTrait, InstrumentedSequentialSharedRandomness, + Base, Context as ContextTrait, InstrumentedSequentialSharedRandomness, ShardedContext, SpecialAccessToUpgradedContext, UpgradableContext, UpgradedContext, }, prss::{Endpoint as PrssEndpoint, FromPrss}, @@ -30,7 +33,7 @@ use crate::{ semi_honest::AdditiveShare as Replicated, }, seq_join::SeqJoin, - sharding::{NotSharded, ShardBinding}, + sharding::{NotSharded, ShardBinding, ShardConfiguration, ShardIndex, Sharded}, sync::Arc, }; @@ -53,6 +56,26 @@ pub struct Context<'a, B: ShardBinding> { inner: Base<'a, B>, } +impl ShardConfiguration for Context<'_, Sharded> { + fn shard_id(&self) -> ShardIndex { + self.inner.shard_id() + } + + fn shard_count(&self) -> ShardIndex { + self.inner.shard_count() + } +} + +impl ShardedContext for Context<'_, Sharded> { + fn shard_send_channel(&self, dest_shard: ShardIndex) -> SendingEnd { + self.inner.shard_send_channel(dest_shard) + } + + fn shard_recv_channel(&self, origin: ShardIndex) -> ShardReceivingEnd { + self.inner.shard_recv_channel(origin) + } +} + impl<'a> Context<'a, NotSharded> { pub fn new(participant: &'a PrssEndpoint, gateway: &'a Gateway) -> Self { Self::new_with_gate(participant, gateway, Gate::default(), NotSharded) diff --git a/ipa-core/src/protocol/context/semi_honest.rs b/ipa-core/src/protocol/context/semi_honest.rs index bd8c2e260..e1b2f7d8e 100644 --- a/ipa-core/src/protocol/context/semi_honest.rs +++ b/ipa-core/src/protocol/context/semi_honest.rs @@ -37,6 +37,7 @@ use crate::{ pub struct Context<'a, B: ShardBinding> { inner: Base<'a, B>, } + impl ShardConfiguration for Context<'_, Sharded> { fn shard_id(&self) -> ShardIndex { self.inner.shard_id() diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs index 6b495eff7..12ebaed7d 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs @@ -92,7 +92,7 @@ where let mut x_2 = x_1.clone(); add_single_shares_in_place(&mut x_2, z_31); x_2.shuffle(&mut rng_perm_l); - send_to_peer(&x_2, ctx, &OPRFShuffleStep::TransferX2, Direction::Right).await?; + send_to_peer(&x_2, ctx, &OPRFShuffleStep::TransferXY, Direction::Right).await?; let res = combine_single_shares(a_hat, b_hat).collect::>(); // we only need to store x_1 in IntermediateShuffleMessage @@ -130,12 +130,12 @@ where let mut x_2: Vec = Vec::with_capacity(batch_size.get()); future::try_join( - send_to_peer(&y_1, ctx, &OPRFShuffleStep::TransferY1, Direction::Right), + send_to_peer(&y_1, ctx, &OPRFShuffleStep::TransferXY, Direction::Right), receive_from_peer_into( &mut x_2, batch_size, ctx, - &OPRFShuffleStep::TransferX2, + &OPRFShuffleStep::TransferXY, Direction::Left, ), ) @@ -153,17 +153,12 @@ where let mut c_hat_2 = repurpose_allocation(x_3); future::try_join( - send_to_peer( - &c_hat_1, - ctx, - &OPRFShuffleStep::TransferCHat, - Direction::Right, - ), + send_to_peer(&c_hat_1, ctx, &OPRFShuffleStep::TransferC, Direction::Right), receive_from_peer_into( &mut c_hat_2, batch_size, ctx, - &OPRFShuffleStep::TransferCHat, + &OPRFShuffleStep::TransferC, Direction::Right, ), ) @@ -199,7 +194,7 @@ where &mut y_1, batch_size, ctx, - &OPRFShuffleStep::TransferY1, + &OPRFShuffleStep::TransferXY, Direction::Left, ) .await?; @@ -224,17 +219,12 @@ where let c_hat_2: Vec = add_single_shares(y_3.iter(), a_hat.iter()).collect(); let mut c_hat_1 = repurpose_allocation(y_3); future::try_join( - send_to_peer( - &c_hat_2, - ctx, - &OPRFShuffleStep::TransferCHat, - Direction::Left, - ), + send_to_peer(&c_hat_2, ctx, &OPRFShuffleStep::TransferC, Direction::Left), receive_from_peer_into( &mut c_hat_1, batch_size, ctx, - &OPRFShuffleStep::TransferCHat, + &OPRFShuffleStep::TransferC, Direction::Left, ), ) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index 40f973abf..72493fc2d 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -1,6 +1,6 @@ use std::{iter, ops::Add}; -use futures::stream::TryStreamExt; +use futures::{stream::TryStreamExt, StreamExt}; use futures_util::{ future::{try_join, try_join3}, stream::iter, @@ -13,13 +13,14 @@ use crate::{ ff::{boolean_array::BooleanArray, Field, Gf32Bit, Serializable}, helpers::{ hashing::{compute_possibly_empty_hash, Hash}, - Direction, TotalRecords, + Direction, Role, TotalRecords, }, protocol::{ basics::{malicious_reveal, mul::semi_honest_multiply}, - context::Context, + context::{Context, ShardedContext}, ipa_prf::shuffle::{ - shuffle_protocol, + base::shuffle_protocol, + sharded::{h1_shuffle_for_shard, h2_shuffle_for_shard, h3_shuffle_for_shard}, step::{OPRFShuffleStep, VerifyShuffleStep}, IntermediateShuffleMessages, }, @@ -28,9 +29,10 @@ use crate::{ }, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, - SharedValue, SharedValueArray, StdArray, + SharedValue, }, seq_join::seq_join, + sharding::ShardIndex, }; /// This function executes the maliciously secure shuffle protocol on the input: `shares`. @@ -85,6 +87,92 @@ where Ok(truncate_tags(&shuffled_shares)) } +async fn setup_keys(ctx: C, amount_of_keys: usize) -> Result>, Error> +where + C: ShardedContext, +{ + // We reshuffle among the shards, so all the shards need to use the same MAC keys. + // The first shard generates the keys and sends them to all the others. + let key_dist_ctx = ctx.set_total_records(TotalRecords::specified(amount_of_keys).unwrap()); + if ctx.shard_id() == ShardIndex::FIRST { + // generate MAC keys + let keys = (0..amount_of_keys) + .map(|i| ctx.prss().generate(RecordId::from(i))) + .collect::>>(); + + for i in 1..u32::from(ctx.shard_count()) { + let shard = ShardIndex::from(i); + ctx.parallel_join(keys.iter().enumerate().map(|(i, key)| { + let key_dist_ctx = key_dist_ctx.clone(); + async move { + key_dist_ctx + .shard_send_channel::>(shard) + .send(RecordId::from(i), key) + .await + } + })) + .await?; + } + + Ok(keys) + } else { + key_dist_ctx + .shard_recv_channel(ShardIndex::FIRST) + .take(amount_of_keys) + .try_collect() + .await + } +} + +/// Entry point to execute malicious-secure sharded shuffle. +/// ## Errors +/// Failure to communicate over the network, either to other MPC helpers, and/or to other shards +/// will generate a shuffle error, as will detection of data inconsistencies that could indicate +/// a malicious helper. +#[allow(dead_code)] +pub async fn malicious_sharded_shuffle( + ctx: C, + shares: I, +) -> Result>, crate::error::Error> +where + I: IntoIterator>, + I::IntoIter: Send + ExactSizeIterator, + C: ShardedContext, + S: BooleanArray, + B: BooleanArray, + AdditiveShare: crate::protocol::ipa_prf::shuffle::sharded::Shuffleable, +{ + // assert lengths + assert_eq!(S::BITS + 32, B::BITS); + + // prepare keys + let amount_of_keys: usize = (usize::try_from(S::BITS).unwrap() + 31) / 32; + let keys = setup_keys(ctx.narrow(&OPRFShuffleStep::SetupKeys), amount_of_keys).await?; + + // compute and append tags to rows + let shares_and_tags: Vec> = + compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?; + + let (shuffled_shares, messages) = match ctx.role() { + Role::H1 => h1_shuffle_for_shard(ctx.clone(), shares_and_tags).await, + Role::H2 => h2_shuffle_for_shard(ctx.clone(), shares_and_tags).await, + Role::H3 => h3_shuffle_for_shard(ctx.clone(), shares_and_tags).await, + }?; + + // verify the shuffle + verify_shuffle::<_, S, B>( + ctx.narrow(&OPRFShuffleStep::VerifyShuffle), + &keys, + &shuffled_shares, + messages, + ) + .await?; + + // truncate tags from output_shares + // verify_shuffle ensures that truncate_tags yields the correct rows + Ok(truncate_tags::(&shuffled_shares)) +} + /// This function truncates the tags from the output shares of the shuffle protocol /// /// ## Panics @@ -143,11 +231,7 @@ async fn verify_shuffle( let k_ctx = ctx .narrow(&VerifyShuffleStep::RevealMACKey) .set_total_records(TotalRecords::specified(key_shares.len())?); - let keys = reveal_keys(&k_ctx, key_shares) - .await? - .iter() - .map(Gf32Bit::from_array) - .collect::>(); + let keys = reveal_keys(&k_ctx, key_shares).await?; assert_eq!(messages.role(), ctx.role()); @@ -368,18 +452,19 @@ where async fn reveal_keys( ctx: &C, key_shares: &[AdditiveShare], -) -> Result>, Error> { +) -> Result, Error> { // reveal MAC keys let keys = ctx .parallel_join(key_shares.iter().enumerate().map(|(i, key)| async move { // uses malicious_reveal directly since we malicious_shuffle always needs the malicious_revel - malicious_reveal(ctx.clone(), RecordId::from(i), None, key).await + malicious_reveal(ctx.clone(), RecordId::from(i), None, key) + .await + .map(|v| Gf32Bit::from_array(&v.unwrap())) })) .await? .into_iter() - .flatten() // add a one, since last row element is tag which is not multiplied with a key - .chain(iter::once(StdArray::from_fn(|_| Gf32Bit::ONE))) + .chain(iter::once(Gf32Bit::ONE)) .collect::>(); Ok(keys) @@ -475,7 +560,7 @@ fn concatenate_row_and_tag( #[cfg(all(test, unit_test))] mod tests { - use rand::{distributions::Standard, prelude::Distribution, thread_rng, Rng}; + use rand::{distributions::Standard, prelude::Distribution, Rng}; use super::*; use crate::{ @@ -489,8 +574,11 @@ mod tests { }, protocol::ipa_prf::shuffle::base::shuffle_protocol, secret_sharing::SharedValue, - test_executor::run, - test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig}, + sharding::ShardContext, + test_executor::{run, run_random}, + test_fixture::{ + RandomInputDistribution, Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards, + }, }; /// Test the hashing of `BA112` and tag equality. @@ -499,7 +587,7 @@ mod tests { run(|| async { let world = TestWorld::default(); - let mut rng = thread_rng(); + let mut rng = world.rng(); let record = rng.gen::(); let (keys, result) = world @@ -508,8 +596,7 @@ mod tests { let amount_of_keys: usize = (usize::try_from(BA112::BITS).unwrap() + 31) / 32; // // generate MAC keys let keys = (0..amount_of_keys) - .map(|i| ctx.prss().generate_fields(RecordId::from(i))) - .map(|(left, right)| AdditiveShare::new(left, right)) + .map(|i| ctx.prss().generate(RecordId::from(i))) .collect::>>(); // compute and append tags to rows @@ -551,7 +638,7 @@ mod tests { const RECORD_AMOUNT: usize = 10; run(|| async { let world = TestWorld::default(); - let mut rng = thread_rng(); + let mut rng = world.rng(); let mut records = (0..RECORD_AMOUNT) .map(|_| rng.gen()) .collect::>(); @@ -598,7 +685,7 @@ mod tests { const RECORD_AMOUNT: usize = 10; run(|| async { let world = TestWorld::default(); - let mut rng = thread_rng(); + let mut rng = world.rng(); let records = (0..RECORD_AMOUNT) .map(|_| { let entry = rng.gen::<[u8; 4]>(); @@ -612,7 +699,7 @@ mod tests { let _ = world .semi_honest(records.into_iter(), |ctx, rows| async move { // trivial shares of Gf32Bit::ONE - let key_shares = vec![AdditiveShare::new(Gf32Bit::ONE, Gf32Bit::ONE); 1]; + let key_shares = vec![AdditiveShare::new(Gf32Bit::ONE, Gf32Bit::ONE)]; // run shuffle let (shares, messages) = shuffle_protocol(ctx.narrow("shuffle"), rows).await.unwrap(); @@ -636,13 +723,12 @@ mod tests { /// The function concatenates random rows and tags /// and checks whether the concatenation /// is still consistent with the original rows and tags - fn check_concatenate() + fn check_concatenate(rng: &mut impl Rng) where S: BooleanArray, B: BooleanArray, Standard: Distribution, { - let mut rng = thread_rng(); let row = AdditiveShare::::new(rng.gen(), rng.gen()); let tag = AdditiveShare::::new(rng.gen::(), rng.gen::()); let row_and_tag: AdditiveShare = concatenate_row_and_tag(&row, &tag); @@ -670,8 +756,10 @@ mod tests { #[test] fn check_concatenate_for_boolean_arrays() { - check_concatenate::(); - check_concatenate::(); + run_random(|mut rng| async move { + check_concatenate::(&mut rng); + check_concatenate::(&mut rng); + }); } /// Helper function for checking the tags @@ -689,7 +777,7 @@ mod tests { const RECORD_AMOUNT: usize = 10; run(|| async { let world = TestWorld::default(); - let mut rng = thread_rng(); + let mut rng = world.rng(); let records = (0..RECORD_AMOUNT) .map(|_| rng.gen::()) .collect::>(); @@ -772,33 +860,54 @@ mod tests { } #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait - fn interceptor_h1_to_h2(ctx: &MaliciousHelperContext, data: &mut Vec) { + fn interceptor_h1_to_h2( + ctx: &MaliciousHelperContext, + target_shard: ShardContext, + data: &mut Vec, + ) { // H1 runs an additive attack against H2 by // changing x2 - if ctx.gate.as_ref().contains("transfer_x2") && ctx.dest == Role::H2 { + if ctx.gate.as_ref().contains("transfer_x_y") + && ctx.dest == Role::H2 + && ctx.shard == target_shard + { data[0] ^= 1u8; } } #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait - fn interceptor_h2_to_h3(ctx: &MaliciousHelperContext, data: &mut Vec) { + fn interceptor_h2_to_h3( + ctx: &MaliciousHelperContext, + target_shard: ShardContext, + data: &mut Vec, + ) { // H2 runs an additive attack against H3 by // changing y1 - if ctx.gate.as_ref().contains("transfer_y1") && ctx.dest == Role::H3 { + if ctx.gate.as_ref().contains("transfer_x_y") + && ctx.dest == Role::H3 + && ctx.shard == target_shard + { data[0] ^= 1u8; } } #[allow(clippy::ptr_arg)] // to match StreamInterceptor trait - fn interceptor_h3_to_h2(ctx: &MaliciousHelperContext, data: &mut Vec) { + fn interceptor_h3_to_h2( + ctx: &MaliciousHelperContext, + target_shard: ShardContext, + data: &mut Vec, + ) { // H3 runs an additive attack against H2 by // changing c_hat_2 - if ctx.gate.as_ref().contains("transfer_c_hat") && ctx.dest == Role::H2 { + if ctx.gate.as_ref().contains("transfer_c") + && ctx.dest == Role::H2 + && ctx.shard == target_shard + { data[0] ^= 1u8; } } - /// This test checks that the malicious sort fails + /// This test checks that the malicious shuffle fails /// under a simple bit flip attack by H1. /// /// `x2` will be inconsistent which is checked by `H2`. @@ -808,12 +917,14 @@ mod tests { const RECORD_AMOUNT: usize = 10; run(move || async move { - let mut rng = thread_rng(); let mut config = TestWorldConfig::default(); config.stream_interceptor = - MaliciousHelper::new(Role::H1, config.role_assignment(), interceptor_h1_to_h2); + MaliciousHelper::new(Role::H1, config.role_assignment(), move |ctx, data| { + interceptor_h1_to_h2(ctx, None, data); + }); let world = TestWorld::new_with(config); + let mut rng = world.rng(); let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); let [_, h2, _] = world .semi_honest(records.into_iter(), |ctx, shares| async move { @@ -825,7 +936,7 @@ mod tests { }); } - /// This test checks that the malicious sort fails + /// This test checks that the malicious shuffle fails /// under a simple bit flip attack by H2. /// /// `y1` will be inconsistent which is checked by `H1`. @@ -835,12 +946,14 @@ mod tests { const RECORD_AMOUNT: usize = 10; run(move || async move { - let mut rng = thread_rng(); let mut config = TestWorldConfig::default(); config.stream_interceptor = - MaliciousHelper::new(Role::H2, config.role_assignment(), interceptor_h2_to_h3); + MaliciousHelper::new(Role::H2, config.role_assignment(), move |ctx, data| { + interceptor_h2_to_h3(ctx, None, data); + }); let world = TestWorld::new_with(config); + let mut rng = world.rng(); let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); let [h1, _, _] = world .malicious(records.into_iter(), |ctx, shares| async move { @@ -851,7 +964,7 @@ mod tests { }); } - /// This test checks that the malicious sort fails + /// This test checks that the malicious shuffle fails /// under a simple bit flip attack by H3. /// /// `c` from `H2` will be inconsistent @@ -862,12 +975,14 @@ mod tests { const RECORD_AMOUNT: usize = 10; run(move || async move { - let mut rng = thread_rng(); let mut config = TestWorldConfig::default(); config.stream_interceptor = - MaliciousHelper::new(Role::H3, config.role_assignment(), interceptor_h3_to_h2); + MaliciousHelper::new(Role::H3, config.role_assignment(), move |ctx, data| { + interceptor_h3_to_h2(ctx, None, data); + }); let world = TestWorld::new_with(config); + let mut rng = world.rng(); let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); let [h1, h2, _] = world .semi_honest(records.into_iter(), |ctx, shares| async move { @@ -882,4 +997,170 @@ mod tests { let _ = h1.unwrap(); }); } + + #[test] + fn sharded_correctness_small() { + const SHARDS: usize = 3; + const RECORD_AMOUNT: usize = 2; // some shard will have no output + type Distribution = RandomInputDistribution; + run(|| async { + let world = TestWorld::>::with_shards( + TestWorldConfig::default(), + ); + let mut rng = world.rng(); + let mut records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); + let sharded_result = world + .semi_honest(records.clone().into_iter(), |ctx, input| async move { + malicious_sharded_shuffle::<_, BA32, BA64, _>(ctx, input) + .await + .unwrap() + }) + .await; + + assert_eq!(sharded_result.len(), SHARDS); + + let mut result = sharded_result + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + // unshuffle by sorting + records.sort_by_key(U128Conversions::as_u128); + result.sort_by_key(U128Conversions::as_u128); + + assert_eq!(records, result); + }); + } + + #[test] + fn sharded_correctness_large() { + const SHARDS: usize = 3; + const RECORD_AMOUNT: usize = 100; // all shards will have output w.h.p. + type Distribution = RandomInputDistribution; + run(|| async { + let world = TestWorld::>::with_shards( + TestWorldConfig::default(), + ); + let mut rng = world.rng(); + let mut records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); + + let sharded_result = world + .semi_honest(records.clone().into_iter(), |ctx, input| async move { + malicious_sharded_shuffle::<_, BA32, BA64, _>(ctx, input) + .await + .unwrap() + }) + .await; + + assert_eq!(sharded_result.len(), SHARDS); + + let mut result = sharded_result + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + // unshuffle by sorting + records.sort_by_key(U128Conversions::as_u128); + result.sort_by_key(U128Conversions::as_u128); + + assert_eq!(records, result); + }); + } + + /// This test checks that the sharded malicious shuffle fails + /// under a simple bit flip attack by H1. + /// + /// `x2` will be inconsistent which is checked by `H2`. + #[test] + #[should_panic(expected = "X2 is inconsistent")] + fn sharded_fail_under_bit_flip_attack_on_x2() { + const SHARDS: usize = 3; + const RECORD_AMOUNT: usize = 100; // all shards will have output w.h.p. + type Distribution = RandomInputDistribution; + + run_random(|mut rng| async move { + let target_shard = ShardIndex::from(rng.gen_range(0..u32::try_from(SHARDS).unwrap())); + let mut config = TestWorldConfig::default().with_seed(rng.gen()); + config.stream_interceptor = + MaliciousHelper::new(Role::H1, config.role_assignment(), move |ctx, data| { + interceptor_h1_to_h2(ctx, Some(target_shard), data); + }); + + let world = TestWorld::>::with_shards(config); + let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); + let sharded_results = world + .semi_honest(records.into_iter(), |ctx, shares| async move { + malicious_sharded_shuffle::<_, BA32, BA64, _>(ctx, shares).await + }) + .await; + + assert_eq!(sharded_results.len(), SHARDS); + sharded_results[target_shard][Role::H2].as_ref().unwrap(); + }); + } + + /// This test checks that the sharded malicious shuffle fails + /// under a simple bit flip attack by H2. + /// + /// `y1` will be inconsistent which is checked by `H1`. + #[test] + #[should_panic(expected = "Y1 is inconsistent")] + fn sharded_fail_under_bit_flip_attack_on_y1() { + const SHARDS: usize = 3; + const RECORD_AMOUNT: usize = 100; // all shards will have output w.h.p. + type Distribution = RandomInputDistribution; + + run_random(|mut rng| async move { + let target_shard = ShardIndex::from(rng.gen_range(0..u32::try_from(SHARDS).unwrap())); + let mut config = TestWorldConfig::default().with_seed(rng.gen()); + config.stream_interceptor = + MaliciousHelper::new(Role::H2, config.role_assignment(), move |ctx, data| { + interceptor_h2_to_h3(ctx, Some(target_shard), data); + }); + + let world = TestWorld::>::with_shards(config); + let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); + let sharded_results = world + .semi_honest(records.into_iter(), |ctx, shares| async move { + malicious_sharded_shuffle::<_, BA32, BA64, _>(ctx, shares).await + }) + .await; + + assert_eq!(sharded_results.len(), SHARDS); + sharded_results[target_shard][Role::H1].as_ref().unwrap(); + }); + } + + /// This test checks that the malicious sharded shuffle fails + /// under a simple bit flip attack by H3. + /// + /// `c` from `H2` will be inconsistent + /// which is checked by `H1`. + #[test] + #[should_panic(expected = "C from H2 is inconsistent")] + fn sharded_fail_under_bit_flip_attack_on_c() { + const SHARDS: usize = 3; + const RECORD_AMOUNT: usize = 100; // all shards will have output w.h.p. + type Distribution = RandomInputDistribution; + + run_random(|mut rng| async move { + let target_shard = ShardIndex::from(rng.gen_range(0..u32::try_from(SHARDS).unwrap())); + let mut config = TestWorldConfig::default().with_seed(rng.gen()); + config.stream_interceptor = + MaliciousHelper::new(Role::H3, config.role_assignment(), move |ctx, data| { + interceptor_h3_to_h2(ctx, Some(target_shard), data); + }); + + let world = TestWorld::>::with_shards(config); + let records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); + let sharded_results = world + .semi_honest(records.into_iter(), |ctx, shares| async move { + malicious_sharded_shuffle::<_, BA32, BA64, _>(ctx, shares).await + }) + .await; + + assert_eq!(sharded_results.len(), SHARDS); + sharded_results[target_shard][Role::H1].as_ref().unwrap(); + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index 5d8517fd5..db84d6df4 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -18,7 +18,7 @@ use crate::{ protocol::{ context::{Context, MaliciousContext, SemiHonestContext}, ipa_prf::{ - shuffle::{base::shuffle_protocol, malicious::malicious_shuffle}, + shuffle::sharded::{MaliciousShuffleable, ShuffleContext}, OPRFIPAInputRow, }, }, @@ -26,14 +26,17 @@ use crate::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, SharedValue, }, - sharding::ShardBinding, + sharding::{ShardBinding, Sharded}, }; -pub mod base; -pub mod malicious; -#[cfg(descriptive_gate)] +mod base; +mod malicious; mod sharded; -pub(crate) mod step; +pub(crate) mod step; // must be pub(crate) for compact gate gen + +use base::shuffle_protocol as base_shuffle; +use malicious::{malicious_sharded_shuffle, malicious_shuffle}; +use sharded::shuffle as sharded_shuffle; /// This struct stores some intermediate messages during the shuffle. /// In a maliciously secure shuffle, @@ -68,8 +71,8 @@ impl IntermediateShuffleMessages { } } -/// Trait used by protocols to invoke either semi-honest or malicious shuffle, depending -/// on the type of context being used. +/// Trait used by protocols to invoke either semi-honest or malicious non-sharded +/// shuffle, depending on the type of context being used. pub trait Shuffle: Context { fn shuffle( self, @@ -99,12 +102,12 @@ impl<'b, T: ShardBinding> Shuffle for SemiHonestContext<'b, T> { for<'a> &'a B: Add + Add<&'a B, Output = B>, Standard: Distribution + Distribution, { - let fut = shuffle_protocol::<_, I, S>(self, shares); + let fut = base_shuffle::<_, I, S>(self, shares); fut.map(|res| res.map(|(output, _intermediates)| output)) } } -impl<'b> Shuffle for MaliciousContext<'b> { +impl<'b, T: ShardBinding> Shuffle for MaliciousContext<'b, T> { fn shuffle( self, shares: I, @@ -122,6 +125,47 @@ impl<'b> Shuffle for MaliciousContext<'b> { } } +/// Trait used by protocols to invoke either semi-honest or malicious sharded shuffle, +/// depending on the type of context being used. +#[allow(dead_code)] +pub trait ShardedShuffle: ShuffleContext { + fn sharded_shuffle(self, shares: I) -> impl Future, Error>> + Send + where + S: MaliciousShuffleable, + I: IntoIterator> + Send, + I::IntoIter: ExactSizeIterator + Send, + for<'a> &'a S: Add + Add<&'a S, Output = S>, + Standard: Distribution; +} + +impl<'b> ShardedShuffle for SemiHonestContext<'b, Sharded> { + fn sharded_shuffle(self, shares: I) -> impl Future, Error>> + Send + where + S: MaliciousShuffleable, + I: IntoIterator> + Send, + I::IntoIter: ExactSizeIterator + Send, + for<'a> &'a S: Add + Add<&'a S, Output = S>, + Standard: Distribution, + { + let fut = sharded_shuffle::<_, S, _>(self, shares.into_iter().map(S::from)); + fut.map(|res| res.map(|(output, _intermediates)| output)) + } +} + +impl<'b> ShardedShuffle for MaliciousContext<'b, Sharded> { + fn sharded_shuffle(self, shares: I) -> impl Future, Error>> + Send + where + S: MaliciousShuffleable, + I: IntoIterator> + Send, + I::IntoIter: ExactSizeIterator + Send, + for<'a> &'a S: Add + Add<&'a S, Output = S>, + Standard: Distribution, + { + let fut = malicious_sharded_shuffle::<_, S::Share, S::ShareAndTag, _>(self, shares); + fut.map(|res| res.map(|vec| vec.into_iter().map(S::from).collect())) + } +} + #[tracing::instrument(name = "shuffle_inputs", skip_all)] pub async fn shuffle_inputs( ctx: C, @@ -183,7 +227,7 @@ where TV: BooleanArray, TS: BooleanArray, { - let mut y = AdditiveShare::new(YS::ZERO, YS::ZERO); + let mut y = ReplicatedSecretSharing::new(YS::ZERO, YS::ZERO); expand_shared_array_in_place(&mut y, &input.match_key, 0); let mut offset = BA64::BITS as usize; @@ -217,7 +261,7 @@ where let mut offset = BA64::BITS as usize; - let is_trigger = AdditiveShare::::new( + let is_trigger = ReplicatedSecretSharing::new( input.left().get(offset).unwrap_or(Boolean::ZERO), input.right().get(offset).unwrap_or(Boolean::ZERO), ); @@ -250,7 +294,7 @@ where BK: BooleanArray, TV: BooleanArray, { - let mut y = AdditiveShare::new(YS::ZERO, YS::ZERO); + let mut y = ReplicatedSecretSharing::new(YS::ZERO, YS::ZERO); expand_shared_array_in_place(&mut y, &input.attributed_breakdown_key_bits, 0); expand_shared_array_in_place( &mut y, diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs index 2d1d78695..6a37fd6f3 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs @@ -8,31 +8,36 @@ //! MPC communication, it uses 6 rounds of intra-helper communications to send data between shards. //! In this implementation, this operation is called "resharding". -use std::{borrow::Borrow, future::Future, num::NonZeroUsize, ops::Add}; +use std::{future::Future, num::NonZeroUsize}; use futures::{future::try_join, stream, StreamExt, TryFutureExt}; -use ipa_step::Step; use rand::seq::SliceRandom; use crate::{ - ff::{boolean_array::BA64, U128Conversions}, + ff::{ + boolean_array::{BooleanArray, BA32, BA64}, + Serializable, U128Conversions, + }, helpers::{Direction, Error, Role, TotalRecords}, protocol::{ context::{reshard_iter, ShardedContext}, - ipa_prf::shuffle::IntermediateShuffleMessages, + ipa_prf::shuffle::{step::ShardedShuffleStep as ShuffleStep, IntermediateShuffleMessages}, prss::{FromRandom, SharedRandomness}, RecordId, }, - secret_sharing::{ - replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, - Sendable, SharedValue, - }, + secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, seq_join::{assert_send, seq_join}, }; /// This context is only useful for sharded shuffle modules because it implements common operations /// that all shards on all helpers perform to achieve the perfect shuffle. -trait ShuffleContext: ShardedContext { +/// +/// This trait is `pub`, which is required, because it is a supertrait of `pub trait ShardedShuffle`. +/// `mod sharded` is not `pub`, which makes the situation a variant of the sealed trait pattern. +/// Specifically, it prevents types outside `mod shuffle` from implementing `trait ShuffleContext`. +/// Note that this structure does NOT prevent calling these methods outside of `mod shuffle`, but +/// please don't do that. +pub trait ShuffleContext: ShardedContext { /// This sends a single machine word (8 byte value) to one of the helpers specified in /// `direction` parameter. fn send_word( @@ -82,8 +87,7 @@ trait ShuffleContext: ShardedContext { data: I, ) -> impl Future, crate::error::Error>> + Send where - I: IntoIterator, - I::Item: Borrow, + I: IntoIterator, I::IntoIter: ExactSizeIterator + Send, S: ShuffleShare, { @@ -100,7 +104,7 @@ trait ShuffleContext: ShardedContext { Direction::Right => r, }; - item.borrow().clone() + mask + item + mask }), |ctx, record_id, _| ctx.pick_shard(record_id, direction), )) @@ -182,46 +186,6 @@ trait ShuffleContext: ShardedContext { } } -enum ShuffleStep { - /// Depending on the helper position inside the MPC ring, generate Ã, B̃ or both. - PseudoRandomTable, - /// Permute the input according to the PRSS shared between H1 and H2. - Permute12, - /// Permute the input according to the PRSS shared between H2 and H3. - Permute23, - /// Permute the input according to the PRSS shared between H3 and H1. - Permute31, - /// Specific to H1 and H2 interaction - H2 informs H1 about |C|. - Cardinality, - /// Send all the shares from helper on the left to the helper on the right. - LeftToRight, - /// H2 and H3 interaction - Exchange `C_1` and `C_2`. - C, - /// Apply a mask to the given set of shares. Masking values come from PRSS. - Mask, - /// Local per-shard shuffle, where each shard redistributes shares locally according to samples - /// obtained from PRSS. Does not require Shard or MPC communication. - LocalShuffle, -} - -impl Step for ShuffleStep {} - -impl AsRef for ShuffleStep { - fn as_ref(&self) -> &str { - match self { - ShuffleStep::PseudoRandomTable => "PseudoRandomTable", - ShuffleStep::Permute12 => "Permute12", - ShuffleStep::Permute23 => "Permute23", - ShuffleStep::Permute31 => "Permute31", - ShuffleStep::Cardinality => "Cardinality", - ShuffleStep::LeftToRight => "LeftToRight", - ShuffleStep::C => "C", - ShuffleStep::Mask => "Mask", - ShuffleStep::LocalShuffle => "LocalShuffle", - } - } -} - impl ShuffleContext for C {} /// Marker trait for share values that can be shuffled. In simple cases where we shuffle events @@ -234,12 +198,16 @@ impl ShuffleContext for C {} /// /// [`ShuffleShare`] and [`Shuffleable`] are added to bridge the gap. They can be implemented for /// arbitrary structs as long as `Add` operation can be defined on them. -pub trait ShuffleShare: Sendable + Clone + FromRandom + Add {} +pub trait ShuffleShare: BooleanArray + Serializable + FromRandom {} -impl> ShuffleShare for V {} +impl ShuffleShare for V {} /// Trait for shuffle inputs that consists of two values (left and right). -pub trait Shuffleable: Send + 'static { +// The `From` and `Into` bounds are necessary to work with routines that have not been +// updated to use the `Shuffleable` trait. +pub trait Shuffleable: + From> + Into> + Send + 'static +{ type Share: ShuffleShare; fn left(&self) -> Self::Share; @@ -248,7 +216,7 @@ pub trait Shuffleable: Send + 'static { fn new(l: Self::Share, r: Self::Share) -> Self; } -impl Shuffleable for AdditiveShare { +impl Shuffleable for AdditiveShare { type Share = V; fn left(&self) -> Self::Share { @@ -264,8 +232,18 @@ impl Shuffleable for AdditiveShare { } } +/// Trait for inputs to malicious shuffle. +pub trait MaliciousShuffleable: Shuffleable { + /// A type that can hold `::Share` along with a 32-bit MAC. + type ShareAndTag: BooleanArray + FromRandom; +} + +impl MaliciousShuffleable for AdditiveShare { + type ShareAndTag = BA64; +} + /// Sharded shuffle as performed by shards on H1. -async fn h1_shuffle_for_shard( +pub(super) async fn h1_shuffle_for_shard( ctx: C, shares: I, ) -> Result<(Vec, IntermediateShuffleMessages), crate::error::Error> @@ -288,11 +266,11 @@ where // shared with the left helper. let x2: Vec = ctx .narrow(&ShuffleStep::Permute31) - .mask_and_shuffle(Direction::Left, &x1) + .mask_and_shuffle(Direction::Left, x1.iter().copied()) .await?; // X_2 is masked now and cannot reveal anything to the helper on the right. - ctx.narrow(&ShuffleStep::LeftToRight) + ctx.narrow(&ShuffleStep::TransferXY) .send_all(x2, Direction::Right) .await?; @@ -321,7 +299,7 @@ where } /// Sharded shuffle as performed by shards on H2. -async fn h2_shuffle_for_shard( +pub(super) async fn h2_shuffle_for_shard( ctx: C, shares: I, ) -> Result<(Vec, IntermediateShuffleMessages), crate::error::Error> @@ -342,19 +320,19 @@ where // Share y1 to the right. Safe to do because input has been masked with randomness // known to H1 and H2 only. - ctx.narrow(&ShuffleStep::LeftToRight) + ctx.narrow(&ShuffleStep::TransferXY) .send_all(y1, Direction::Right) .await?; let x2 = ctx - .narrow(&ShuffleStep::LeftToRight) + .narrow(&ShuffleStep::TransferXY) .recv_all::(Direction::Left) .await?; // generate X_3 = perm_23(X_2 ⊕ z_23) let x3: Vec = ctx .narrow(&ShuffleStep::Permute23) - .mask_and_shuffle(Direction::Right, &x2) + .mask_and_shuffle(Direction::Right, x2.iter().copied()) .await?; // at this moment we know the cardinality of C, and we let H1 know it, so it can start @@ -371,11 +349,11 @@ where // Knowing b, c_1 and c_2 lets us set our resulting share, according to the paper it is // (b, c_1 + c_2) let send_channel = ctx - .narrow(&ShuffleStep::C) + .narrow(&ShuffleStep::TransferC) .set_total_records(x3_len) .send_channel(ctx.role().peer(Direction::Right)); let recv_channel = ctx - .narrow(&ShuffleStep::C) + .narrow(&ShuffleStep::TransferC) .recv_channel(ctx.role().peer(Direction::Right)); let res = ctx @@ -386,12 +364,12 @@ where .narrow(&ShuffleStep::PseudoRandomTable) .prss() .generate(RecordId::from(i)); - let c1: S::Share = x3 + b.clone(); + let c1: S::Share = x3 + b; try_join( - send_channel.send(record_id, c1.clone()), + send_channel.send(record_id, c1), recv_channel.receive(record_id), ) - .map_ok(|((), c2)| S::new(b, c1 + c2)) + .map_ok(move |((), c2)| S::new(b, c1 + c2)) })) .await?; @@ -400,7 +378,7 @@ where /// Sharded shuffle as performed by shards on H3. Note that in semi-honest setting, H3 does not /// use its input. Adding support for active security will change that. -async fn h3_shuffle_for_shard( +pub(super) async fn h3_shuffle_for_shard( ctx: C, _: I, ) -> Result<(Vec, IntermediateShuffleMessages), crate::error::Error> @@ -412,20 +390,20 @@ where { // Receive y1 from the left let y1 = ctx - .narrow(&ShuffleStep::LeftToRight) + .narrow(&ShuffleStep::TransferXY) .recv_all::(Direction::Left) .await?; // Generate y2 = perm_31(y_1 ⊕ z_31) let y2: Vec = ctx .narrow(&ShuffleStep::Permute31) - .mask_and_shuffle(Direction::Right, &y1) + .mask_and_shuffle(Direction::Right, y1.iter().copied()) .await?; // Generate y3 = perm_23(y_2 ⊕ z_23) let y3: Vec = ctx .narrow(&ShuffleStep::Permute23) - .mask_and_shuffle(Direction::Left, &y2) + .mask_and_shuffle(Direction::Left, y2.iter().copied()) .await?; let Some(y3_len) = NonZeroUsize::new(y3.len()) else { @@ -435,11 +413,11 @@ where // Generate c_2 = y_3 ⊕ a, stream it to H2 and receive c_1 from it at the same time. // Set our share to be (c_1 + c_2, a) let send_channel = ctx - .narrow(&ShuffleStep::C) + .narrow(&ShuffleStep::TransferC) .set_total_records(y3_len) .send_channel(ctx.role().peer(Direction::Left)); let recv_channel = ctx - .narrow(&ShuffleStep::C) + .narrow(&ShuffleStep::TransferC) .recv_channel::(ctx.role().peer(Direction::Left)); let res = ctx .try_join(y3.into_iter().enumerate().map(|(i, y3)| { @@ -449,12 +427,12 @@ where .narrow(&ShuffleStep::PseudoRandomTable) .prss() .generate(RecordId::from(i)); - let c2 = y3 + a.clone(); + let c2 = y3 + a; try_join( - send_channel.send(record_id, c2.clone()), + send_channel.send(record_id, c2), recv_channel.receive(record_id), ) - .map_ok(|((), c1)| S::new(c1 + c2, a)) + .map_ok(move |((), c1)| S::new(c1 + c2, a)) })) .await?; @@ -489,7 +467,10 @@ mod tests { use rand::{thread_rng, Rng}; use crate::{ - ff::{boolean_array::BA8, Gf40Bit, U128Conversions}, + ff::{ + boolean_array::{BA64, BA8}, + U128Conversions, + }, protocol::ipa_prf::shuffle::{ base::test_helpers::{extract_shuffle_results, ExtractedShuffleResults}, sharded::shuffle, @@ -504,11 +485,15 @@ mod tests { async fn sharded_shuffle(input: Vec) -> Vec { let world: TestWorld> = TestWorld::with_shards(TestWorldConfig::default()); - world + let sharded_result = world .semi_honest(input.into_iter(), |ctx, input| async move { shuffle(ctx, input).await.unwrap().0 }) - .await + .await; + + assert_eq!(sharded_result.len(), SHARDS); + + sharded_result .into_iter() .flat_map(|v| v.reconstruct()) .collect::>() @@ -574,33 +559,34 @@ mod tests { type Distribution = RandomInputDistribution; run(|| async { let mut rng = thread_rng(); - // using Gf40Bit here since it implements cmp such that vec can later be sorted - let mut records = (0..RECORD_AMOUNT) - .map(|_| rng.gen()) - .collect::>(); + let mut records = (0..RECORD_AMOUNT).map(|_| rng.gen()).collect::>(); - let results = TestWorld::>::with_shards( + let sharded_results = TestWorld::>::with_shards( TestWorldConfig::default(), ) .semi_honest(records.clone().into_iter(), |ctx, input| async move { shuffle(ctx, input).await.unwrap() }) - .await - .into_iter() - .map(extract_shuffle_results) - .fold(ExtractedShuffleResults::empty(), |mut acc, results| { - let ExtractedShuffleResults { - x1_xor_y1, - x2_xor_y2, - a_xor_b_xor_c, - } = results; - - acc.x1_xor_y1.extend(x1_xor_y1); - acc.x2_xor_y2.extend(x2_xor_y2); - acc.a_xor_b_xor_c.extend(a_xor_b_xor_c); - - acc - }); + .await; + + assert_eq!(sharded_results.len(), SHARDS); + + let results = sharded_results + .into_iter() + .map(extract_shuffle_results) + .fold(ExtractedShuffleResults::empty(), |mut acc, results| { + let ExtractedShuffleResults { + x1_xor_y1, + x2_xor_y2, + a_xor_b_xor_c, + } = results; + + acc.x1_xor_y1.extend(x1_xor_y1); + acc.x2_xor_y2.extend(x2_xor_y2); + acc.a_xor_b_xor_c.extend(a_xor_b_xor_c); + + acc + }); let ExtractedShuffleResults { mut x1_xor_y1, @@ -609,10 +595,10 @@ mod tests { } = results; // unshuffle by sorting - records.sort(); - x1_xor_y1.sort(); - x2_xor_y2.sort(); - a_xor_b_xor_c.sort(); + records.sort_by_key(U128Conversions::as_u128); + x1_xor_y1.sort_by_key(U128Conversions::as_u128); + x2_xor_y2.sort_by_key(U128Conversions::as_u128); + a_xor_b_xor_c.sort_by_key(U128Conversions::as_u128); assert_eq!(records, a_xor_b_xor_c); assert_eq!(records, x1_xor_y1); diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs index 126996574..1120d14fa 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs @@ -1,14 +1,18 @@ use ipa_step_derive::CompactStep; +// Note: the stream interception tests for malicious shuffles require that the +// `TransferXY` and `TransferC` steps have the same name in `OPRFShuffleStep` and +// `ShardedShuffleStep`. + #[derive(CompactStep)] pub(crate) enum OPRFShuffleStep { + SetupKeys, ApplyPermutations, GenerateAHat, GenerateBHat, GenerateZ, - TransferCHat, - TransferX2, - TransferY1, + TransferXY, // Transfer of X2 and Y1 + TransferC, // Exchange of `C_1` and `C_2` GenerateTags, #[step(child = crate::protocol::ipa_prf::shuffle::step::VerifyShuffleStep)] VerifyShuffle, @@ -21,3 +25,26 @@ pub(crate) enum VerifyShuffleStep { HashH2toH1, HashH3toH2, } + +#[derive(CompactStep)] +pub(crate) enum ShardedShuffleStep { + /// Depending on the helper position inside the MPC ring, generate Ã, B̃ or both. + PseudoRandomTable, + /// Permute the input according to the PRSS shared between H1 and H2. + Permute12, + /// Permute the input according to the PRSS shared between H2 and H3. + Permute23, + /// Permute the input according to the PRSS shared between H3 and H1. + Permute31, + /// Specific to H1 and H2 interaction - H2 informs H1 about |C|. + Cardinality, + /// H1 sends X2 to H2. H2 sends Y1 to H3. + TransferXY, + /// H2 and H3 interaction - Exchange `C_1` and `C_2`. + TransferC, + /// Apply a mask to the given set of shares. Masking values come from PRSS. + Mask, + /// Local per-shard shuffle, where each shard redistributes shares locally according to samples + /// obtained from PRSS. Does not require Shard or MPC communication. + LocalShuffle, +} diff --git a/ipa-core/src/protocol/step.rs b/ipa-core/src/protocol/step.rs index cf3658018..ee2ef726a 100644 --- a/ipa-core/src/protocol/step.rs +++ b/ipa-core/src/protocol/step.rs @@ -39,6 +39,8 @@ pub enum DeadCodeStep { FeatureLabelDotProduct, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::MultiplicationStep)] Multiplication, + #[step(child = crate::protocol::ipa_prf::shuffle::step::ShardedShuffleStep)] + ShardedShuffle, } /// Provides a unique per-iteration context in tests. From 77314245f65dab1b06161e904f63e8243b9edeba Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Mon, 4 Nov 2024 13:50:07 -0800 Subject: [PATCH 2/2] Feedback --- ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index 72493fc2d..258064e4a 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -100,8 +100,7 @@ where .map(|i| ctx.prss().generate(RecordId::from(i))) .collect::>>(); - for i in 1..u32::from(ctx.shard_count()) { - let shard = ShardIndex::from(i); + for shard in ctx.shard_count().iter().skip(1) { ctx.parallel_join(keys.iter().enumerate().map(|(i, key)| { let key_dist_ctx = key_dist_ctx.clone(); async move {