Skip to content

Commit

Permalink
Merge pull request #1249 from danielmasny/shuffle-function
Browse files Browse the repository at this point in the history
changes to shuffle function
  • Loading branch information
danielmasny authored Sep 6, 2024
2 parents 7c765c0 + 841d448 commit adcd8dd
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 15 deletions.
137 changes: 124 additions & 13 deletions ipa-core/src/protocol/ipa_prf/shuffle/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use crate::{

/// # Errors
/// Will propagate errors from transport and a few typecasts
pub async fn shuffle<C, I, S>(ctx: C, shares: I) -> Result<Vec<AdditiveShare<S>>, Error>
pub async fn shuffle<C, I, S>(
ctx: C,
shares: I,
) -> Result<(Vec<AdditiveShare<S>>, IntermediateShuffleMessages<S>), Error>
where
C: Context,
I: IntoIterator<Item = AdditiveShare<S>>,
Expand All @@ -32,7 +35,13 @@ where
// This protocol can take a mutable iterator and replace items in the input.
let shares = shares.into_iter();
let Some(shares_len) = NonZeroUsize::new(shares.len()) else {
return Ok(vec![]);
return Ok((
vec![],
IntermediateShuffleMessages {
x1_or_y1: None,
x2_or_y2: None,
},
));
};
let ctx_z = ctx.narrow(&OPRFShuffleStep::GenerateZ);
let zs = generate_random_tables_with_peers(shares_len, &ctx_z);
Expand All @@ -44,12 +53,23 @@ where
}
}

#[allow(dead_code)]
/// This struct stores some intermediate messages during the shuffle.
/// In a maliciously secure shuffle,
/// these messages need to be checked for consistency across helpers.
/// `H1` stores `x1`, `H2` stores `x2` and `H3` stores `y1` and `y2`.
#[derive(Debug, Clone)]
pub struct IntermediateShuffleMessages<S: SharedValue> {
x1_or_y1: Option<Vec<S>>,
x2_or_y2: Option<Vec<S>>,
}

async fn run_h1<C, I, S, Zl, Zr>(
ctx: &C,
batch_size: NonZeroUsize,
shares: I,
(z_31, z_12): (Zl, Zr),
) -> Result<Vec<AdditiveShare<S>>, Error>
) -> Result<(Vec<AdditiveShare<S>>, IntermediateShuffleMessages<S>), Error>
where
C: Context,
I: IntoIterator<Item = AdditiveShare<S>>,
Expand All @@ -76,21 +96,31 @@ where
let (mut rng_perm_l, mut rng_perm_r) = ctx_perm.prss_rng();
x_1.shuffle(&mut rng_perm_r);

let mut x_2 = x_1;
// need to output x_1
// call to clone causes allocation
// ideally in the semi honest setting, we would not clone
let mut x_2 = x_1.clone();
add_single_shares_in_place(&mut x_2, z_31);
x_2.shuffle(&mut rng_perm_l);
send_to_peer(&x_2, ctx, &OPRFShuffleStep::TransferX2, Direction::Right).await?;

let res = combine_single_shares(a_hat, b_hat).collect::<Vec<_>>();
Ok(res)
// we only need to store x_1 in IntermediateShuffleMessage
Ok((
res,
IntermediateShuffleMessages {
x1_or_y1: Some(x_1),
x2_or_y2: None,
},
))
}

async fn run_h2<C, I, S, Zl, Zr>(
ctx: &C,
batch_size: NonZeroUsize,
shares: I,
(z_12, z_23): (Zl, Zr),
) -> Result<Vec<AdditiveShare<S>>, Error>
) -> Result<(Vec<AdditiveShare<S>>, IntermediateShuffleMessages<S>), Error>
where
C: Context,
I: IntoIterator<Item = AdditiveShare<S>>,
Expand Down Expand Up @@ -127,7 +157,10 @@ where
)
.await?;

let mut x_3 = x_2;
// we need to output x_2
// call to clone causes allocation
// ideally in the semi honest setting, we would not clone
let mut x_3 = x_2.clone();
add_single_shares_in_place(&mut x_3, z_23);
x_3.shuffle(&mut rng_perm_r);

Expand All @@ -154,14 +187,21 @@ where

let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter());
let res = combine_single_shares(b_hat, c_hat).collect::<Vec<_>>();
Ok(res)
// we only need to store x_2 in IntermediateShuffleMessage
Ok((
res,
IntermediateShuffleMessages {
x1_or_y1: None,
x2_or_y2: Some(x_2),
},
))
}

async fn run_h3<C, S, Zl, Zr>(
ctx: &C,
batch_size: NonZeroUsize,
(z_23, z_31): (Zl, Zr),
) -> Result<Vec<AdditiveShare<S>>, Error>
) -> Result<(Vec<AdditiveShare<S>>, IntermediateShuffleMessages<S>), Error>
where
C: Context,
S: SharedValue + Add<Output = S>,
Expand All @@ -186,14 +226,20 @@ where
)
.await?;

let mut y_2 = y_1;
// need to output y_1
// call to clone causes allocation
// ideally in the semi honest setting, we would not clone
let mut y_2 = y_1.clone();
add_single_shares_in_place(&mut y_2, z_31);

let ctx_perm = ctx.narrow(&OPRFShuffleStep::ApplyPermutations);
let (mut rng_perm_l, mut rng_perm_r) = ctx_perm.prss_rng();
y_2.shuffle(&mut rng_perm_r);

let mut y_3 = y_2;
// need to output y_2
// call to clone causes allocation
// ideally in the semi honest setting, we would not clone
let mut y_3 = y_2.clone();
add_single_shares_in_place(&mut y_3, z_23);
y_3.shuffle(&mut rng_perm_l);

Expand All @@ -218,7 +264,13 @@ where

let c_hat = add_single_shares(c_hat_1, c_hat_2);
let res = combine_single_shares(c_hat, a_hat).collect::<Vec<_>>();
Ok(res)
Ok((
res,
IntermediateShuffleMessages {
x1_or_y1: Some(y_1),
x2_or_y2: Some(y_2),
},
))
}

fn add_single_shares<A, B, S, L, R>(l: L, r: R) -> impl Iterator<Item = S>
Expand Down Expand Up @@ -343,9 +395,13 @@ where

#[cfg(all(test, unit_test))]
pub mod tests {
use rand::{thread_rng, Rng};

use super::shuffle;
use crate::{
ff::{Gf40Bit, U128Conversions},
secret_sharing::replicated::ReplicatedSecretSharing,
test_executor::run,
test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig},
};

Expand All @@ -364,7 +420,7 @@ pub mod tests {
// Stable seed is used to get predictable shuffle results.
let mut actual = TestWorld::new_with(TestWorldConfig::default().with_seed(123))
.semi_honest(records.clone().into_iter(), |ctx, shares| async move {
shuffle(ctx, shares).await.unwrap()
shuffle(ctx, shares).await.unwrap().0
})
.await
.reconstruct();
Expand All @@ -381,4 +437,59 @@ pub mod tests {
"Shuffle should not change the items in the set"
);
}

#[test]
fn check_intermediate_messages() {
const RECORD_AMOUNT: usize = 100;
run(|| async {
let world = TestWorld::default();
let mut rng = thread_rng();
// using Gf40Bit here since it implements cmp such that vec can later be sorted
let mut records = (0..RECORD_AMOUNT)
.map(|_| rng.gen())
.collect::<Vec<Gf40Bit>>();

let [h1, h2, h3] = world
.semi_honest(records.clone().into_iter(), |ctx, records| async move {
shuffle(ctx, records).await
})
.await;

// check consistency
// i.e. x_1 xor y_1 = x_2 xor y_2 = C xor A xor B
let (h1_shares, h1_messages) = h1.unwrap();
let (_, h2_messages) = h2.unwrap();
let (h3_shares, h3_messages) = h3.unwrap();

let mut x1_xor_y1 = h1_messages
.x1_or_y1
.unwrap()
.iter()
.zip(h3_messages.x1_or_y1.unwrap())
.map(|(x1, y1)| x1 + y1)
.collect::<Vec<_>>();
let mut x2_xor_y2 = h2_messages
.x2_or_y2
.unwrap()
.iter()
.zip(h3_messages.x2_or_y2.unwrap())
.map(|(x2, y2)| x2 + y2)
.collect::<Vec<_>>();
let mut a_xor_b_xor_c = h1_shares
.iter()
.zip(h3_shares)
.map(|(h1_share, h3_share)| h1_share.left() + h1_share.right() + h3_share.left())
.collect::<Vec<_>>();

// unshuffle by sorting
records.sort();
x1_xor_y1.sort();
x2_xor_y2.sort();
a_xor_b_xor_c.sort();

assert_eq!(records, a_xor_b_xor_c);
assert_eq!(records, x1_xor_y1);
assert_eq!(records, x2_xor_y2);
});
}
}
4 changes: 2 additions & 2 deletions ipa-core/src/protocol/ipa_prf/shuffle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ where
.map(|item| oprfreport_to_shuffle_input::<BA112, BK, TV, TS>(&item))
.collect::<Vec<_>>();

let shuffled = shuffle(ctx, shuffle_input).await?;
let (shuffled, _) = shuffle(ctx, shuffle_input).await?;

Ok(shuffled
.into_iter()
Expand All @@ -69,7 +69,7 @@ where
.map(|item| attribution_outputs_to_shuffle_input::<BK, TV, R>(&item))
.collect::<Vec<_>>();

let shuffled = shuffle(ctx, shuffle_input).await?;
let (shuffled, _) = shuffle(ctx, shuffle_input).await?;

Ok(shuffled
.into_iter()
Expand Down

0 comments on commit adcd8dd

Please sign in to comment.