Skip to content

Commit

Permalink
Merge pull request #804 from martinthomson/bucket-fixes
Browse files Browse the repository at this point in the history
Restructure the bucketing code
  • Loading branch information
martinthomson authored Oct 16, 2023
2 parents 2bbfeec + 05626b1 commit 380e5d3
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 178 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ cfg_aliases = "0.1.1"
command-fds = "0.2.2"
hex = "0.4"
permutation = "0.4.1"
proptest = "1.0.0"
proptest = "1"
tempfile = "3"
tokio-rustls = { version = "0.24.0", features = ["dangerous_configuration"] }

Expand Down
12 changes: 12 additions & 0 deletions src/ff/prime_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@ macro_rules! field_impl {
}
}

impl std::iter::Sum for $field {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::ZERO, |a, b| a + b)
}
}

impl<'a> std::iter::Sum<&'a $field> for $field {
fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
iter.fold(Self::ZERO, |a, b| a + *b)
}
}

impl TryFrom<u128> for $field {
type Error = crate::error::Error;

Expand Down
234 changes: 234 additions & 0 deletions src/protocol/prf_sharding/bucket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
use embed_doc_image::embed_doc_image;

use crate::{
error::Error,
ff::{GaloisField, PrimeField, Serializable},
protocol::{
basics::SecureMul, context::UpgradedContext, prf_sharding::BinaryTreeDepthStep,
step::BitOpStep, RecordId,
},
secret_sharing::{
replicated::malicious::ExtendableField, BitDecomposed, Linear as LinearSecretSharing,
},
};

#[embed_doc_image("tree-aggregation", "images/tree_aggregation.png")]
/// This function moves a single value to a correct bucket using tree aggregation approach
///
/// Here is how it works
/// The combined value, [`value`] forms the root of a binary tree as follows:
/// ![Tree propagation][tree-aggregation]
///
/// This value is propagated through the tree, with each subsequent iteration doubling the number of multiplications.
/// In the first round, r=BK-1, multiply the most significant bit ,[`bd_key`]_r by the value to get [`bd_key`]_r.[`value`]. From that,
/// produce [`row_contribution`]_r,0 =[`value`]-[`bd_key`]_r.[`value`] and [`row_contribution`]_r,1=[`bd_key`]_r.[`value`].
/// This takes the most significant bit of `bd_key` and places value in one of the two child nodes of the binary tree.
/// At each successive round, the next most significant bit is propagated from the leaf nodes of the tree into further leaf nodes:
/// [`row_contribution`]_r+1,q,0 =[`row_contribution`]_r,q - [`bd_key`]_r+1.[`row_contribution`]_r,q and [`row_contribution`]_r+1,q,1 =[`bd_key`]_r+1.[`row_contribution`]_r,q.
/// The work of each iteration therefore doubles relative to the one preceding.
///
/// In case a malicious entity sends a out of range breakdown key (i.e. greater than the max count) to this function, we need to do some
/// extra processing to ensure contribution doesn't end up in a wrong bucket. However, this requires extra multiplications.
/// This would potentially not be needed in IPA (as the breakdown key is provided by the report collector, so a bad value only spoils their own result) but useful for PAM.
/// This can be by passing `robust` as true.
pub async fn move_single_value_to_bucket<BK, C, S, F>(
ctx: C,
record_id: RecordId,
bd_key: BitDecomposed<S>,
value: S,
breakdown_count: usize,
robust: bool,
) -> Result<Vec<S>, Error>
where
BK: GaloisField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + Serializable + SecureMul<C>,
F: PrimeField + ExtendableField,
{
let mut step: usize = 1 << BK::BITS;

assert!(
breakdown_count <= 1 << BK::BITS,
"Asking for more buckets ({breakdown_count}) than bits in the key ({}) allow",
BK::BITS
);
assert!(
breakdown_count <= 128,
"Our step implementation (BitOpStep) cannot go past 64"
);
let mut row_contribution = vec![value; breakdown_count];

for (tree_depth, bit_of_bdkey) in bd_key.iter().enumerate().rev() {
let span = step >> 1;
if !robust && span > breakdown_count {
step = span;
continue;
}

let depth_c = ctx.narrow(&BinaryTreeDepthStep::from(tree_depth));
let mut futures = Vec::with_capacity(breakdown_count / step);

for (i, tree_index) in (0..breakdown_count).step_by(step).enumerate() {
let bit_c = depth_c.narrow(&BitOpStep::from(i));

if robust || tree_index + span < breakdown_count {
futures.push(row_contribution[tree_index].multiply(bit_of_bdkey, bit_c, record_id));
}
}
let contributions = ctx.parallel_join(futures).await?;

for (index, bdbit_contribution) in contributions.into_iter().enumerate() {
let left_index = index * step;
let right_index = left_index + span;

row_contribution[left_index] -= &bdbit_contribution;
if right_index < breakdown_count {
row_contribution[right_index] = bdbit_contribution;
}
}
step = span;
}
Ok(row_contribution)
}

#[cfg(all(test, unit_test))]
pub mod tests {
use rand::thread_rng;

use crate::{
ff::{Field, Fp32BitPrime, Gf5Bit, Gf8Bit},
protocol::{
context::{Context, UpgradableContext, Validator},
prf_sharding::bucket::move_single_value_to_bucket,
RecordId,
},
rand::Rng,
secret_sharing::SharedValue,
test_executor::run,
test_fixture::{get_bits, Reconstruct, Runner, TestWorld},
};

const MAX_BREAKDOWN_COUNT: usize = 1 << Gf5Bit::BITS;
const VALUE: u32 = 10;

async fn move_to_bucket(count: usize, breakdown_key: usize, robust: bool) -> Vec<Fp32BitPrime> {
let breakdown_key_bits =
get_bits::<Fp32BitPrime>(breakdown_key.try_into().unwrap(), Gf5Bit::BITS);
let value = Fp32BitPrime::truncate_from(VALUE);

TestWorld::default()
.semi_honest(
(breakdown_key_bits, value),
|ctx, (breakdown_key_share, value_share)| async move {
let validator = ctx.validator();
let ctx = validator.context();
move_single_value_to_bucket::<Gf5Bit, _, _, Fp32BitPrime>(
ctx.set_total_records(1),
RecordId::from(0),
breakdown_key_share,
value_share,
count,
robust,
)
.await
.unwrap()
},
)
.await
.reconstruct()
}

#[test]
fn semi_honest_move_in_range() {
run(|| async move {
let mut rng = thread_rng();
let count = rng.gen_range(1..MAX_BREAKDOWN_COUNT);
let breakdown_key = rng.gen_range(0..count);
let mut expected = vec![Fp32BitPrime::ZERO; count];
expected[breakdown_key] = Fp32BitPrime::truncate_from(VALUE);

let result = move_to_bucket(count, breakdown_key, false).await;
assert_eq!(result, expected, "expected value at index {breakdown_key}");
});
}

#[test]
fn semi_honest_move_in_range_robust() {
run(|| async move {
let mut rng = thread_rng();
let count = rng.gen_range(1..MAX_BREAKDOWN_COUNT);
let breakdown_key = rng.gen_range(0..count);
let mut expected = vec![Fp32BitPrime::ZERO; count];
expected[breakdown_key] = Fp32BitPrime::truncate_from(VALUE);

let result = move_to_bucket(count, breakdown_key, true).await;
assert_eq!(result, expected, "expected value at index {breakdown_key}");
});
}

#[test]
fn semi_honest_move_out_of_range() {
run(move || async move {
let mut rng: rand::rngs::ThreadRng = thread_rng();
let count = rng.gen_range(2..MAX_BREAKDOWN_COUNT - 1);
let breakdown_key = rng.gen_range(count..MAX_BREAKDOWN_COUNT);

let result = move_to_bucket(count, breakdown_key, false).await;
assert_eq!(result.len(), count);
assert_eq!(
result.into_iter().sum::<Fp32BitPrime>(),
Fp32BitPrime::truncate_from(VALUE)
);
});
}

#[test]
fn semi_honest_move_out_of_range_robust() {
run(move || async move {
let mut rng: rand::rngs::ThreadRng = thread_rng();
let count = rng.gen_range(2..MAX_BREAKDOWN_COUNT - 1);
let breakdown_key = rng.gen_range(count..MAX_BREAKDOWN_COUNT);

let result = move_to_bucket(count, breakdown_key, true).await;
assert_eq!(result.len(), count);
assert!(result.into_iter().all(|x| x == Fp32BitPrime::ZERO));
});
}

#[test]
#[should_panic]
fn move_out_of_range_too_many_buckets_type() {
run(move || async move {
_ = move_to_bucket(MAX_BREAKDOWN_COUNT + 1, 0, false).await;
});
}

#[test]
#[should_panic]
fn move_out_of_range_too_many_buckets_steps() {
run(move || async move {
let breakdown_key_bits = get_bits::<Fp32BitPrime>(0, Gf8Bit::BITS);
let value = Fp32BitPrime::truncate_from(VALUE);

_ = TestWorld::default()
.semi_honest(
(breakdown_key_bits, value),
|ctx, (breakdown_key_share, value_share)| async move {
let validator = ctx.validator();
let ctx = validator.context();
move_single_value_to_bucket::<Gf8Bit, _, _, Fp32BitPrime>(
ctx.set_total_records(1),
RecordId::from(0),
breakdown_key_share,
value_share,
129,
false,
)
.await
.unwrap()
},
)
.await;
});
}
}
10 changes: 5 additions & 5 deletions src/protocol/prf_sharding/feature_label_dot_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,10 @@ where
let num_user_rows = rows_for_user.len();
let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned();
let record_ids = record_id_for_row_depth[..num_user_rows].to_owned();
record_id_for_row_depth[..num_user_rows]
.iter_mut()
.for_each(|count| *count += 1);

for count in &mut record_id_for_row_depth[..num_user_rows] {
*count += 1;
}
#[allow(clippy::async_yields_async)]
// this is ok, because seq join wants a stream of futures
async move {
Expand All @@ -273,15 +273,15 @@ where
}));

// Execute all of the async futures (sequentially), and flatten the result
let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits)
let flattened_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits)
.flat_map(|x| stream_iter(x.unwrap()));

// modulus convert feature vector bits from shares in `Z_2` to shares in `Z_p`
let converted_feature_vector_bits = convert_bits(
prime_field_ctx
.narrow(&Step::ModulusConvertFeatureVectorBits)
.set_total_records(num_outputs),
flattenned_stream,
flattened_stream,
0..FV::BITS,
);

Expand Down
Loading

0 comments on commit 380e5d3

Please sign in to comment.