From 2f1c63b30e25c4a748d25a0a175e11d7eba8658b Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 28 Sep 2023 00:34:07 +0800 Subject: [PATCH] martinthomson's comments --- src/protocol/boolean/saturating_sum.rs | 6 +- src/protocol/prf_sharding/mod.rs | 316 ++++++++++++------------- 2 files changed, 158 insertions(+), 164 deletions(-) diff --git a/src/protocol/boolean/saturating_sum.rs b/src/protocol/boolean/saturating_sum.rs index da868d112..c3ae098b3 100644 --- a/src/protocol/boolean/saturating_sum.rs +++ b/src/protocol/boolean/saturating_sum.rs @@ -120,7 +120,7 @@ impl> SaturatingSum { /// This can be computed "for free" in Gf2 /// /// The `carry_out` bit can be efficiently computed with just a single multiplication as: -/// `c_(i+1) = c_i ⊕ ((x_i ⊕ c_i) ∧ (y_i ⊕ c_i))` +/// `c_(i+1) = c_i ⊕ ((x_i ⊕ c_i) & (y_i ⊕ c_i))` /// /// Returns `sum_bit` /// @@ -185,11 +185,11 @@ where SB: LinearSecretSharing + BasicProtocols, { // compute difference bit as not_y XOR x XOR carry_in - let difference_bit = SB::share_known_value(&ctx, Gf2::ONE) - y + x + carry_in; + let difference_bit = SB::share_known_value(&ctx, Gf2::ONE) + y + x + carry_in; if compute_carry_out { let x_xor_carry_in = x.clone() + carry_in; let y_xor_carry_in = y.clone() + carry_in; - let not_y_xor_carry_in = SB::share_known_value(&ctx, Gf2::ONE) - &y_xor_carry_in; + let not_y_xor_carry_in = SB::share_known_value(&ctx, Gf2::ONE) + &y_xor_carry_in; *carry_in = x_xor_carry_in .multiply(¬_y_xor_carry_in, ctx, record_id) diff --git a/src/protocol/prf_sharding/mod.rs b/src/protocol/prf_sharding/mod.rs index b3815dc4e..ce2da899d 100644 --- a/src/protocol/prf_sharding/mod.rs +++ b/src/protocol/prf_sharding/mod.rs @@ -1,9 +1,9 @@ -use std::iter::repeat; +use std::iter::{repeat, zip}; use futures_util::future::try_join; use ipa_macros::Step; -use super::{boolean::saturating_sum::SaturatingSum, step::BitOpStep}; +use super::{basics::if_else, boolean::saturating_sum::SaturatingSum, step::BitOpStep}; use crate::{ error::Error, ff::{Field, GaloisField, Gf2}, @@ -15,7 +15,7 @@ use crate::{ }, secret_sharing::{ replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, - BitDecomposed, + BitDecomposed, SharedValue, }, seq_join::seq_try_join_all, }; @@ -34,6 +34,135 @@ struct InputsRequiredFromPrevRow { difference_to_cap: BitDecomposed>, } +impl InputsRequiredFromPrevRow { + /// + /// This function contains the main logic for the per-user attribution circuit. + /// Multiple rows of data about a single user are processed in-order from oldest to newest. + /// + /// Summary: + /// - Last touch attribution + /// - Every trigger event which is preceded by a source event is attributed + /// - Trigger events are attributed to the `breakdown_key` of the most recent preceding source event + /// - Per user capping + /// - A cumulative sum of "Attributed Trigger Value" is maintained + /// - Bitwise addition is used, and a single bit indicates if the sum is "saturated" + /// - The only available values for "cap" are powers of 2 (i.e. 1, 2, 4, 8, 16, 32, ...) + /// - Prior to the cumulative sum reaching saturation, attributed trigger values are passed along + /// - The row which puts the cumulative sum over the cap is "capped" to the delta between the cumulative sum of the last row and the cap + /// - All subsequent rows contribute zero + /// - Outputs + /// - If a user has `N` input rows, they will generate `N-1` output rows. (The first row cannot possibly contribute any value to the output) + /// - Each output row has two main values: + /// - `capped_attributed_trigger_value` - the value to contribute to the output (bitwise secret-shared), + /// - `attributed_breakdown_key` - the breakdown to which this contribution applies (bitwise secret-shared), + /// - Additional output: + /// - `did_trigger_get_attributed` - a secret-shared bit indicating if this row corresponds to a trigger event + /// which was attributed. Might be able to reveal this (after a shuffle and the addition of dummies) to minimize + /// the amount of processing work that must be done in the Aggregation stage. + pub async fn compute_row_with_previous( + &mut self, + ctx: C, + record_id: RecordId, + input_row: &PrfShardedIpaInputRow, + num_saturating_sum_bits: usize, + ) -> Result + where + C: UpgradedContext>, + BK: GaloisField, + TV: GaloisField, + { + let bd_key = BitDecomposed::decompose(BK::BITS, |i| { + input_row.breakdown_key.map(|v| Gf2::truncate_from(v[i])) + }); + let tv = BitDecomposed::decompose(TV::BITS, |i| { + input_row.trigger_value.map(|v| Gf2::truncate_from(v[i])) + }); + assert_eq!(self.saturating_sum.sum.len(), num_saturating_sum_bits); + + let share_of_one = Replicated::share_known_value(&ctx, Gf2::ONE); + let is_source_event = &share_of_one - &input_row.is_trigger_bit; + + let (ever_encountered_a_source_event, attributed_breakdown_key_bits) = try_join( + or( + ctx.narrow(&Step::EverEncounteredSourceEvent), + record_id, + &is_source_event, + &self.ever_encountered_a_source_event, + ), + breakdown_key_of_most_recent_source_event( + ctx.narrow(&Step::AttributedBreakdownKey), + record_id, + &input_row.is_trigger_bit, + &self.attributed_breakdown_key_bits, + &bd_key, + ), + ) + .await?; + + let did_trigger_get_attributed = input_row + .is_trigger_bit + .multiply( + &ever_encountered_a_source_event, + ctx.narrow(&Step::DidTriggerGetAttributed), + record_id, + ) + .await?; + + let attributed_trigger_value = zero_out_trigger_value_unless_attributed( + ctx.narrow(&Step::AttributedTriggerValue), + record_id, + &did_trigger_get_attributed, + &tv, + ) + .await?; + + let updated_sum = self + .saturating_sum + .add( + ctx.narrow(&Step::ComputeSaturatingSum), + record_id, + &attributed_trigger_value, + ) + .await?; + + let (is_saturated_and_prev_row_not_saturated, difference_to_cap) = try_join( + updated_sum.is_saturated.multiply( + &(share_of_one - &self.saturating_sum.is_saturated), + ctx.narrow(&Step::IsSaturatedAndPrevRowNotSaturated), + record_id, + ), + updated_sum.truncated_delta_to_saturation_point( + ctx.narrow(&Step::ComputeDifferenceToCap), + record_id, + TV::BITS, + ), + ) + .await?; + + let capped_attributed_trigger_value = compute_capped_trigger_value( + ctx, + record_id, + &updated_sum.is_saturated, + &is_saturated_and_prev_row_not_saturated, + &self.difference_to_cap, + &attributed_trigger_value, + ) + .await?; + + self.ever_encountered_a_source_event = ever_encountered_a_source_event; + self.attributed_breakdown_key_bits = attributed_breakdown_key_bits.clone(); + self.saturating_sum = updated_sum; + self.difference_to_cap = difference_to_cap; + + let outputs_for_aggregation = CappedAttributionOutputs { + did_trigger_get_attributed, + attributed_breakdown_key_bits, + capped_attributed_trigger_value, + }; + Ok(outputs_for_aggregation) + } +} + #[derive(Debug)] pub struct CappedAttributionOutputs { pub did_trigger_get_attributed: Replicated, @@ -201,7 +330,7 @@ where { assert!(!rows_for_user.is_empty()); if rows_for_user.len() == 1 { - return Ok(vec![]); + return Ok(Vec::new()); } let first_row = &rows_for_user[0]; let mut prev_row_inputs = initialize_new_device_attribution_variables( @@ -215,14 +344,14 @@ where let ctx_for_this_row_depth = ctx_for_row_number[i].clone(); // no context was created for row 0 let record_id_for_this_row_depth = record_id_for_each_depth[i + 1]; // skip row 0 - let capped_attribution_outputs = compute_row_with_previous( - ctx_for_this_row_depth, - record_id_for_this_row_depth, - row, - &mut prev_row_inputs, - num_saturating_sum_bits, - ) - .await?; + let capped_attribution_outputs = prev_row_inputs + .compute_row_with_previous( + ctx_for_this_row_depth, + record_id_for_this_row_depth, + row, + num_saturating_sum_bits, + ) + .await?; output.push(capped_attribution_outputs); } @@ -284,12 +413,7 @@ where .enumerate() .map(|(i, (cur_bit, prev_bit))| { let c = ctx.narrow(&BitOpStep::from(i)); - async move { - let maybe_diff = is_trigger_bit - .multiply(&(prev_bit.clone() - cur_bit), c, record_id) - .await?; - Ok::<_, Error>(maybe_diff + cur_bit) - } + async move { if_else(c, record_id, is_trigger_bit, prev_bit, cur_bit).await } }), ) .await?, @@ -365,23 +489,27 @@ where let narrowed_ctx1 = ctx.narrow(&Step::ComputedCappedAttributedTriggerValueNotSaturatedCase); let narrowed_ctx2 = ctx.narrow(&Step::ComputedCappedAttributedTriggerValueJustSaturatedCase); - let one = &Replicated::share_known_value(&narrowed_ctx1, Gf2::ONE); + let zero = &Replicated::share_known_value(&narrowed_ctx1, Gf2::ZERO); Ok(BitDecomposed::new( ctx.parallel_join( - attributed_trigger_value - .iter() - .zip(prev_row_diff_to_cap.iter()) + zip(attributed_trigger_value.iter(), prev_row_diff_to_cap.iter()) .enumerate() .map(|(i, (bit, prev_bit))| { let c1 = narrowed_ctx1.narrow(&BitOpStep::from(i)); let c2 = narrowed_ctx2.narrow(&BitOpStep::from(i)); async move { - let not_saturated_case = - (one - is_saturated).multiply(bit, c1, record_id).await?; - let just_saturated_case = is_saturated_and_prev_row_not_saturated - .multiply(prev_bit, c2, record_id) - .await?; + let (not_saturated_case, just_saturated_case) = try_join( + if_else(c1, record_id, is_saturated, zero, bit), + if_else( + c2, + record_id, + is_saturated_and_prev_row_not_saturated, + prev_bit, + zero, + ), + ) + .await?; Ok::<_, Error>(not_saturated_case + &just_saturated_case) } }), @@ -390,140 +518,6 @@ where )) } -/// -/// This function contains the main logic for the per-user attribution circuit. -/// Multiple rows of data about a single user are processed in-order from oldest to newest. -/// -/// Summary: -/// - Last touch attribution -/// - Every trigger event which is preceded by a source event is attributed -/// - Trigger events are attributed to the `breakdown_key` of the most recent preceding source event -/// - Per user capping -/// - A cumulative sum of "Attributed Trigger Value" is maintained -/// - Bitwise addition is used, and a single bit indicates if the sum is "saturated" -/// - The only available values for "cap" are powers of 2 (i.e. 1, 2, 4, 8, 16, 32, ...) -/// - Prior to the cumulative sum reaching saturation, attributed trigger values are passed along -/// - The row which puts the cumulative sum over the cap is "capped" to the delta between the cumulative sum of the last row and the cap -/// - All subsequent rows contribute zero -/// - Outputs -/// - If a user has `N` input rows, they will generate `N-1` output rows. (The first row cannot possibly contribute any value to the output) -/// - Each output row has two main values: -/// - `capped_attributed_trigger_value` - the value to contribute to the output (bitwise secret-shared), -/// - `attributed_breakdown_key` - the breakdown to which this contribution applies (bitwise secret-shared), -/// - Additional output: -/// - `did_trigger_get_attributed` - a secret-shared bit indicating if this row corresponds to a trigger event -/// which was attributed. Might be able to reveal this (after a shuffle and the addition of dummies) to minimize -/// the amount of processing work that must be done in the Aggregation stage. -async fn compute_row_with_previous( - ctx: C, - record_id: RecordId, - input_row: &PrfShardedIpaInputRow, - inputs_required_from_previous_row: &mut InputsRequiredFromPrevRow, - num_saturating_sum_bits: usize, -) -> Result -where - C: UpgradedContext>, - BK: GaloisField, - TV: GaloisField, -{ - let bd_key = BitDecomposed::decompose(BK::BITS, |i| { - input_row.breakdown_key.map(|v| Gf2::truncate_from(v[i])) - }); - let tv = BitDecomposed::decompose(TV::BITS, |i| { - input_row.trigger_value.map(|v| Gf2::truncate_from(v[i])) - }); - assert_eq!( - inputs_required_from_previous_row.saturating_sum.sum.len(), - num_saturating_sum_bits - ); - - let share_of_one = Replicated::share_known_value(&ctx, Gf2::ONE); - let is_source_event = &share_of_one - &input_row.is_trigger_bit; - - let (ever_encountered_a_source_event, attributed_breakdown_key_bits) = try_join( - or( - ctx.narrow(&Step::EverEncounteredSourceEvent), - record_id, - &is_source_event, - &inputs_required_from_previous_row.ever_encountered_a_source_event, - ), - breakdown_key_of_most_recent_source_event( - ctx.narrow(&Step::AttributedBreakdownKey), - record_id, - &input_row.is_trigger_bit, - &inputs_required_from_previous_row.attributed_breakdown_key_bits, - &bd_key, - ), - ) - .await?; - - let did_trigger_get_attributed = input_row - .is_trigger_bit - .multiply( - &ever_encountered_a_source_event, - ctx.narrow(&Step::DidTriggerGetAttributed), - record_id, - ) - .await?; - - let attributed_trigger_value = zero_out_trigger_value_unless_attributed( - ctx.narrow(&Step::AttributedTriggerValue), - record_id, - &did_trigger_get_attributed, - &tv, - ) - .await?; - - let updated_sum = inputs_required_from_previous_row - .saturating_sum - .add( - ctx.narrow(&Step::ComputeSaturatingSum), - record_id, - &attributed_trigger_value, - ) - .await?; - - let (is_saturated_and_prev_row_not_saturated, difference_to_cap) = try_join( - updated_sum.is_saturated.multiply( - &(share_of_one - - &inputs_required_from_previous_row - .saturating_sum - .is_saturated), - ctx.narrow(&Step::IsSaturatedAndPrevRowNotSaturated), - record_id, - ), - updated_sum.truncated_delta_to_saturation_point( - ctx.narrow(&Step::ComputeDifferenceToCap), - record_id, - TV::BITS, - ), - ) - .await?; - - let capped_attributed_trigger_value = compute_capped_trigger_value( - ctx, - record_id, - &updated_sum.is_saturated, - &is_saturated_and_prev_row_not_saturated, - &inputs_required_from_previous_row.difference_to_cap, - &attributed_trigger_value, - ) - .await?; - - *inputs_required_from_previous_row = InputsRequiredFromPrevRow { - ever_encountered_a_source_event, - attributed_breakdown_key_bits: BitDecomposed::new(attributed_breakdown_key_bits.clone()), - saturating_sum: updated_sum, - difference_to_cap, - }; - let outputs_for_aggregation = CappedAttributionOutputs { - did_trigger_get_attributed, - attributed_breakdown_key_bits, - capped_attributed_trigger_value, - }; - Ok(outputs_for_aggregation) -} - #[cfg(all(test, unit_test))] pub mod tests { use super::{attribution_and_capping, CappedAttributionOutputs, PrfShardedIpaInputRow};