From fce7b13813cf57aba8721bc807d5530d358f1270 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Thu, 7 Nov 2024 10:48:16 -0800 Subject: [PATCH] implement Serializable and shard_picker for PRFIndistinguishableHybridReport, add reshard by oprf --- ipa-core/src/protocol/hybrid/mod.rs | 16 ++++- ipa-core/src/protocol/hybrid/oprf.rs | 104 ++++++++++++++++++++++++--- ipa-core/src/protocol/hybrid/step.rs | 1 + ipa-core/src/sharding.rs | 8 +++ 4 files changed, 117 insertions(+), 12 deletions(-) diff --git a/ipa-core/src/protocol/hybrid/mod.rs b/ipa-core/src/protocol/hybrid/mod.rs index f8aa1dcee..fae51e5c0 100644 --- a/ipa-core/src/protocol/hybrid/mod.rs +++ b/ipa-core/src/protocol/hybrid/mod.rs @@ -1,16 +1,18 @@ pub(crate) mod oprf; pub(crate) mod step; +use oprf::PRFIndistinguishableHybridReport; + use crate::{ error::Error, ff::{ boolean::Boolean, boolean_array::BooleanArray, curve_points::RP25519, - ec_prime_field::Fp25519, U128Conversions, + ec_prime_field::Fp25519, Serializable, U128Conversions, }, helpers::query::DpMechanism, protocol::{ basics::{BooleanProtocols, Reveal}, - context::{DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext}, + context::{reshard_iter, DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext}, hybrid::{ oprf::{compute_prf_for_inputs, BreakdownKey, CONV_CHUNK, PRF_CHUNK}, step::HybridStep as Step, @@ -67,6 +69,7 @@ where PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, Replicated: Reveal, Output = >::Array>, + PRFIndistinguishableHybridReport: Serializable, { if input_rows.is_empty() { return Ok(vec![Replicated::ZERO; B]); @@ -83,7 +86,14 @@ where // TODO shuffle input rows let shuffled_input_rows = padded_input_rows; - let _prf_input_rows = compute_prf_for_inputs(ctx.clone(), &shuffled_input_rows).await?; + let prf_input_rows_stream = compute_prf_for_inputs(ctx.clone(), &shuffled_input_rows).await?; + + let _sharded_prf_rows = reshard_iter( + ctx.narrow(&Step::ReshardByPrf), + prf_input_rows_stream, + |ctx, _, report| report.shard_picker(ctx.shard_count()), + ) + .await?; unimplemented!("protocol::hybrid::hybrid_protocol is not fully implemented") } diff --git a/ipa-core/src/protocol/hybrid/oprf.rs b/ipa-core/src/protocol/hybrid/oprf.rs index 074b41db0..4c3705371 100644 --- a/ipa-core/src/protocol/hybrid/oprf.rs +++ b/ipa-core/src/protocol/hybrid/oprf.rs @@ -1,16 +1,17 @@ use std::iter::zip; use futures::{stream, StreamExt, TryStreamExt}; -use typenum::Const; +use generic_array::GenericArray; +use typenum::{Const, Unsigned, U12}; use crate::{ error::{Error, UnwrapInfallible}, ff::{ boolean::Boolean, - boolean_array::{BooleanArray, BA5, BA64, BA8}, + boolean_array::{BooleanArray, BA3, BA5, BA64, BA8}, curve_points::RP25519, ec_prime_field::Fp25519, - U128Conversions, + Serializable, U128Conversions, }, helpers::{ stream::{div_round_up, process_slice_by_chunks, Chunk, ChunkData, TryFlattenItersExt}, @@ -30,12 +31,13 @@ use crate::{ prss::{FromPrss, SharedRandomness}, RecordId, }, - report::hybrid::IndistinguishableHybridReport, + report::hybrid::{IndistinguishableHybridReport, InvalidHybridReportError}, secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, SharedValue, TransposeFrom, Vectorizable, }, seq_join::seq_join, + sharding::ShardIndex, }; // In theory, we could support (runtime-configured breakdown count) ≤ (compile-time breakdown count) @@ -54,7 +56,6 @@ use crate::{ // These could be imported from src/protocl/ipa_prf/mod.rs // however we've copy/pasted them here with the intention of deleting that file [TODO] - pub trait BreakdownKey: BooleanArray + U128Conversions {} impl BreakdownKey<32> for BA5 {} impl BreakdownKey<256> for BA8 {} @@ -76,14 +77,72 @@ pub const PRF_CHUNK: usize = 16; // multiplications per batch const CONV_PROOF_CHUNK: usize = 256; -#[derive(Default, Debug)] -#[allow(dead_code)] // needed to mute warning until used in future PRs +#[derive(Debug, Clone, PartialEq)] pub struct PRFIndistinguishableHybridReport { + // `prf_of_match_key` needs to be a u64 for serialization and deserialization to work prf_of_match_key: u64, value: Replicated, breakdown_key: Replicated, } +impl PRFIndistinguishableHybridReport +where + BK: SharedValue, + V: SharedValue, +{ + /// ## Panics + /// it doesn't. `ShardIndex` is a u32, expanded into a u64 for the mod operation + /// `prf_of_match_key % shard_count` will always fit into a `ShardIndex` + #[must_use] + pub fn shard_picker(&self, shard_count: ShardIndex) -> ShardIndex { + let shard_count = u64::from(shard_count); + ShardIndex::try_from(self.prf_of_match_key % shard_count) + .expect("Modulo a u32 will fit in a u32") + } +} + +impl PRFIndistinguishableHybridReport { + const PRF_MK_SZ: usize = 8; + const V_SZ: usize = as Serializable>::Size::USIZE; + const BK_SZ: usize = as Serializable>::Size::USIZE; +} + +impl Serializable for PRFIndistinguishableHybridReport { + type Size = U12; + type DeserializationError = InvalidHybridReportError; + + fn serialize(&self, buf: &mut GenericArray) { + buf[..Self::PRF_MK_SZ].copy_from_slice(&self.prf_of_match_key.to_le_bytes()); + + self.value.serialize(GenericArray::from_mut_slice( + &mut buf[Self::PRF_MK_SZ..Self::PRF_MK_SZ + Self::V_SZ], + )); + + self.breakdown_key.serialize(GenericArray::from_mut_slice( + &mut buf[Self::PRF_MK_SZ + Self::V_SZ..Self::PRF_MK_SZ + Self::V_SZ + Self::BK_SZ], + )); + } + + fn deserialize(buf: &GenericArray) -> Result { + let prf_of_match_key = u64::from_le_bytes(buf[..Self::PRF_MK_SZ].try_into().unwrap()); + + let value = Replicated::::deserialize(GenericArray::from_slice( + &buf[Self::PRF_MK_SZ..Self::PRF_MK_SZ + Self::V_SZ], + )) + .map_err(|e| InvalidHybridReportError::DeserializationError("value", e.into()))?; + + let breakdown_key = Replicated::::deserialize_infallible(GenericArray::from_slice( + &buf[Self::PRF_MK_SZ + Self::V_SZ..Self::PRF_MK_SZ + Self::V_SZ + Self::BK_SZ], + )); + + Ok(Self { + prf_of_match_key, + value, + breakdown_key, + }) + } +} + #[tracing::instrument(name = "compute_prf_for_inputs", skip_all)] pub async fn compute_prf_for_inputs( ctx: C, @@ -181,13 +240,22 @@ where #[cfg(all(test, unit_test, feature = "in-memory-infra"))] mod test { + use generic_array::GenericArray; use ipa_step::StepNarrow; + use rand::Rng; + use super::PRFIndistinguishableHybridReport; use crate::{ - ff::boolean_array::{BA3, BA8}, + ff::{ + boolean_array::{BA3, BA8}, + Serializable, + }, protocol::{hybrid::oprf::compute_prf_for_inputs, step::ProtocolStep, Gate}, report::hybrid::{HybridReport, IndistinguishableHybridReport}, - secret_sharing::IntoShares, + secret_sharing::{ + replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, + IntoShares, + }, test_executor::run, test_fixture::{ hybrid::TestHybridRecord, RoundRobinInputDistribution, TestWorld, TestWorldConfig, @@ -293,4 +361,22 @@ mod test { } }); } + + #[test] + fn prf_indistinguishable_serialize_deserialize() { + run(|| async { + let world = TestWorld::default(); + let mut rng = world.rng(); + let report = PRFIndistinguishableHybridReport:: { + prf_of_match_key: rng.gen(), + breakdown_key: Replicated::new(rng.gen(), rng.gen()), + value: Replicated::new(rng.gen(), rng.gen()), + }; + let mut buf = GenericArray::default(); + report.serialize(&mut buf); + let deserialized_report = + PRFIndistinguishableHybridReport::::deserialize(&buf); + assert_eq!(report, deserialized_report.unwrap()); + }); + } } diff --git a/ipa-core/src/protocol/hybrid/step.rs b/ipa-core/src/protocol/hybrid/step.rs index d54e20145..44695797e 100644 --- a/ipa-core/src/protocol/hybrid/step.rs +++ b/ipa-core/src/protocol/hybrid/step.rs @@ -12,4 +12,5 @@ pub(crate) enum HybridStep { PrfKeyGen, #[step(child = crate::protocol::context::step::MaliciousProtocolStep)] EvalPrf, + ReshardByPrf, } diff --git a/ipa-core/src/sharding.rs b/ipa-core/src/sharding.rs index e3c76312b..180dea7d7 100644 --- a/ipa-core/src/sharding.rs +++ b/ipa-core/src/sharding.rs @@ -90,6 +90,14 @@ impl TryFrom for ShardIndex { } } +impl TryFrom for ShardIndex { + type Error = TryFromIntError; + + fn try_from(value: u64) -> Result { + u32::try_from(value).map(Self) + } +} + impl TryFrom for ShardIndex { type Error = TryFromIntError;