Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compact gate follow up #1078

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
9 changes: 0 additions & 9 deletions ipa-core/src/protocol/basics/mul/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,3 @@ pub(crate) enum MaliciousMultiplyStep {
RandomnessForValidation,
ReshareRx,
}

// This is a dummy step that is used to narrow (but never executed) the semi-honest
// context. Semi-honest implementations of `UpgradedContext::upgrade()` and subsequent
// `UpgradeToMalicious::upgrade()` narrows but these will end up in
// `UpgradedContext::upgrade_one()` or `UpgradedContext::upgrade_sparse()` which both
// return Ok() and never trigger communications.
#[derive(CompactStep)]
#[step(name = "upgrade")]
pub(crate) struct UpgradeStep;
16 changes: 10 additions & 6 deletions ipa-core/src/protocol/boolean/and.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
use crate::{
error::Error,
ff::boolean::Boolean,
protocol::{basics::SecureMul, boolean::step::BoolAndStep, context::Context, RecordId},
protocol::{
basics::SecureMul,
boolean::{step::EightBitStep, NBitStep},
context::Context,
RecordId,
},
secret_sharing::{replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd},
};

const MAX_BITS: usize = 8;

/// Matrix bitwise AND for use with vectors of bit-decomposed values. Supports up to 8 bits of input
/// that is enough to support both WALR and PRF IPA use cases.
///
Expand Down Expand Up @@ -40,14 +43,15 @@
let b = b.into_iter();
assert_eq!(a.len(), b.len());
assert!(
a.len() <= MAX_BITS,
"Up to {MAX_BITS} values are supported, but was given a value of {len} bits",
a.len() <= usize::try_from(EightBitStep::BITS).unwrap(),
"Up to {max_bits} bit values are supported, but was given a value of {len} bits",
max_bits = EightBitStep::BITS,

Check warning on line 48 in ipa-core/src/protocol/boolean/and.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/boolean/and.rs#L47-L48

Added lines #L47 - L48 were not covered by tests
len = a.len()
);

BitDecomposed::try_from(
ctx.parallel_join(zip(a.iter(), b).enumerate().map(|(i, (a, b))| {
let ctx = ctx.narrow(&BoolAndStep::Bit(i));
let ctx = ctx.narrow(&EightBitStep::Bit(i));
a.multiply(b, ctx, record_id)
}))
.await?,
Expand Down
34 changes: 12 additions & 22 deletions ipa-core/src/protocol/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,27 @@ pub(crate) mod step;
///
/// This is a temporary solution for narrowing contexts until the infra is
/// updated with a new step scheme.
pub trait BitStep: Step + From<usize> {
fn max_bit_depth() -> u32;
pub trait NBitStep: Step + From<usize> {
const BITS: u32;
}

impl BitStep for EightBitStep {
fn max_bit_depth() -> u32 {
8
}
impl NBitStep for EightBitStep {
const BITS: u32 = 8;
}

impl BitStep for SixteenBitStep {
fn max_bit_depth() -> u32 {
16
}
impl NBitStep for SixteenBitStep {
const BITS: u32 = 16;
}

impl BitStep for ThirtyTwoBitStep {
fn max_bit_depth() -> u32 {
32
}
impl NBitStep for ThirtyTwoBitStep {
const BITS: u32 = 32;
}

impl BitStep for TwoHundredFiftySixBitOpStep {
fn max_bit_depth() -> u32 {
256
}
impl NBitStep for TwoHundredFiftySixBitOpStep {
const BITS: u32 = 256;
}

#[cfg(test)]
impl BitStep for crate::protocol::boolean::step::DefaultBitStep {
fn max_bit_depth() -> u32 {
256
}
impl NBitStep for crate::protocol::boolean::step::DefaultBitStep {
const BITS: u32 = 256;
}
6 changes: 0 additions & 6 deletions ipa-core/src/protocol/boolean/step.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
use ipa_step_derive::CompactStep;

#[derive(CompactStep)]
pub(crate) enum BoolAndStep {
#[step(count = 8)] // keep in sync with MAX_BITS defined inside and.rs
Bit(usize),
}

#[derive(CompactStep)]
pub enum EightBitStep {
#[step(count = 8)]
Expand Down
8 changes: 4 additions & 4 deletions ipa-core/src/protocol/ipa_prf/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
},
protocol::{
basics::{BooleanArrayMul, BooleanProtocols},
boolean::{step::SixteenBitStep, BitStep},
boolean::{step::SixteenBitStep, NBitStep},
context::{Context, UpgradedSemiHonestContext},
ipa_prf::{
aggregation::step::{AggregateValuesStep, AggregationStep as Step},
Expand Down Expand Up @@ -266,12 +266,12 @@ where
let record_id = RecordId::from(i);
if a.len() < usize::try_from(OV::BITS).unwrap() {
assert!(
OV::BITS <= SixteenBitStep::max_bit_depth(),
OV::BITS <= SixteenBitStep::BITS,
"SixteenBitStep not large enough to accomodate this sum"
);
// If we have enough output bits, add and keep the carry.
let (mut sum, carry) = integer_add::<_, SixteenBitStep, B>(
ctx.narrow(&AggregateValuesStep::OverflowingAdd),
ctx.narrow(&AggregateValuesStep::Add),
record_id,
&a,
&b,
Expand All @@ -281,7 +281,7 @@ where
Ok(sum)
} else {
assert!(
OV::BITS <= SixteenBitStep::max_bit_depth(),
OV::BITS <= SixteenBitStep::BITS,
"SixteenBitStep not large enough to accomodate this sum"
);
integer_sat_add::<_, SixteenBitStep, B>(
Expand Down
17 changes: 4 additions & 13 deletions ipa-core/src/protocol/ipa_prf/aggregation/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,19 @@ pub(crate) enum AggregationStep {
}

#[derive(CompactStep)]
pub enum BucketStep {
/// should be equal to `MAX_BREAKDOWNS`
#[step(count = 512, child = crate::protocol::boolean::step::BoolAndStep)]
Bit(usize),
}

impl From<u32> for BucketStep {
fn from(v: u32) -> Self {
Self::Bit(usize::try_from(v).unwrap())
}
}
#[step(count = 512, child = crate::protocol::boolean::step::EightBitStep, name = "b")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep the comment about keeping this in sync with MAX_BREAKDOWNS.

pub struct BucketStep(usize);

impl From<usize> for BucketStep {
fn from(v: usize) -> Self {
Self::Bit(v)
Self(v)
}
}

#[derive(CompactStep)]
pub(crate) enum AggregateValuesStep {
#[step(child = crate::protocol::boolean::step::SixteenBitStep)]
OverflowingAdd,
Add,
#[step(child = crate::protocol::ipa_prf::boolean_ops::step::SaturatedAdditionStep)]
SaturatingAdd,
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
helpers::repeat_n,
protocol::{
basics::{BooleanProtocols, SecureMul},
boolean::{or::bool_or, BitStep},
boolean::{or::bool_or, NBitStep},
context::{Context, UpgradedSemiHonestContext},
Gate, RecordId,
},
Expand Down Expand Up @@ -38,7 +38,7 @@ pub async fn integer_add<C, S, const N: usize>(
>
where
C: Context,
S: BitStep,
S: NBitStep,
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: BooleanProtocols<C, N>,
Gate: StepNarrow<S>,
Expand All @@ -61,7 +61,7 @@ pub async fn integer_sat_add<'a, SH, S, const N: usize>(
) -> Result<BitDecomposed<AdditiveShare<Boolean, N>>, Error>
where
SH: ShardBinding,
S: BitStep,
S: NBitStep,
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: BooleanProtocols<UpgradedSemiHonestContext<'a, SH, Boolean>, N>,
Gate: StepNarrow<S>,
Expand Down Expand Up @@ -100,7 +100,7 @@ async fn addition_circuit<C, S, const N: usize>(
) -> Result<BitDecomposed<AdditiveShare<Boolean, N>>, Error>
where
C: Context,
S: BitStep,
S: NBitStep,
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: BooleanProtocols<C, N>,
Gate: StepNarrow<S>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
ff::{boolean::Boolean, ArrayAccessRef, CustomArray, Field},
protocol::{
basics::{select, BooleanArrayMul, BooleanProtocols, SecureMul, ShareKnownValue},
boolean::BitStep,
boolean::NBitStep,
context::{Context, SemiHonestContext},
Gate, RecordId,
},
Expand All @@ -36,7 +36,7 @@ pub async fn compare_geq<C, S>(
) -> Result<AdditiveShare<Boolean>, Error>
where
C: Context,
S: BitStep,
S: NBitStep,
AdditiveShare<Boolean>: BooleanProtocols<C>,
Gate: StepNarrow<S>,
{
Expand All @@ -60,7 +60,7 @@ pub async fn compare_gt<C, S, const N: usize>(
) -> Result<AdditiveShare<Boolean, N>, Error>
where
C: Context,
S: BitStep,
S: NBitStep,
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: BooleanProtocols<C, N>,
Gate: StepNarrow<S>,
Expand All @@ -85,7 +85,7 @@ pub async fn integer_sub<C, S>(
) -> Result<BitDecomposed<AdditiveShare<Boolean>>, Error>
where
C: Context,
S: BitStep,
S: NBitStep,
AdditiveShare<Boolean>: BooleanProtocols<C>,
Gate: StepNarrow<S>,
{
Expand All @@ -108,7 +108,7 @@ pub async fn integer_sat_sub<S, St>(
) -> Result<AdditiveShare<S>, Error>
where
S: SharedValue + CustomArray<Element = Boolean>,
St: BitStep,
St: NBitStep,
for<'a> AdditiveShare<S>: BooleanArrayMul<SemiHonestContext<'a>>,
Gate: StepNarrow<St>,
{
Expand Down Expand Up @@ -154,7 +154,7 @@ async fn subtraction_circuit<C, S, const N: usize>(
) -> Result<BitDecomposed<AdditiveShare<Boolean, N>>, Error>
where
C: Context,
S: BitStep,
S: NBitStep,
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: BooleanProtocols<C, N>,
Gate: StepNarrow<S>,
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/ipa_prf/boolean_ops/step.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use ipa_step_derive::CompactStep;

/// FIXME: This step is not generic enough to be used in the `saturated_addition` protocol.
/// It constraints the input to be at most 2 bytes and it will panic in runtime if it is greater
/// It constrains the input to be at most 2 bytes and it will panic in runtime if it is greater
/// than that. The issue is that compact gate requires concrete type to be put as child.
/// If we ever see it being an issue, we should make a few implementations of this similar to what
/// we've done for bit steps
Expand Down
18 changes: 9 additions & 9 deletions ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::{
boolean::{
or::or,
step::{EightBitStep, ThirtyTwoBitStep},
BitStep,
NBitStep,
},
context::{
Context, SemiHonestContext, UpgradableContext, UpgradedSemiHonestContext, Validator,
Expand All @@ -38,8 +38,8 @@ use crate::{
},
prf_sharding::step::{
AttributionPerRowStep as PerRowStep, AttributionStep as Step,
AttributionWindowStep as WindowStep, AttributionZeroTriggerStep as ZeroStep,
UserNthRowStep,
AttributionWindowStep as WindowStep,
AttributionZeroOutTriggerStep as ZeroOutTriggerStep, UserNthRowStep,
},
AGG_CHUNK,
},
Expand Down Expand Up @@ -197,7 +197,7 @@ where
.await?;

assert!(
TV::BITS <= EightBitStep::max_bit_depth(),
TV::BITS <= EightBitStep::BITS,
"EightBitStep not large enough to accomodate this sum"
);
let (updated_sum, overflow_bit) = integer_add::<_, EightBitStep, 1>(
Expand All @@ -209,7 +209,7 @@ where
.await?;

assert!(
TV::BITS <= EightBitStep::max_bit_depth(),
TV::BITS <= EightBitStep::BITS,
"EightBitStep not large enough to accomodate this subtraction"
);
let (overflow_bit_and_prev_row_not_saturated, difference_to_cap) = try_join(
Expand Down Expand Up @@ -616,11 +616,11 @@ where
let (did_trigger_get_attributed, is_trigger_within_window) = try_join(
is_trigger_bit.multiply(
ever_encountered_a_source_event,
ctx.narrow(&ZeroStep::DidTriggerGetAttributed),
ctx.narrow(&ZeroOutTriggerStep::DidTriggerGetAttributed),
record_id,
),
is_trigger_event_within_attribution_window(
ctx.narrow(&ZeroStep::CheckAttributionWindow),
ctx.narrow(&ZeroOutTriggerStep::CheckAttributionWindow),
record_id,
attribution_window_seconds,
trigger_event_timestamp,
Expand All @@ -631,7 +631,7 @@ where

// save 1 multiplication if there is no attribution window
let zero_out_flag = if attribution_window_seconds.is_some() {
let c = ctx.narrow(&ZeroStep::AttributedEventCheckFlag);
let c = ctx.narrow(&ZeroOutTriggerStep::AttributedEventCheckFlag);
did_trigger_get_attributed
.multiply(&is_trigger_within_window, c, record_id)
.await?
Expand Down Expand Up @@ -667,7 +667,7 @@ where
{
if let Some(attribution_window_seconds) = attribution_window_seconds {
assert!(
TS::BITS <= ThirtyTwoBitStep::max_bit_depth(),
TS::BITS <= ThirtyTwoBitStep::BITS,
"ThirtyTwoBitStep is not large enough to accomodate this subtraction"
);
let time_delta_bits = integer_sub::<_, ThirtyTwoBitStep>(
Expand Down
5 changes: 2 additions & 3 deletions ipa-core/src/protocol/ipa_prf/prf_sharding/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ pub(crate) enum AttributionStep {
#[step(child = UserNthRowStep)]
BinaryValidator,
PrimeFieldValidator,
ModulusConvertBreakdownKeyBitsAndTriggerValues,
#[step(child = crate::protocol::ipa_prf::aggregation::step::AggregationStep)]
Aggregate,
}
Expand All @@ -26,7 +25,7 @@ pub(crate) enum AttributionStep {
pub(crate) enum AttributionPerRowStep {
EverEncounteredSourceEvent,
AttributedBreakdownKey,
#[step(child = AttributionZeroTriggerStep)]
#[step(child = AttributionZeroOutTriggerStep)]
AttributedTriggerValue,
SourceEventTimestamp,
#[step(child = crate::protocol::boolean::step::EightBitStep)]
Expand All @@ -39,7 +38,7 @@ pub(crate) enum AttributionPerRowStep {
}

#[derive(CompactStep)]
pub(crate) enum AttributionZeroTriggerStep {
pub(crate) enum AttributionZeroOutTriggerStep {
DidTriggerGetAttributed,
#[step(child = AttributionWindowStep)]
CheckAttributionWindow,
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/protocol/ipa_prf/quicksort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
helpers::stream::{process_stream_by_chunks, ChunkBuffer, TryFlattenItersExt},
protocol::{
basics::Reveal,
boolean::{step::ThirtyTwoBitStep, BitStep},
boolean::{step::ThirtyTwoBitStep, NBitStep},
context::{Context, SemiHonestContext},
ipa_prf::{
boolean_ops::comparison_and_subtraction_sequential::compare_gt,
Expand Down Expand Up @@ -168,7 +168,7 @@ where
.map(Ok);

assert!(
K::BITS <= ThirtyTwoBitStep::max_bit_depth(),
K::BITS <= ThirtyTwoBitStep::BITS,
"ThirtyTwoBitStep is not large enough to accommodate this sort"
);
let comp: BitVec<usize, Lsb0> = seq_join(
Expand Down
Loading