Skip to content

Commit

Permalink
Merge pull request #871 from private-attribution/boolean_array_dot_pr…
Browse files Browse the repository at this point in the history
…oduct

Migrating the feature label dot product to use BooleanArray
  • Loading branch information
benjaminsavage authored Dec 5, 2023
2 parents df322a2 + de967a1 commit 9fd865c
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 95 deletions.
28 changes: 18 additions & 10 deletions ipa-core/src/ff/galois_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use bitvec::prelude::{bitarr, BitArr, Lsb0};
use generic_array::GenericArray;
use typenum::{Unsigned, U1, U2, U3, U4, U5};

use super::ArrayAccess;
use crate::{
ff::{Field, Serializable},
secret_sharing::{Block, SharedValue},
Expand Down Expand Up @@ -165,6 +166,23 @@ macro_rules! bit_array_impl {
}
}

impl ArrayAccess for $name {
type Output = bool;

fn get(&self, index: usize) -> Option<Self::Output> {
if index < usize::try_from(<$name>::BITS).unwrap() {
Some(self.0[index].into())
} else {
None
}
}

fn set(&mut self, index: usize, e: Self::Output) {
debug_assert!(index < usize::try_from(<$name>::BITS).unwrap());
self.0.set(index, bool::from(e));
}
}

impl GaloisField for $name {
const POLYNOMIAL: u128 = $polynomial;
}
Expand Down Expand Up @@ -613,16 +631,6 @@ bit_array_impl!(
0b10_0001_1011_u128,
);

bit_array_impl!(
bit_array_5,
Gf5Bit,
U8_1,
5,
bitarr!(const u8, Lsb0; 1, 0, 0, 0, 0),
// x^5 + x^4 + x^3 + x^2 + x + 1
0b111_111_u128,
);

bit_array_impl!(
bit_array_3,
Gf3Bit,
Expand Down
4 changes: 1 addition & 3 deletions ipa-core/src/ff/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ mod prime_field;
use std::ops::{Add, AddAssign, Sub, SubAssign};

pub use field::{Field, FieldType};
pub use galois_field::{
GaloisField, Gf2, Gf20Bit, Gf32Bit, Gf3Bit, Gf40Bit, Gf5Bit, Gf8Bit, Gf9Bit,
};
pub use galois_field::{GaloisField, Gf2, Gf20Bit, Gf32Bit, Gf3Bit, Gf40Bit, Gf8Bit, Gf9Bit};
use generic_array::{ArrayLength, GenericArray};
#[cfg(any(test, feature = "weak-field"))]
pub use prime_field::Fp31;
Expand Down
37 changes: 22 additions & 15 deletions ipa-core/src/protocol/ipa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use typenum::Unsigned;

use crate::{
error::Error,
ff::{Field, GaloisField, Gf2, PrimeField, Serializable},
ff::{ArrayAccess, Field, Gf2, PrimeField, Serializable},
helpers::{query::IpaQueryConfig, Role},
protocol::{
attribution::secure_attribution,
Expand All @@ -35,7 +35,7 @@ use crate::{
semi_honest::AdditiveShare as Replicated,
ReplicatedSecretSharing,
},
BitDecomposed, Linear as LinearSecretSharing, LinearRefOps,
BitDecomposed, Linear as LinearSecretSharing, LinearRefOps, SharedValue,
},
};

Expand All @@ -60,15 +60,15 @@ pub(crate) enum IPAInputRowResharableStep {

#[derive(Debug)]
#[cfg_attr(test, derive(Clone, PartialEq, Eq))]
pub struct IPAInputRow<F: Field, MK: GaloisField, BK: GaloisField> {
pub struct IPAInputRow<F: Field, MK: SharedValue, BK: SharedValue> {
pub timestamp: Replicated<F>,
pub mk_shares: Replicated<MK>,
pub is_trigger_bit: Replicated<F>,
pub breakdown_key: Replicated<BK>,
pub trigger_value: Replicated<F>,
}

impl<F: Field, MK: GaloisField, BK: GaloisField> Serializable for IPAInputRow<F, MK, BK>
impl<F: Field, MK: SharedValue, BK: SharedValue> Serializable for IPAInputRow<F, MK, BK>
where
Replicated<BK>: Serializable,
Replicated<MK>: Serializable,
Expand Down Expand Up @@ -163,7 +163,7 @@ where
}
}

impl<F: Field, MK: GaloisField, BK: GaloisField> IPAInputRow<F, MK, BK>
impl<F: Field, MK: SharedValue, BK: SharedValue> IPAInputRow<F, MK, BK>
where
IPAInputRow<F, MK, BK>: Serializable,
{
Expand Down Expand Up @@ -323,8 +323,8 @@ where
+ 'static,
for<'r> &'r SB: LinearRefOps<'r, SB, Gf2>,
F: PrimeField + ExtendableField,
MK: GaloisField,
BK: GaloisField,
MK: SharedValue + ArrayAccess<Output = bool>,
BK: SharedValue + ArrayAccess<Output = bool>,
ShuffledPermutationWrapper<S, C::UpgradedContext<F>>: DowngradeMalicious<Target = Vec<u32>>,
for<'u> UpgradeContext<'u, C::UpgradedContext<F>, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple<Replicated<F>>, BitConversionTriple<S>>
+ UpgradeToMalicious<
Expand Down Expand Up @@ -413,16 +413,16 @@ fn get_gf2_match_key_bits<F, MK, BK>(
) -> Vec<BitDecomposed<Replicated<Gf2>>>
where
F: PrimeField,
MK: GaloisField,
BK: GaloisField,
MK: SharedValue + ArrayAccess<Output = bool>,
BK: SharedValue + ArrayAccess<Output = bool>,
{
input_rows
.iter()
.map(|row| {
BitDecomposed::decompose(MK::BITS, |i| {
Replicated::new(
Gf2::truncate_from(row.mk_shares.left()[i]),
Gf2::truncate_from(row.mk_shares.right()[i]),
Gf2::truncate_from(row.mk_shares.left().get(i.try_into().unwrap()).unwrap()),
Gf2::truncate_from(row.mk_shares.right().get(i.try_into().unwrap()).unwrap()),
)
})
})
Expand All @@ -434,16 +434,23 @@ fn get_gf2_breakdown_key_bits<F, MK, BK>(
) -> Vec<BitDecomposed<Replicated<Gf2>>>
where
F: PrimeField,
MK: GaloisField,
BK: GaloisField,
MK: SharedValue + ArrayAccess<Output = bool>,
BK: SharedValue + ArrayAccess<Output = bool>,
{
input_rows
.iter()
.map(|row| {
BitDecomposed::decompose(BK::BITS, |i| {
Replicated::new(
Gf2::truncate_from(row.breakdown_key.left()[i]),
Gf2::truncate_from(row.breakdown_key.right()[i]),
Gf2::truncate_from(
row.breakdown_key.left().get(i.try_into().unwrap()).unwrap(),
),
Gf2::truncate_from(
row.breakdown_key
.right()
.get(i.try_into().unwrap())
.unwrap(),
),
)
})
})
Expand Down
Loading

0 comments on commit 9fd865c

Please sign in to comment.