Skip to content

Commit

Permalink
Tolerate empty vectors of shuffle intermediates (#1383)
Browse files Browse the repository at this point in the history
This is more important for the sharded shuffle, which for small inputs
is reasonably likely to produce an empty output on some shard.
  • Loading branch information
andyleiserson authored Oct 30, 2024
1 parent 77565ae commit 6d29275
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 12 deletions.
62 changes: 54 additions & 8 deletions ipa-core/src/helpers/hashing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ impl<'a, T: Serializable> SerializeAs<T> for &'a T {

impl MpcMessage for Hash {}

/// Computes Hash of serializable values from an iterator
///
/// ## Panics
/// Panics when Iterator is empty.
pub fn compute_hash<I, T, S>(input: I) -> Hash
fn compute_hash_internal<I, T, S>(input: I) -> (Hash, bool)
where
I: IntoIterator<Item = T>,
T: SerializeAs<S>,
Expand All @@ -74,9 +70,37 @@ where
sha.update(&buf);
}

assert!(!is_empty, "must not provide an empty iterator");
// compute hash
Hash(sha.finalize())
(Hash(sha.finalize()), is_empty)
}

/// Computes Hash of serializable values from an iterator
///
/// ## Panics
/// Panics if an empty input is provided. This can offer defense-in-depth by helping to
/// prevent fail-open bugs when the input should never be empty.
pub fn compute_hash<I, T, S>(input: I) -> Hash
where
I: IntoIterator<Item = T>,
T: SerializeAs<S>,
S: Serializable,
{
let (hash, empty) = compute_hash_internal(input);
assert!(!empty, "must not provide an empty iterator");
hash
}

/// Computes Hash of serializable values from an iterator
///
/// Unlike `compute_hash`, this version accepts empty inputs.
pub fn compute_possibly_empty_hash<I, T, S>(iter: I) -> Hash
where
I: IntoIterator<Item = T>,
T: SerializeAs<S>,
S: Serializable,
{
let (hash, _) = compute_hash_internal(iter);
hash
}

/// This function takes two hashes, combines them together and returns a single field element.
Expand Down Expand Up @@ -128,11 +152,13 @@ where

#[cfg(all(test, unit_test))]
mod test {
use std::iter;

use generic_array::{sequence::GenericSequence, GenericArray};
use rand::{thread_rng, Rng};
use typenum::U8;

use super::{compute_hash, Hash};
use super::{compute_hash, compute_possibly_empty_hash, Hash};
use crate::{
ff::{Fp31, Fp32BitPrime, Serializable},
helpers::hashing::hash_to_field,
Expand Down Expand Up @@ -239,4 +265,24 @@ mod test {
let vec = (0..100).map(|_| rng.gen::<Fp31>()).collect::<Vec<_>>();
assert_eq!(compute_hash(&vec), compute_hash(vec));
}

#[test]
#[should_panic(expected = "must not provide an empty iterator")]
fn empty_reject() {
compute_hash(iter::empty::<Fp31>());
}

#[test]
fn empty_accept() {
// SHA256 hash of zero-length input.
let empty_hash = Hash::deserialize(GenericArray::from_slice(
&hex::decode(b"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
.unwrap(),
))
.unwrap();
assert_eq!(
compute_possibly_empty_hash(iter::empty::<Fp31>()),
empty_hash
);
}
}
28 changes: 24 additions & 4 deletions ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
error::Error,
ff::{boolean_array::BooleanArray, Field, Gf32Bit, Serializable},
helpers::{
hashing::{compute_hash, Hash},
hashing::{compute_possibly_empty_hash, Hash},
Direction, TotalRecords,
},
protocol::{
Expand Down Expand Up @@ -348,7 +348,7 @@ where
.into_iter()
.chain(iter::once(tag))
});
compute_hash(iterator.map(|row_entry_iterator| {
compute_possibly_empty_hash(iterator.map(|row_entry_iterator| {
row_entry_iterator
.zip(keys)
.fold(Gf32Bit::ZERO, |acc, (row_entry, key)| {
Expand Down Expand Up @@ -414,9 +414,12 @@ where
{
let row_iterator = rows.into_iter();
let length = row_iterator.len();
if length == 0 {
return Ok(Vec::new());
}
let row_length = keys.len();
// make sure total records is not 0
debug_assert!(length * row_length != 0);
// Make sure `total_records` is not zero.
debug_assert!(row_length != 0);
let tag_ctx = ctx.set_total_records(TotalRecords::specified(length * row_length)?);
let p_ctx = &tag_ctx;

Expand Down Expand Up @@ -569,6 +572,23 @@ mod tests {
});
}

#[test]
fn empty() {
run(|| async {
assert_eq!(
TestWorld::default()
.semi_honest(iter::empty::<BA32>(), |ctx, records| async move {
malicious_shuffle::<_, _, BA64, _>(ctx, records)
.await
.unwrap()
})
.await
.reconstruct(),
Vec::<BA32>::new(),
);
});
}

/// This test checks the correctness of the malicious shuffle
/// when all parties behave honestly
/// and all the MAC keys are `Gf32Bit::ONE`.
Expand Down

0 comments on commit 6d29275

Please sign in to comment.