Skip to content

Commit

Permalink
Merge pull request #1432 from akoshelev/mal-shuffle-bound-fix
Browse files Browse the repository at this point in the history
Remove SharedValue trait bound from MaliciousShuffleable
  • Loading branch information
akoshelev authored Nov 18, 2024
2 parents 4e0608e + dba4f6a commit b6b9223
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 28 deletions.
95 changes: 73 additions & 22 deletions ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,44 @@ use crate::{
sharding::ShardIndex,
};

/// Container for left and right shares with tags attached to them.
/// Looks like an additive share, but it is not because it does not need
/// many traits that additive shares require to implement
#[derive(Clone, Debug, Default)]
struct Pair<S: ShuffleShare> {
left: S,
right: S,
}

impl<S: ShuffleShare> Shuffleable for Pair<S> {
type Share = S;

fn left(&self) -> Self::Share {
self.left.clone()
}

fn right(&self) -> Self::Share {
self.right.clone()
}

fn new(l: Self::Share, r: Self::Share) -> Self {
Self { left: l, right: r }
}
}

impl<S: ShuffleShare + SharedValue> From<AdditiveShare<S>> for Pair<S> {
fn from(value: AdditiveShare<S>) -> Self {
let (l, r) = value.as_tuple();
Shuffleable::new(l, r)
}
}

impl<S: ShuffleShare + SharedValue> From<Pair<S>> for AdditiveShare<S> {
fn from(value: Pair<S>) -> Self {
ReplicatedSecretSharing::new(value.left, value.right)
}
}

/// This function executes the maliciously secure shuffle protocol on the input: `shares`.
///
/// ## Errors
Expand All @@ -66,7 +104,7 @@ where
.collect::<Vec<AdditiveShare<Gf32Bit>>>();

// compute and append tags to rows
let shares_and_tags: Vec<AdditiveShare<S::ShareAndTag>> =
let shares_and_tags: Vec<Pair<S::ShareAndTag>> =
compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?;

// shuffle
Expand All @@ -76,7 +114,7 @@ where
verify_shuffle::<_, S>(
ctx.narrow(&OPRFShuffleStep::VerifyShuffle),
&keys,
&shuffled_shares,
shuffled_shares.as_slice(),
messages,
)
.await?;
Expand Down Expand Up @@ -144,7 +182,7 @@ where
let keys = setup_keys(ctx.narrow(&OPRFShuffleStep::SetupKeys), amount_of_keys).await?;

// compute and append tags to rows
let shares_and_tags: Vec<AdditiveShare<S::ShareAndTag>> =
let shares_and_tags: Vec<Pair<S::ShareAndTag>> =
compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?;

let (shuffled_shares, messages) = match ctx.role() {
Expand All @@ -171,16 +209,16 @@ where
///
/// ## Panics
/// Panics when `S::Bits > B::Bits`.
fn truncate_tags<S>(shares_and_tags: &[AdditiveShare<S::ShareAndTag>]) -> Vec<S>
fn truncate_tags<S>(shares_and_tags: &[Pair<S::ShareAndTag>]) -> Vec<S>
where
S: MaliciousShuffleable,
{
shares_and_tags
.iter()
.map(|row_with_tag| {
Shuffleable::new(
split_row_and_tag::<S>(ReplicatedSecretSharing::left(row_with_tag)).0,
split_row_and_tag::<S>(ReplicatedSecretSharing::right(row_with_tag)).0,
split_row_and_tag::<S>(&row_with_tag.left).0,
split_row_and_tag::<S>(&row_with_tag.right).0,
)
})
.collect()
Expand All @@ -192,7 +230,9 @@ where
/// When `row_with_tag` does not have the correct format,
/// i.e. deserialization returns an error,
/// the output row and tag will be zero.
fn split_row_and_tag<S: MaliciousShuffleable>(row_with_tag: S::ShareAndTag) -> (S::Share, Gf32Bit) {
fn split_row_and_tag<S: MaliciousShuffleable>(
row_with_tag: &S::ShareAndTag,
) -> (S::Share, Gf32Bit) {
let mut buf = GenericArray::default();
row_with_tag.serialize(&mut buf);
(
Expand All @@ -211,7 +251,7 @@ fn split_row_and_tag<S: MaliciousShuffleable>(row_with_tag: S::ShareAndTag) -> (
async fn verify_shuffle<C: Context, S: MaliciousShuffleable>(
ctx: C,
key_shares: &[AdditiveShare<Gf32Bit>],
shuffled_shares: &[AdditiveShare<S::ShareAndTag>],
shuffled_shares: &[Pair<S::ShareAndTag>],
messages: IntermediateShuffleMessages<S::ShareAndTag>,
) -> Result<(), Error> {
// reveal keys
Expand Down Expand Up @@ -248,7 +288,7 @@ async fn verify_shuffle<C: Context, S: MaliciousShuffleable>(
async fn h1_verify<C: Context, S: MaliciousShuffleable>(
ctx: C,
keys: &[Gf32Bit],
share_a_and_b: &[AdditiveShare<S::ShareAndTag>],
share_a_and_b: &[Pair<S::ShareAndTag>],
x1: Vec<S::ShareAndTag>,
) -> Result<(), Error> {
// compute hashes
Expand All @@ -257,9 +297,9 @@ async fn h1_verify<C: Context, S: MaliciousShuffleable>(
// compute hash for A xor B
let hash_a_xor_b = compute_and_hash_tags::<S, _>(
keys,
share_a_and_b.iter().map(|share| {
ReplicatedSecretSharing::left(share) + ReplicatedSecretSharing::right(share)
}),
share_a_and_b
.iter()
.map(|share| Shuffleable::left(share) + Shuffleable::right(share)),
);

// setup channels
Expand Down Expand Up @@ -315,7 +355,7 @@ async fn h1_verify<C: Context, S: MaliciousShuffleable>(
async fn h2_verify<C: Context, S: MaliciousShuffleable>(
ctx: C,
keys: &[Gf32Bit],
share_b_and_c: &[AdditiveShare<S::ShareAndTag>],
share_b_and_c: &[Pair<S::ShareAndTag>],
x2: Vec<S::ShareAndTag>,
) -> Result<(), Error> {
// compute hashes
Expand Down Expand Up @@ -360,7 +400,7 @@ async fn h2_verify<C: Context, S: MaliciousShuffleable>(
async fn h3_verify<C: Context, S: MaliciousShuffleable>(
ctx: C,
keys: &[Gf32Bit],
share_c_and_a: &[AdditiveShare<S::ShareAndTag>],
share_c_and_a: &[Pair<S::ShareAndTag>],
y1: Vec<S::ShareAndTag>,
y2: Vec<S::ShareAndTag>,
) -> Result<(), Error> {
Expand Down Expand Up @@ -406,7 +446,7 @@ where
let iterator = row_iterator.into_iter().map(|row_with_tag| {
// when split_row_and_tags returns the default value, the verification will fail
// except 2^-security_parameter, i.e. 2^-32
let (row, tag) = split_row_and_tag::<S>(row_with_tag);
let (row, tag) = split_row_and_tag::<S>(&row_with_tag);
<S::Share as TryInto<Vec<Gf32Bit>>>::try_into(row)
.unwrap()
.into_iter()
Expand Down Expand Up @@ -471,7 +511,7 @@ async fn compute_and_add_tags<C, S>(
ctx: C,
keys: &[AdditiveShare<Gf32Bit>],
rows: Vec<S>,
) -> Result<Vec<AdditiveShare<S::ShareAndTag>>, Error>
) -> Result<Vec<Pair<S::ShareAndTag>>, Error>
where
C: Context,
S: MaliciousShuffleable,
Expand Down Expand Up @@ -538,7 +578,7 @@ where
fn concatenate_row_and_tag<S: MaliciousShuffleable>(
row: &S,
tag: &AdditiveShare<Gf32Bit>,
) -> AdditiveShare<S::ShareAndTag> {
) -> Pair<S::ShareAndTag> {
let mut row_left = GenericArray::default();
let mut row_right = GenericArray::default();
let mut tag_left = GenericArray::default();
Expand Down Expand Up @@ -601,7 +641,10 @@ mod tests {
vec![record],
)
.await
.unwrap();
.unwrap()
.into_iter()
.map(AdditiveShare::from)
.collect();

(keys, shares_and_tags)
})
Expand Down Expand Up @@ -702,7 +745,11 @@ mod tests {
verify_shuffle::<_, AdditiveShare<BA32>>(
ctx.narrow("verify"),
&key_shares,
&shares,
shares
.into_iter()
.map(Pair::from)
.collect::<Vec<_>>()
.as_slice(),
messages,
)
.await
Expand All @@ -727,21 +774,21 @@ mod tests {
{
let row = <S as Shuffleable>::new(rng.gen(), rng.gen());
let tag = AdditiveShare::<Gf32Bit>::new(rng.gen::<Gf32Bit>(), rng.gen::<Gf32Bit>());
let row_and_tag: AdditiveShare<S::ShareAndTag> = concatenate_row_and_tag(&row, &tag);
let row_and_tag: Pair<S::ShareAndTag> = concatenate_row_and_tag(&row, &tag);

let mut buf = GenericArray::default();
let mut buf_row = GenericArray::default();
let mut buf_tag = GenericArray::default();

// check left shares
ReplicatedSecretSharing::left(&row_and_tag).serialize(&mut buf);
Shuffleable::left(&row_and_tag).serialize(&mut buf);
Shuffleable::left(&row).serialize(&mut buf_row);
assert_eq!(buf[0..S::TAG_OFFSET], buf_row[..]);
ReplicatedSecretSharing::left(&tag).serialize(&mut buf_tag);
assert_eq!(buf[S::TAG_OFFSET..], buf_tag[..]);

// check right shares
ReplicatedSecretSharing::right(&row_and_tag).serialize(&mut buf);
Shuffleable::right(&row_and_tag).serialize(&mut buf);
Shuffleable::right(&row).serialize(&mut buf_row);
assert_eq!(buf[0..S::TAG_OFFSET], buf_row[..]);
ReplicatedSecretSharing::right(&tag).serialize(&mut buf_tag);
Expand All @@ -766,6 +813,7 @@ mod tests {
where
S: MaliciousShuffleable,
S::Share: IntoShares<S>,
S::ShareAndTag: SharedValue,
Standard: Distribution<S::Share>,
{
const RECORD_AMOUNT: usize = 10;
Expand Down Expand Up @@ -814,6 +862,9 @@ mod tests {
)
.await
.unwrap()
.into_iter()
.map(AdditiveShare::from)
.collect::<Vec<_>>()
},
)
.await
Expand Down
8 changes: 2 additions & 6 deletions ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ pub trait MaliciousShuffleable:
///
/// Having an alias here makes it easier to reference in the code, because the
/// shuffle routines have an `S: MaliciousShuffleable` type parameter.
type ShareAndTag: ShuffleShare + SharedValue;
type ShareAndTag: ShuffleShare;

/// Same as `Self::MaliciousShare::TAG_OFFSET`.
///
Expand Down Expand Up @@ -316,11 +316,7 @@ where
/// automatically.
pub trait MaliciousShuffleShare: TryInto<Vec<Gf32Bit>, Error = LengthError> {
/// A type that can hold `<Self as Shuffleable>::Share` along with a 32-bit MAC.
///
/// The `SharedValue` bound is required because some of the malicious shuffle
/// routines use `AdditiveShare<ShareAndTag>`. It might be possible to refactor
/// those routines to avoid the `SharedValue` bound.
type ShareAndTag: ShuffleShare + SharedValue;
type ShareAndTag: ShuffleShare;

/// The offset to the MAC in `ShareAndTag`.
const TAG_OFFSET: usize;
Expand Down

0 comments on commit b6b9223

Please sign in to comment.