Skip to content

Commit

Permalink
Merge pull request #1078 from akoshelev/compact-gate-followup
Browse files Browse the repository at this point in the history
Compact gate follow up
  • Loading branch information
benjaminsavage authored May 20, 2024
2 parents 98fc9ff + f1d5dcb commit 7d76bd8
Show file tree
Hide file tree
Showing 17 changed files with 59 additions and 93 deletions.
Empty file.
9 changes: 0 additions & 9 deletions ipa-core/src/protocol/basics/mul/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,3 @@ pub(crate) enum MaliciousMultiplyStep {
RandomnessForValidation,
ReshareRx,
}

// This is a dummy step that is used to narrow (but never executed) the semi-honest
// context. Semi-honest implementations of `UpgradedContext::upgrade()` and subsequent
// `UpgradeToMalicious::upgrade()` narrows but these will end up in
// `UpgradedContext::upgrade_one()` or `UpgradedContext::upgrade_sparse()` which both
// return Ok() and never trigger communications.
#[derive(CompactStep)]
#[step(name = "upgrade")]
pub(crate) struct UpgradeStep;
16 changes: 10 additions & 6 deletions ipa-core/src/protocol/boolean/and.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ use std::iter::zip;
use crate::{
error::Error,
ff::boolean::Boolean,
protocol::{basics::SecureMul, boolean::step::BoolAndStep, context::Context, RecordId},
protocol::{
basics::SecureMul,
boolean::{step::EightBitStep, NBitStep},
context::Context,
RecordId,
},
secret_sharing::{replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd},
};

const MAX_BITS: usize = 8;

/// Matrix bitwise AND for use with vectors of bit-decomposed values. Supports up to 8 bits of input
/// that is enough to support both WALR and PRF IPA use cases.
///
Expand Down Expand Up @@ -40,14 +43,15 @@ where
let b = b.into_iter();
assert_eq!(a.len(), b.len());
assert!(
a.len() <= MAX_BITS,
"Up to {MAX_BITS} values are supported, but was given a value of {len} bits",
a.len() <= usize::try_from(EightBitStep::BITS).unwrap(),
"Up to {max_bits} bit values are supported, but was given a value of {len} bits",
max_bits = EightBitStep::BITS,
len = a.len()
);

BitDecomposed::try_from(
ctx.parallel_join(zip(a.iter(), b).enumerate().map(|(i, (a, b))| {
let ctx = ctx.narrow(&BoolAndStep::Bit(i));
let ctx = ctx.narrow(&EightBitStep::Bit(i));
a.multiply(b, ctx, record_id)
}))
.await?,
Expand Down
34 changes: 12 additions & 22 deletions ipa-core/src/protocol/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,27 @@ pub(crate) mod step;
///
/// This is a temporary solution for narrowing contexts until the infra is
/// updated with a new step scheme.
pub trait BitStep: Step + From<usize> {
fn max_bit_depth() -> u32;
pub trait NBitStep: Step + From<usize> {
const BITS: u32;
}

impl BitStep for EightBitStep {
fn max_bit_depth() -> u32 {
8
}
impl NBitStep for EightBitStep {
const BITS: u32 = 8;
}

impl BitStep for SixteenBitStep {
fn max_bit_depth() -> u32 {
16
}
impl NBitStep for SixteenBitStep {
const BITS: u32 = 16;
}

impl BitStep for ThirtyTwoBitStep {
fn max_bit_depth() -> u32 {
32
}
impl NBitStep for ThirtyTwoBitStep {
const BITS: u32 = 32;
}

impl BitStep for TwoHundredFiftySixBitOpStep {
fn max_bit_depth() -> u32 {
256
}
impl NBitStep for TwoHundredFiftySixBitOpStep {
const BITS: u32 = 256;
}

#[cfg(test)]
impl BitStep for crate::protocol::boolean::step::DefaultBitStep {
fn max_bit_depth() -> u32 {
256
}
impl NBitStep for crate::protocol::boolean::step::DefaultBitStep {
const BITS: u32 = 256;
}
6 changes: 0 additions & 6 deletions ipa-core/src/protocol/boolean/step.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
use ipa_step_derive::CompactStep;

#[derive(CompactStep)]
pub(crate) enum BoolAndStep {
#[step(count = 8)] // keep in sync with MAX_BITS defined inside and.rs
Bit(usize),
}

#[derive(CompactStep)]
pub enum EightBitStep {
#[step(count = 8)]
Expand Down
8 changes: 4 additions & 4 deletions ipa-core/src/protocol/ipa_prf/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
},
protocol::{
basics::{BooleanArrayMul, BooleanProtocols},
boolean::{step::SixteenBitStep, BitStep},
boolean::{step::SixteenBitStep, NBitStep},
context::{Context, UpgradedSemiHonestContext},
ipa_prf::{
aggregation::step::{AggregateValuesStep, AggregationStep as Step},
Expand Down Expand Up @@ -266,12 +266,12 @@ where
let record_id = RecordId::from(i);
if a.len() < usize::try_from(OV::BITS).unwrap() {
assert!(
OV::BITS <= SixteenBitStep::max_bit_depth(),
OV::BITS <= SixteenBitStep::BITS,
"SixteenBitStep not large enough to accomodate this sum"
);
// If we have enough output bits, add and keep the carry.
let (mut sum, carry) = integer_add::<_, SixteenBitStep, B>(
ctx.narrow(&AggregateValuesStep::OverflowingAdd),
ctx.narrow(&AggregateValuesStep::Add),
record_id,
&a,
&b,
Expand All @@ -281,7 +281,7 @@ where
Ok(sum)
} else {
assert!(
OV::BITS <= SixteenBitStep::max_bit_depth(),
OV::BITS <= SixteenBitStep::BITS,
"SixteenBitStep not large enough to accomodate this sum"
);
integer_sat_add::<_, SixteenBitStep, B>(
Expand Down
17 changes: 4 additions & 13 deletions ipa-core/src/protocol/ipa_prf/aggregation/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,19 @@ pub(crate) enum AggregationStep {
}

#[derive(CompactStep)]
pub enum BucketStep {
/// should be equal to `MAX_BREAKDOWNS`
#[step(count = 512, child = crate::protocol::boolean::step::BoolAndStep)]
Bit(usize),
}

impl From<u32> for BucketStep {
fn from(v: u32) -> Self {
Self::Bit(usize::try_from(v).unwrap())
}
}
#[step(count = 512, child = crate::protocol::boolean::step::EightBitStep, name = "b")]
pub struct BucketStep(usize);

impl From<usize> for BucketStep {
fn from(v: usize) -> Self {
Self::Bit(v)
Self(v)
}
}

#[derive(CompactStep)]
pub(crate) enum AggregateValuesStep {
#[step(child = crate::protocol::boolean::step::SixteenBitStep)]
OverflowingAdd,
Add,
#[step(child = crate::protocol::ipa_prf::boolean_ops::step::SaturatedAdditionStep)]
SaturatingAdd,
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
helpers::repeat_n,
protocol::{
basics::{BooleanProtocols, SecureMul},
boolean::{or::bool_or, BitStep},
boolean::{or::bool_or, NBitStep},
context::{Context, UpgradedSemiHonestContext},
Gate, RecordId,
},
Expand Down Expand Up @@ -38,7 +38,7 @@ pub async fn integer_add<C, S, const N: usize>(
>
where
C: Context,
S: BitStep,
S: NBitStep,
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: BooleanProtocols<C, N>,
Gate: StepNarrow<S>,
Expand All @@ -61,7 +61,7 @@ pub async fn integer_sat_add<'a, SH, S, const N: usize>(
) -> Result<BitDecomposed<AdditiveShare<Boolean, N>>, Error>
where
SH: ShardBinding,
S: BitStep,
S: NBitStep,
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: BooleanProtocols<UpgradedSemiHonestContext<'a, SH, Boolean>, N>,
Gate: StepNarrow<S>,
Expand Down Expand Up @@ -100,7 +100,7 @@ async fn addition_circuit<C, S, const N: usize>(
) -> Result<BitDecomposed<AdditiveShare<Boolean, N>>, Error>
where
C: Context,
S: BitStep,
S: NBitStep,
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: BooleanProtocols<C, N>,
Gate: StepNarrow<S>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
ff::{boolean::Boolean, ArrayAccessRef, CustomArray, Field},
protocol::{
basics::{select, BooleanArrayMul, BooleanProtocols, SecureMul, ShareKnownValue},
boolean::BitStep,
boolean::NBitStep,
context::{Context, SemiHonestContext},
Gate, RecordId,
},
Expand All @@ -36,7 +36,7 @@ pub async fn compare_geq<C, S>(
) -> Result<AdditiveShare<Boolean>, Error>
where
C: Context,
S: BitStep,
S: NBitStep,
AdditiveShare<Boolean>: BooleanProtocols<C>,
Gate: StepNarrow<S>,
{
Expand All @@ -60,7 +60,7 @@ pub async fn compare_gt<C, S, const N: usize>(
) -> Result<AdditiveShare<Boolean, N>, Error>
where
C: Context,
S: BitStep,
S: NBitStep,
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: BooleanProtocols<C, N>,
Gate: StepNarrow<S>,
Expand All @@ -85,7 +85,7 @@ pub async fn integer_sub<C, S>(
) -> Result<BitDecomposed<AdditiveShare<Boolean>>, Error>
where
C: Context,
S: BitStep,
S: NBitStep,
AdditiveShare<Boolean>: BooleanProtocols<C>,
Gate: StepNarrow<S>,
{
Expand All @@ -108,7 +108,7 @@ pub async fn integer_sat_sub<S, St>(
) -> Result<AdditiveShare<S>, Error>
where
S: SharedValue + CustomArray<Element = Boolean>,
St: BitStep,
St: NBitStep,
for<'a> AdditiveShare<S>: BooleanArrayMul<SemiHonestContext<'a>>,
Gate: StepNarrow<St>,
{
Expand Down Expand Up @@ -154,7 +154,7 @@ async fn subtraction_circuit<C, S, const N: usize>(
) -> Result<BitDecomposed<AdditiveShare<Boolean, N>>, Error>
where
C: Context,
S: BitStep,
S: NBitStep,
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: BooleanProtocols<C, N>,
Gate: StepNarrow<S>,
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/ipa_prf/boolean_ops/step.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use ipa_step_derive::CompactStep;

/// FIXME: This step is not generic enough to be used in the `saturated_addition` protocol.
/// It constraints the input to be at most 2 bytes and it will panic in runtime if it is greater
/// It constrains the input to be at most 2 bytes and it will panic in runtime if it is greater
/// than that. The issue is that compact gate requires concrete type to be put as child.
/// If we ever see it being an issue, we should make a few implementations of this similar to what
/// we've done for bit steps
Expand Down
18 changes: 9 additions & 9 deletions ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::{
boolean::{
or::or,
step::{EightBitStep, ThirtyTwoBitStep},
BitStep,
NBitStep,
},
context::{
Context, SemiHonestContext, UpgradableContext, UpgradedSemiHonestContext, Validator,
Expand All @@ -38,8 +38,8 @@ use crate::{
},
prf_sharding::step::{
AttributionPerRowStep as PerRowStep, AttributionStep as Step,
AttributionWindowStep as WindowStep, AttributionZeroTriggerStep as ZeroStep,
UserNthRowStep,
AttributionWindowStep as WindowStep,
AttributionZeroOutTriggerStep as ZeroOutTriggerStep, UserNthRowStep,
},
AGG_CHUNK,
},
Expand Down Expand Up @@ -197,7 +197,7 @@ where
.await?;

assert!(
TV::BITS <= EightBitStep::max_bit_depth(),
TV::BITS <= EightBitStep::BITS,
"EightBitStep not large enough to accomodate this sum"
);
let (updated_sum, overflow_bit) = integer_add::<_, EightBitStep, 1>(
Expand All @@ -209,7 +209,7 @@ where
.await?;

assert!(
TV::BITS <= EightBitStep::max_bit_depth(),
TV::BITS <= EightBitStep::BITS,
"EightBitStep not large enough to accomodate this subtraction"
);
let (overflow_bit_and_prev_row_not_saturated, difference_to_cap) = try_join(
Expand Down Expand Up @@ -616,11 +616,11 @@ where
let (did_trigger_get_attributed, is_trigger_within_window) = try_join(
is_trigger_bit.multiply(
ever_encountered_a_source_event,
ctx.narrow(&ZeroStep::DidTriggerGetAttributed),
ctx.narrow(&ZeroOutTriggerStep::DidTriggerGetAttributed),
record_id,
),
is_trigger_event_within_attribution_window(
ctx.narrow(&ZeroStep::CheckAttributionWindow),
ctx.narrow(&ZeroOutTriggerStep::CheckAttributionWindow),
record_id,
attribution_window_seconds,
trigger_event_timestamp,
Expand All @@ -631,7 +631,7 @@ where

// save 1 multiplication if there is no attribution window
let zero_out_flag = if attribution_window_seconds.is_some() {
let c = ctx.narrow(&ZeroStep::AttributedEventCheckFlag);
let c = ctx.narrow(&ZeroOutTriggerStep::AttributedEventCheckFlag);
did_trigger_get_attributed
.multiply(&is_trigger_within_window, c, record_id)
.await?
Expand Down Expand Up @@ -667,7 +667,7 @@ where
{
if let Some(attribution_window_seconds) = attribution_window_seconds {
assert!(
TS::BITS <= ThirtyTwoBitStep::max_bit_depth(),
TS::BITS <= ThirtyTwoBitStep::BITS,
"ThirtyTwoBitStep is not large enough to accomodate this subtraction"
);
let time_delta_bits = integer_sub::<_, ThirtyTwoBitStep>(
Expand Down
5 changes: 2 additions & 3 deletions ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ pub(crate) enum AttributionStep {
#[step(child = UserNthRowStep)]
BinaryValidator,
PrimeFieldValidator,
ModulusConvertBreakdownKeyBitsAndTriggerValues,
#[step(child = crate::protocol::ipa_prf::aggregation::step::AggregationStep)]
Aggregate,
}
Expand All @@ -26,7 +25,7 @@ pub(crate) enum AttributionStep {
pub(crate) enum AttributionPerRowStep {
EverEncounteredSourceEvent,
AttributedBreakdownKey,
#[step(child = AttributionZeroTriggerStep)]
#[step(child = AttributionZeroOutTriggerStep)]
AttributedTriggerValue,
SourceEventTimestamp,
#[step(child = crate::protocol::boolean::step::EightBitStep)]
Expand All @@ -39,7 +38,7 @@ pub(crate) enum AttributionPerRowStep {
}

#[derive(CompactStep)]
pub(crate) enum AttributionZeroTriggerStep {
pub(crate) enum AttributionZeroOutTriggerStep {
DidTriggerGetAttributed,
#[step(child = AttributionWindowStep)]
CheckAttributionWindow,
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/protocol/ipa_prf/quicksort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
helpers::stream::{process_stream_by_chunks, ChunkBuffer, TryFlattenItersExt},
protocol::{
basics::Reveal,
boolean::{step::ThirtyTwoBitStep, BitStep},
boolean::{step::ThirtyTwoBitStep, NBitStep},
context::{Context, SemiHonestContext},
ipa_prf::{
boolean_ops::comparison_and_subtraction_sequential::compare_gt,
Expand Down Expand Up @@ -168,7 +168,7 @@ where
.map(Ok);

assert!(
K::BITS <= ThirtyTwoBitStep::max_bit_depth(),
K::BITS <= ThirtyTwoBitStep::BITS,
"ThirtyTwoBitStep is not large enough to accommodate this sort"
);
let comp: BitVec<usize, Lsb0> = seq_join(
Expand Down
Loading

0 comments on commit 7d76bd8

Please sign in to comment.