Skip to content

Commit

Permalink
Merge pull request #1420 from eriktaubeneck/reshard-by-oprf
Browse files Browse the repository at this point in the history
Reshard by oprf
  • Loading branch information
akoshelev authored Nov 9, 2024
2 parents 1f330ac + fce7b13 commit 3655716
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 12 deletions.
16 changes: 13 additions & 3 deletions ipa-core/src/protocol/hybrid/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
pub(crate) mod oprf;
pub(crate) mod step;

use oprf::PRFIndistinguishableHybridReport;

use crate::{
error::Error,
ff::{
boolean::Boolean, boolean_array::BooleanArray, curve_points::RP25519,
ec_prime_field::Fp25519, U128Conversions,
ec_prime_field::Fp25519, Serializable, U128Conversions,
},
helpers::query::DpMechanism,
protocol::{
basics::{BooleanProtocols, Reveal},
context::{DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext},
context::{reshard_iter, DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext},
hybrid::{
oprf::{compute_prf_for_inputs, BreakdownKey, CONV_CHUNK, PRF_CHUNK},
step::HybridStep as Step,
Expand Down Expand Up @@ -67,6 +69,7 @@ where
PrfSharing<MacUpgraded<C, Fp25519>, PRF_CHUNK, Field = Fp25519> + FromPrss,
Replicated<RP25519, PRF_CHUNK>:
Reveal<MacUpgraded<C, Fp25519>, Output = <RP25519 as Vectorizable<PRF_CHUNK>>::Array>,
PRFIndistinguishableHybridReport<BK, V>: Serializable,
{
if input_rows.is_empty() {
return Ok(vec![Replicated::ZERO; B]);
Expand All @@ -83,7 +86,14 @@ where
// TODO shuffle input rows
let shuffled_input_rows = padded_input_rows;

let _prf_input_rows = compute_prf_for_inputs(ctx.clone(), &shuffled_input_rows).await?;
let prf_input_rows_stream = compute_prf_for_inputs(ctx.clone(), &shuffled_input_rows).await?;

let _sharded_prf_rows = reshard_iter(
ctx.narrow(&Step::ReshardByPrf),
prf_input_rows_stream,
|ctx, _, report| report.shard_picker(ctx.shard_count()),
)
.await?;

unimplemented!("protocol::hybrid::hybrid_protocol is not fully implemented")
}
104 changes: 95 additions & 9 deletions ipa-core/src/protocol/hybrid/oprf.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use std::iter::zip;

use futures::{stream, StreamExt, TryStreamExt};
use typenum::Const;
use generic_array::GenericArray;
use typenum::{Const, Unsigned, U12};

use crate::{
error::{Error, UnwrapInfallible},
ff::{
boolean::Boolean,
boolean_array::{BooleanArray, BA5, BA64, BA8},
boolean_array::{BooleanArray, BA3, BA5, BA64, BA8},
curve_points::RP25519,
ec_prime_field::Fp25519,
U128Conversions,
Serializable, U128Conversions,
},
helpers::{
stream::{div_round_up, process_slice_by_chunks, Chunk, ChunkData, TryFlattenItersExt},
Expand All @@ -30,12 +31,13 @@ use crate::{
prss::{FromPrss, SharedRandomness},
RecordId,
},
report::hybrid::IndistinguishableHybridReport,
report::hybrid::{IndistinguishableHybridReport, InvalidHybridReportError},
secret_sharing::{
replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, SharedValue,
TransposeFrom, Vectorizable,
},
seq_join::seq_join,
sharding::ShardIndex,
};

// In theory, we could support (runtime-configured breakdown count) ≤ (compile-time breakdown count)
Expand All @@ -54,7 +56,6 @@ use crate::{

// These could be imported from src/protocl/ipa_prf/mod.rs
// however we've copy/pasted them here with the intention of deleting that file [TODO]

pub trait BreakdownKey<const MAX_BREAKDOWNS: usize>: BooleanArray + U128Conversions {}
impl BreakdownKey<32> for BA5 {}
impl BreakdownKey<256> for BA8 {}
Expand All @@ -76,14 +77,72 @@ pub const PRF_CHUNK: usize = 16;
// multiplications per batch
const CONV_PROOF_CHUNK: usize = 256;

#[derive(Default, Debug)]
#[allow(dead_code)] // needed to mute warning until used in future PRs
#[derive(Debug, Clone, PartialEq)]
pub struct PRFIndistinguishableHybridReport<BK: SharedValue, V: SharedValue> {
// `prf_of_match_key` needs to be a u64 for serialization and deserialization to work
prf_of_match_key: u64,
value: Replicated<V>,
breakdown_key: Replicated<BK>,
}

impl<BK, V> PRFIndistinguishableHybridReport<BK, V>
where
BK: SharedValue,
V: SharedValue,
{
/// ## Panics
/// it doesn't. `ShardIndex` is a u32, expanded into a u64 for the mod operation
/// `prf_of_match_key % shard_count` will always fit into a `ShardIndex`
#[must_use]
pub fn shard_picker(&self, shard_count: ShardIndex) -> ShardIndex {
let shard_count = u64::from(shard_count);
ShardIndex::try_from(self.prf_of_match_key % shard_count)
.expect("Modulo a u32 will fit in a u32")
}
}

impl PRFIndistinguishableHybridReport<BA8, BA3> {
const PRF_MK_SZ: usize = 8;
const V_SZ: usize = <Replicated<BA3> as Serializable>::Size::USIZE;
const BK_SZ: usize = <Replicated<BA8> as Serializable>::Size::USIZE;
}

impl Serializable for PRFIndistinguishableHybridReport<BA8, BA3> {
type Size = U12;
type DeserializationError = InvalidHybridReportError;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
buf[..Self::PRF_MK_SZ].copy_from_slice(&self.prf_of_match_key.to_le_bytes());

self.value.serialize(GenericArray::from_mut_slice(
&mut buf[Self::PRF_MK_SZ..Self::PRF_MK_SZ + Self::V_SZ],
));

self.breakdown_key.serialize(GenericArray::from_mut_slice(
&mut buf[Self::PRF_MK_SZ + Self::V_SZ..Self::PRF_MK_SZ + Self::V_SZ + Self::BK_SZ],
));
}

fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError> {
let prf_of_match_key = u64::from_le_bytes(buf[..Self::PRF_MK_SZ].try_into().unwrap());

let value = Replicated::<BA3>::deserialize(GenericArray::from_slice(
&buf[Self::PRF_MK_SZ..Self::PRF_MK_SZ + Self::V_SZ],
))
.map_err(|e| InvalidHybridReportError::DeserializationError("value", e.into()))?;

let breakdown_key = Replicated::<BA8>::deserialize_infallible(GenericArray::from_slice(
&buf[Self::PRF_MK_SZ + Self::V_SZ..Self::PRF_MK_SZ + Self::V_SZ + Self::BK_SZ],
));

Ok(Self {
prf_of_match_key,
value,
breakdown_key,
})
}
}

#[tracing::instrument(name = "compute_prf_for_inputs", skip_all)]
pub async fn compute_prf_for_inputs<C, BK, V>(
ctx: C,
Expand Down Expand Up @@ -181,13 +240,22 @@ where

#[cfg(all(test, unit_test, feature = "in-memory-infra"))]
mod test {
use generic_array::GenericArray;
use ipa_step::StepNarrow;
use rand::Rng;

use super::PRFIndistinguishableHybridReport;
use crate::{
ff::boolean_array::{BA3, BA8},
ff::{
boolean_array::{BA3, BA8},
Serializable,
},
protocol::{hybrid::oprf::compute_prf_for_inputs, step::ProtocolStep, Gate},
report::hybrid::{HybridReport, IndistinguishableHybridReport},
secret_sharing::IntoShares,
secret_sharing::{
replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing},
IntoShares,
},
test_executor::run,
test_fixture::{
hybrid::TestHybridRecord, RoundRobinInputDistribution, TestWorld, TestWorldConfig,
Expand Down Expand Up @@ -293,4 +361,22 @@ mod test {
}
});
}

#[test]
fn prf_indistinguishable_serialize_deserialize() {
run(|| async {
let world = TestWorld::default();
let mut rng = world.rng();
let report = PRFIndistinguishableHybridReport::<BA8, BA3> {
prf_of_match_key: rng.gen(),
breakdown_key: Replicated::new(rng.gen(), rng.gen()),
value: Replicated::new(rng.gen(), rng.gen()),
};
let mut buf = GenericArray::default();
report.serialize(&mut buf);
let deserialized_report =
PRFIndistinguishableHybridReport::<BA8, BA3>::deserialize(&buf);
assert_eq!(report, deserialized_report.unwrap());
});
}
}
1 change: 1 addition & 0 deletions ipa-core/src/protocol/hybrid/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ pub(crate) enum HybridStep {
PrfKeyGen,
#[step(child = crate::protocol::context::step::MaliciousProtocolStep)]
EvalPrf,
ReshardByPrf,
}
8 changes: 8 additions & 0 deletions ipa-core/src/sharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ impl TryFrom<usize> for ShardIndex {
}
}

impl TryFrom<u64> for ShardIndex {
type Error = TryFromIntError;

fn try_from(value: u64) -> Result<Self, Self::Error> {
u32::try_from(value).map(Self)
}
}

impl TryFrom<u128> for ShardIndex {
type Error = TryFromIntError;

Expand Down

0 comments on commit 3655716

Please sign in to comment.