Skip to content

Commit

Permalink
Merge pull request #813 from richajaindce/minor_fixes
Browse files Browse the repository at this point in the history
Minor changes/fixes in OPRF
  • Loading branch information
benjaminsavage authored Oct 25, 2023
2 parents 9fdad64 + 0c5f464 commit 090d94f
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 14 deletions.
17 changes: 16 additions & 1 deletion src/ff/galois_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{

use bitvec::prelude::{bitarr, BitArr, Lsb0};
use generic_array::GenericArray;
use typenum::{Unsigned, U1, U3, U4, U5};
use typenum::{Unsigned, U1, U2, U3, U4, U5};

use crate::{
ff::{Field, Serializable},
Expand All @@ -25,6 +25,7 @@ pub trait GaloisField:

// Bit store type definitions
type U8_1 = BitArr!(for 8, in u8, Lsb0);
type U8_2 = BitArr!(for 9, in u8, Lsb0);
type U8_3 = BitArr!(for 24, in u8, Lsb0);
type U8_4 = BitArr!(for 32, in u8, Lsb0);
type U8_5 = BitArr!(for 40, in u8, Lsb0);
Expand All @@ -33,6 +34,10 @@ impl Block for U8_1 {
type Size = U1;
}

impl Block for U8_2 {
type Size = U2;
}

impl Block for U8_3 {
type Size = U3;
}
Expand Down Expand Up @@ -575,6 +580,16 @@ bit_array_impl!(
0b1_0001_1011_u128
);

bit_array_impl!(
bit_array_9,
Gf9Bit,
U8_2,
9,
bitarr!(const u8, Lsb0; 1, 0, 0, 0, 0, 0, 0, 0, 0),
// x^9 + x^4 + x^3 + x + 1
0b10_0001_1011_u128
);

bit_array_impl!(
bit_array_5,
Gf5Bit,
Expand Down
4 changes: 3 additions & 1 deletion src/ff/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ mod prime_field;
use std::ops::{Add, AddAssign, Sub, SubAssign};

pub use field::{Field, FieldType};
pub use galois_field::{GaloisField, Gf2, Gf20Bit, Gf32Bit, Gf3Bit, Gf40Bit, Gf5Bit, Gf8Bit};
pub use galois_field::{
GaloisField, Gf2, Gf20Bit, Gf32Bit, Gf3Bit, Gf40Bit, Gf5Bit, Gf8Bit, Gf9Bit,
};
use generic_array::{ArrayLength, GenericArray};
#[cfg(any(test, feature = "weak-field"))]
pub use prime_field::Fp31;
Expand Down
49 changes: 37 additions & 12 deletions src/protocol/prf_sharding/bucket.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
use embed_doc_image::embed_doc_image;
use ipa_macros::Step;

use crate::{
error::Error,
ff::{GaloisField, PrimeField, Serializable},
protocol::{
basics::SecureMul, context::UpgradedContext, prf_sharding::BinaryTreeDepthStep,
step::BitOpStep, RecordId,
basics::SecureMul, context::UpgradedContext, prf_sharding::BinaryTreeDepthStep, RecordId,
},
secret_sharing::{
replicated::malicious::ExtendableField, BitDecomposed, Linear as LinearSecretSharing,
},
};

#[derive(Step)]
pub enum BucketStep {
#[dynamic(256)]
Bit(usize),
}

impl TryFrom<u32> for BucketStep {
type Error = String;

fn try_from(v: u32) -> Result<Self, Self::Error> {
let val = usize::try_from(v);
let val = match val {
Ok(val) => Self::Bit(val),
Err(error) => panic!("{error:?}"),
};
Ok(val)
}
}

impl From<usize> for BucketStep {
fn from(v: usize) -> Self {
Self::Bit(v)
}
}

#[embed_doc_image("tree-aggregation", "images/tree_aggregation.png")]
/// This function moves a single value to a correct bucket using tree aggregation approach
///
Expand Down Expand Up @@ -53,8 +78,8 @@ where
BK::BITS
);
assert!(
breakdown_count <= 128,
"Our step implementation (BitOpStep) cannot go past 64"
breakdown_count <= 512,
"Our step implementation (BucketStep) cannot go past 256"
);
let mut row_contribution = vec![value; breakdown_count];

Expand All @@ -69,7 +94,7 @@ where
let mut futures = Vec::with_capacity(breakdown_count / step);

for (i, tree_index) in (0..breakdown_count).step_by(step).enumerate() {
let bit_c = depth_c.narrow(&BitOpStep::from(i));
let bit_c = depth_c.narrow(&BucketStep::from(i));

if robust || tree_index + span < breakdown_count {
futures.push(row_contribution[tree_index].multiply(bit_of_bdkey, bit_c, record_id));
Expand All @@ -96,7 +121,7 @@ pub mod tests {
use rand::thread_rng;

use crate::{
ff::{Field, Fp32BitPrime, Gf5Bit, Gf8Bit},
ff::{Field, Fp32BitPrime, Gf8Bit, Gf9Bit},
protocol::{
context::{Context, UpgradableContext, Validator},
prf_sharding::bucket::move_single_value_to_bucket,
Expand All @@ -108,12 +133,12 @@ pub mod tests {
test_fixture::{get_bits, Reconstruct, Runner, TestWorld},
};

const MAX_BREAKDOWN_COUNT: usize = 1 << Gf5Bit::BITS;
const MAX_BREAKDOWN_COUNT: usize = 256;
const VALUE: u32 = 10;

async fn move_to_bucket(count: usize, breakdown_key: usize, robust: bool) -> Vec<Fp32BitPrime> {
let breakdown_key_bits =
get_bits::<Fp32BitPrime>(breakdown_key.try_into().unwrap(), Gf5Bit::BITS);
get_bits::<Fp32BitPrime>(breakdown_key.try_into().unwrap(), Gf8Bit::BITS);
let value = Fp32BitPrime::truncate_from(VALUE);

TestWorld::default()
Expand All @@ -122,7 +147,7 @@ pub mod tests {
|ctx, (breakdown_key_share, value_share)| async move {
let validator = ctx.validator();
let ctx = validator.context();
move_single_value_to_bucket::<Gf5Bit, _, _, Fp32BitPrime>(
move_single_value_to_bucket::<Gf8Bit, _, _, Fp32BitPrime>(
ctx.set_total_records(1),
RecordId::from(0),
breakdown_key_share,
Expand Down Expand Up @@ -207,7 +232,7 @@ pub mod tests {
#[should_panic]
fn move_out_of_range_too_many_buckets_steps() {
run(move || async move {
let breakdown_key_bits = get_bits::<Fp32BitPrime>(0, Gf8Bit::BITS);
let breakdown_key_bits = get_bits::<Fp32BitPrime>(0, Gf9Bit::BITS);
let value = Fp32BitPrime::truncate_from(VALUE);

_ = TestWorld::default()
Expand All @@ -216,12 +241,12 @@ pub mod tests {
|ctx, (breakdown_key_share, value_share)| async move {
let validator = ctx.validator();
let ctx = validator.context();
move_single_value_to_bucket::<Gf8Bit, _, _, Fp32BitPrime>(
move_single_value_to_bucket::<Gf9Bit, _, _, Fp32BitPrime>(
ctx.set_total_records(1),
RecordId::from(0),
breakdown_key_share,
value_share,
129,
513,
false,
)
.await
Expand Down
30 changes: 30 additions & 0 deletions src/protocol/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,13 @@ where
F: PrimeField + ExtendableField,
{
let num_records = user_level_attributions.len();

// in case no attributable conversion is found, return 0.
// as anyways the helpers know that no attributions resulted.
if num_records == 0 {
return Ok(vec![S::ZERO; 1 << BK::BITS]);
}

let (bk_vec, tv_vec): (Vec<_>, Vec<_>) = user_level_attributions
.into_iter()
.map(|row| {
Expand Down Expand Up @@ -1244,4 +1251,27 @@ pub mod tests {
assert_eq!(result, &expected);
});
}

#[test]
fn semi_honest_aggregation_empty_input() {
run(|| async move {
let world = TestWorld::default();

let records: Vec<PreAggregationTestInputInBits> = vec![];

let expected = [0_u128; 32];

let result: Vec<_> = world
.semi_honest(records.into_iter(), |ctx, input_rows| async move {
let validator = ctx.validator();
let ctx = validator.context();
do_aggregation::<_, Gf5Bit, Gf3Bit, Fp32BitPrime, _>(ctx, input_rows)
.await
.unwrap()
})
.await
.reconstruct();
assert_eq!(result, &expected);
});
}
}

0 comments on commit 090d94f

Please sign in to comment.