diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 2fcbd28fa..0df54c325 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -290,11 +290,12 @@ async fn ipa( let mut key_registries = KeyRegistries::default(); let actual = match query_style { IpaQueryStyle::Oprf => { - playbook_oprf_ipa::( + playbook_oprf_ipa::( input_rows, &helper_clients, query_id, ipa_query_config, + key_registries.init_from(network), ) .await } diff --git a/ipa-core/src/cli/playbook/ipa.rs b/ipa-core/src/cli/playbook/ipa.rs index 8459d946e..eb9d96a2d 100644 --- a/ipa-core/src/cli/playbook/ipa.rs +++ b/ipa-core/src/cli/playbook/ipa.rs @@ -22,7 +22,10 @@ use crate::{ hpke::PublicKeyRegistry, ipa_test_input, net::MpcHelperClient, - protocol::{ipa::IPAInputRow, BreakdownKey, MatchKey, QueryId, Timestamp, TriggerValue}, + protocol::{ + ipa::IPAInputRow, ipa_prf::OPRFIPAInputRow, BreakdownKey, MatchKey, QueryId, Timestamp, + TriggerValue, + }, query::QueryStatus, report::{KeyIdentifier, OprfReport, Report}, secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, @@ -103,31 +106,59 @@ where run_query_and_validate::(inputs, query_size, clients, query_id, query_config).await } -pub async fn playbook_oprf_ipa( +pub async fn playbook_oprf_ipa( records: Vec, clients: &[MpcHelperClient; 3], query_id: QueryId, query_config: IpaQueryConfig, + encryption: Option<(KeyIdentifier, [&KR; 3])>, ) -> IpaQueryResult where F: PrimeField, AdditiveShare: Serializable, + KR: PublicKeyRegistry, { let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new()); let query_size = records.len(); - let sz = as Serializable>::Size::USIZE; - for buffer in &mut buffers { - buffer.resize(query_size * sz, 0u8); - } + if !query_config.plaintext_match_keys { + if let Some((key_id, key_registries)) = encryption { + const ESTIMATED_AVERAGE_REPORT_SIZE: usize = 80; // TODO: confirm/adjust + for buffer in &mut buffers { + buffer.reserve(query_size * ESTIMATED_AVERAGE_REPORT_SIZE); + } - let shares: [Vec>; 3] = - records.iter().cloned().share(); - zip(&mut buffers, shares).for_each(|(buf, shares)| { - for (share, chunk) in zip(shares, buf.chunks_mut(sz)) { - share.serialize(GenericArray::from_mut_slice(chunk)); + let mut rng = StdRng::from_entropy(); + let shares: [Vec>; 3] = + records.iter().cloned().share(); + zip(&mut buffers, shares).zip(key_registries).for_each( + |((buf, shares), key_registry)| { + for share in shares { + share + .delimited_encrypt_to(key_id, key_registry, &mut rng, buf) + .unwrap(); + } + }, + ); + } else { + panic!("match key encryption was requested, but one or more helpers is missing a public key") + } + } else { + let sz = + as Serializable>::Size::USIZE; + for buffer in &mut buffers { + buffer.resize(query_size * sz, 0u8); } - }); + + let shares: [Vec>; 3] = + records.iter().cloned().share(); + + zip(&mut buffers, shares).for_each(|(buf, shares)| { + for (share, chunk) in zip(shares, buf.chunks_mut(sz)) { + share.serialize(GenericArray::from_mut_slice(chunk)); + } + }); + } let inputs = buffers.map(BodyStream::from); tracing::info!("Starting query for OPRF"); diff --git a/ipa-core/src/hpke/mod.rs b/ipa-core/src/hpke/mod.rs index d2f47cd63..48d4c1303 100644 --- a/ipa-core/src/hpke/mod.rs +++ b/ipa-core/src/hpke/mod.rs @@ -29,9 +29,12 @@ type IpaKem = hpke::kem::X25519HkdfSha256; type IpaAead = hpke::aead::AesGcm128; type IpaKdf = hpke::kdf::HkdfSha256; +pub type EncapsulationSize = <::EncappedKey as Serializable>::OutputSize; +pub type TagSize = as Serializable>::OutputSize; + pub type IpaPublicKey = ::PublicKey; pub type IpaPrivateKey = ::PrivateKey; -pub type IpaEncappedKey = ::EncappedKey; +pub type IpaEncapsulatedKey = ::EncappedKey; pub use hpke::{Deserializable, Serializable}; diff --git a/ipa-core/src/hpke/registry.rs b/ipa-core/src/hpke/registry.rs index 338553a8d..87bbbc0f6 100644 --- a/ipa-core/src/hpke/registry.rs +++ b/ipa-core/src/hpke/registry.rs @@ -128,7 +128,7 @@ mod tests { use rand_core::{CryptoRng, RngCore, SeedableRng}; use super::*; - use crate::hpke::{IpaAead, IpaEncappedKey, IpaKdf, IpaKem}; + use crate::hpke::{IpaAead, IpaEncapsulatedKey, IpaKdf, IpaKem}; const INFO_STR: &[u8] = b"This is an INFO string."; const AAD: &[u8] = b"This is AAD."; @@ -137,7 +137,7 @@ mod tests { pk: &IpaPublicKey, pt: &[u8], r: &mut R, - ) -> (IpaEncappedKey, Vec) { + ) -> (IpaEncapsulatedKey, Vec) { let (encapsulated_key, mut encryption_context) = hpke::setup_sender::(&OpModeS::Base, pk, INFO_STR, r) .expect("Can setup the sender."); @@ -152,7 +152,7 @@ mod tests { fn decrypt>( sk: &IpaPrivateKey, - payload: &(IpaEncappedKey, I), + payload: &(IpaEncapsulatedKey, I), ) -> Result, HpkeError> { let (encap_key, ct) = payload; let mut decryption_context = hpke::setup_receiver::( diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 47616d929..e7dc68f52 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -1,10 +1,12 @@ -use std::num::NonZeroU32; +use std::{num::NonZeroU32, ops::Add}; +use generic_array::{ArrayLength, GenericArray}; use ipa_macros::Step; +use typenum::{Unsigned, U18}; use self::{quicksort::quicksort_ranges_by_key_insecure, shuffle::shuffle_inputs}; use crate::{ - error::Error, + error::{Error, UnwrapInfallible}, ff::{boolean::Boolean, boolean_array::BA64, CustomArray, Field, PrimeField, Serializable}, protocol::{ context::{UpgradableContext, UpgradedContext}, @@ -17,7 +19,6 @@ use crate::{ }, RecordId, }, - report::OprfReport, secret_sharing::{ replicated::{malicious::ExtendableField, semi_honest::AdditiveShare as Replicated}, SharedValue, @@ -40,6 +41,103 @@ pub(crate) enum Step { SortByTimestamp, } +#[derive(Debug)] +#[cfg_attr(test, derive(Clone, PartialEq, Eq))] +pub struct OPRFIPAInputRow { + pub match_key: Replicated, + pub is_trigger: Replicated, + pub breakdown_key: Replicated, + pub trigger_value: Replicated, + pub timestamp: Replicated, +} + +impl Serializable for OPRFIPAInputRow +where + Replicated: Serializable, + Replicated: Serializable, + Replicated: Serializable, + as Serializable>::Size: Add, + as Serializable>::Size: + Add<< as Serializable>::Size as Add>::Output>, + as Serializable>::Size: Add< + < as Serializable>::Size as Add< + < as Serializable>::Size as Add>::Output, + >>::Output, + >, + < as Serializable>::Size as Add< + < as Serializable>::Size as Add< + < as Serializable>::Size as Add>::Output, + >>::Output, + >>::Output: ArrayLength, +{ + type Size = < as Serializable>::Size as Add< + < as Serializable>::Size as Add< + < as Serializable>::Size as Add>::Output, + >>::Output, + >>::Output; + type DeserializationError = Error; + + fn serialize(&self, buf: &mut GenericArray) { + let mk_sz = as Serializable>::Size::USIZE; + let ts_sz = as Serializable>::Size::USIZE; + let bk_sz = as Serializable>::Size::USIZE; + let tv_sz = as Serializable>::Size::USIZE; + let it_sz = as Serializable>::Size::USIZE; + + self.match_key + .serialize(GenericArray::from_mut_slice(&mut buf[..mk_sz])); + + self.timestamp + .serialize(GenericArray::from_mut_slice(&mut buf[mk_sz..mk_sz + ts_sz])); + + self.breakdown_key.serialize(GenericArray::from_mut_slice( + &mut buf[mk_sz + ts_sz..mk_sz + ts_sz + bk_sz], + )); + + self.trigger_value.serialize(GenericArray::from_mut_slice( + &mut buf[mk_sz + ts_sz + bk_sz..mk_sz + ts_sz + bk_sz + tv_sz], + )); + + self.is_trigger.serialize(GenericArray::from_mut_slice( + &mut buf[mk_sz + ts_sz + bk_sz + tv_sz..mk_sz + ts_sz + bk_sz + tv_sz + it_sz], + )); + } + + fn deserialize(buf: &GenericArray) -> Result { + let mk_sz = as Serializable>::Size::USIZE; + let ts_sz = as Serializable>::Size::USIZE; + let bk_sz = as Serializable>::Size::USIZE; + let tv_sz = as Serializable>::Size::USIZE; + let it_sz = as Serializable>::Size::USIZE; + + let match_key = Replicated::::deserialize(GenericArray::from_slice(&buf[..mk_sz])) + .unwrap_infallible(); + let timestamp = + Replicated::::deserialize(GenericArray::from_slice(&buf[mk_sz..mk_sz + ts_sz])) + .map_err(|e| Error::ParseError(e.into()))?; + let breakdown_key = Replicated::::deserialize(GenericArray::from_slice( + &buf[mk_sz + ts_sz..mk_sz + ts_sz + bk_sz], + )) + .map_err(|e| Error::ParseError(e.into()))?; + let trigger_value = Replicated::::deserialize(GenericArray::from_slice( + &buf[mk_sz + ts_sz + bk_sz..mk_sz + ts_sz + bk_sz + tv_sz], + )) + .map_err(|e| Error::ParseError(e.into()))?; + let is_trigger = Replicated::::deserialize(GenericArray::from_slice( + &buf[mk_sz + ts_sz + bk_sz + tv_sz..mk_sz + ts_sz + bk_sz + tv_sz + it_sz], + )) + .map_err(|e| Error::ParseError(e.into()))?; + + Ok(Self { + match_key, + is_trigger, + breakdown_key, + trigger_value, + timestamp, + }) + } +} + /// IPA OPRF Protocol /// /// The output of this function is a vector of secret-shared totals, one per breakdown key @@ -63,7 +161,7 @@ pub(crate) enum Step { /// Propagates errors from config issues or while running the protocol pub async fn oprf_ipa( ctx: C, - input_rows: Vec>, + input_rows: Vec>, attribution_window_seconds: Option, ) -> Result>, Error> where @@ -105,7 +203,7 @@ where #[tracing::instrument(name = "compute_prf_for_inputs", skip_all)] async fn compute_prf_for_inputs( ctx: C, - input_rows: Vec>, + input_rows: Vec>, ) -> Result>, Error> where C: UpgradableContext, diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index db982a0ad..62a7ac5fc 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -7,8 +7,10 @@ use crate::{ boolean_array::{BA112, BA64}, ArrayAccess, CustomArray, Expand, Field, }, - protocol::context::{UpgradableContext, UpgradedContext}, - report::OprfReport, + protocol::{ + context::{UpgradableContext, UpgradedContext}, + ipa_prf::OPRFIPAInputRow, + }, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, SharedValue, @@ -20,8 +22,8 @@ pub mod base; #[tracing::instrument(name = "shuffle_inputs", skip_all)] pub async fn shuffle_inputs( ctx: C, - input: Vec>, -) -> Result>, Error> + input: Vec>, +) -> Result>, Error> where C: UpgradableContext, C::UpgradedContext: UpgradedContext>, @@ -44,7 +46,7 @@ where // This function converts OprfReport to an AdditiveShare needed for shuffle protocol pub fn oprfreport_to_shuffle_input( - input: &OprfReport, + input: &OPRFIPAInputRow, ) -> AdditiveShare where YS: CustomArray::Element> + SharedValue, @@ -73,7 +75,9 @@ where } // This function converts AdditiveShare obtained from shuffle protocol to OprfReport -pub fn shuffled_to_oprfreport(input: &AdditiveShare) -> OprfReport +pub fn shuffled_to_oprfreport( + input: &AdditiveShare, +) -> OPRFIPAInputRow where YS: SharedValue + CustomArray, BK: SharedValue + CustomArray, @@ -99,7 +103,7 @@ where offset += TV::BITS as usize; let timestamp = extract_from_shared_array::(input, offset); - OprfReport { + OPRFIPAInputRow { match_key, is_trigger, breakdown_key, diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index 5244c6c22..1f663499c 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -209,7 +209,7 @@ pub fn execute( move |prss, gateway, config, input| { let ctx = SemiHonestContext::new(prss, gateway); Box::pin( - OprfIpaQuery::<_, Fp32BitPrime>::new(ipa_config) + OprfIpaQuery::<_, Fp32BitPrime>::new(ipa_config, key_registry) .execute(ctx, config.size, input) .then(|res| ready(res.map(|out| Box::new(out) as Box))), ) @@ -223,7 +223,7 @@ pub fn execute( move |prss, gateway, config, input| { let ctx = SemiHonestContext::new(prss, gateway); Box::pin( - OprfIpaQuery::<_, crate::ff::Fp31>::new(ipa_config) + OprfIpaQuery::<_, crate::ff::Fp31>::new(ipa_config, key_registry) .execute(ctx, config.size, input) .then(|res| ready(res.map(|out| Box::new(out) as Box))), ) diff --git a/ipa-core/src/query/runner/oprf_ipa.rs b/ipa-core/src/query/runner/oprf_ipa.rs index c5b0776da..4c013d259 100644 --- a/ipa-core/src/query/runner/oprf_ipa.rs +++ b/ipa-core/src/query/runner/oprf_ipa.rs @@ -1,38 +1,44 @@ use std::marker::PhantomData; -use futures::TryStreamExt; +use futures::{stream::iter, StreamExt, TryStreamExt}; +use futures_util::stream::repeat; use crate::{ error::Error, ff::{ boolean::Boolean, boolean_array::{BA20, BA3, BA4, BA5, BA6, BA7, BA8}, - PrimeField, Serializable, + Field, PrimeField, Serializable, }, helpers::{ query::{IpaQueryConfig, QuerySize}, - BodyStream, RecordsStream, + BodyStream, LengthDelimitedStream, RecordsStream, }, + hpke::{KeyPair, KeyRegistry}, protocol::{ basics::ShareKnownValue, context::{UpgradableContext, UpgradedContext}, - ipa_prf::oprf_ipa, + ipa_prf::{oprf_ipa, OPRFIPAInputRow}, }, - report::OprfReport, - secret_sharing::replicated::{ - malicious::ExtendableField, semi_honest::AdditiveShare as Replicated, + report::{EncryptedOprfReport, EventType}, + secret_sharing::{ + replicated::{malicious::ExtendableField, semi_honest::AdditiveShare as Replicated}, + SharedValue, }, + sync::Arc, }; pub struct OprfIpaQuery { config: IpaQueryConfig, + key_registry: Arc>, phantom_data: PhantomData<(C, F)>, } impl OprfIpaQuery { - pub fn new(config: IpaQueryConfig) -> Self { + pub fn new(config: IpaQueryConfig, key_registry: Arc>) -> Self { Self { config, + key_registry, phantom_data: PhantomData, } } @@ -57,19 +63,52 @@ where ) -> Result>, Error> { let Self { config, + key_registry, phantom_data: _, } = self; tracing::info!("New query: {config:?}"); let sz = usize::from(query_size); let input = if config.plaintext_match_keys { - let mut v = RecordsStream::, _>::new(input_stream) + let mut v = RecordsStream::, _>::new(input_stream) .try_concat() .await?; v.truncate(sz); v } else { - panic!("Encrypted match key handling is not handled for OPRF flow as yet"); + LengthDelimitedStream::, _>::new(input_stream) + .map_err(Into::::into) + .map_ok(|enc_reports| { + iter(enc_reports.into_iter().map(|enc_report| { + enc_report + .decrypt(key_registry.as_ref()) + .map_err(Into::::into) + })) + }) + .try_flatten() + .take(sz) + .zip(repeat(ctx.clone())) + .map(|(res, ctx)| { + res.map(|report| { + let is_trigger = Replicated::::share_known_value( + &ctx, + match report.event_type { + EventType::Source => Boolean::ZERO, + EventType::Trigger => Boolean::ONE, + }, + ); + + OPRFIPAInputRow { + timestamp: report.timestamp, + match_key: report.match_key, + is_trigger, + breakdown_key: report.breakdown_key, + trigger_value: report.trigger_value, + } + }) + }) + .try_collect::>() + .await? }; let aws = config.attribution_window_seconds; @@ -86,3 +125,111 @@ where } } } + +#[cfg(all(test, unit_test))] +mod tests { + use std::iter::zip; + + use rand::rngs::StdRng; + use rand_core::SeedableRng; + + use super::*; + use crate::{ + ff::{Field, Fp31}, + report::{OprfReport, DEFAULT_KEY_ID}, + secret_sharing::IntoShares, + test_fixture::{ipa::TestRawDataRecord, join3v, Reconstruct, TestWorld}, + }; + + #[tokio::test] + async fn encrypted_reports() { + const EXPECTED: &[u128] = &[0, 8, 5]; + + let records: Vec = vec![ + TestRawDataRecord { + timestamp: 0, + user_id: 12345, + is_trigger_report: false, + breakdown_key: 2, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 4, + user_id: 68362, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 10, + user_id: 12345, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 5, + }, + TestRawDataRecord { + timestamp: 12, + user_id: 68362, + is_trigger_report: true, + breakdown_key: 0, + trigger_value: 2, + }, + TestRawDataRecord { + timestamp: 20, + user_id: 68362, + is_trigger_report: false, + breakdown_key: 1, + trigger_value: 0, + }, + TestRawDataRecord { + timestamp: 30, + user_id: 68362, + is_trigger_report: true, + breakdown_key: 1, + trigger_value: 7, + }, + ]; + + let query_size = QuerySize::try_from(records.len()).unwrap(); + + let mut rng = StdRng::seed_from_u64(42); + let key_id = DEFAULT_KEY_ID; + let key_registry = Arc::new(KeyRegistry::random(1, &mut rng)); + + let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new()); + + let shares: [Vec>; 3] = records.into_iter().share(); + for (buf, shares) in zip(&mut buffers, shares) { + for share in shares { + share + .delimited_encrypt_to(key_id, key_registry.as_ref(), &mut rng, buf) + .unwrap(); + } + } + + let world = TestWorld::default(); + let contexts = world.contexts(); + #[allow(clippy::large_futures)] + let results = join3v(buffers.into_iter().zip(contexts).map(|(buffer, ctx)| { + let query_config = IpaQueryConfig { + num_multi_bits: 3, + per_user_credit_cap: 8, + attribution_window_seconds: None, + max_breakdown_key: 3, + plaintext_match_keys: false, + }; + let input = BodyStream::from(buffer); + OprfIpaQuery::<_, Fp31>::new(query_config, Arc::clone(&key_registry)) + .execute(ctx, query_size, input) + })) + .await; + + assert_eq!( + results.reconstruct()[0..3] + .iter() + .map(Field::as_u128) + .collect::>(), + EXPECTED + ); + } +} diff --git a/ipa-core/src/report.rs b/ipa-core/src/report.rs index 37f8b466b..6b16e4c9e 100644 --- a/ipa-core/src/report.rs +++ b/ipa-core/src/report.rs @@ -1,8 +1,6 @@ use std::{ - convert::Infallible, fmt::{Display, Formatter}, marker::PhantomData, - mem::size_of, ops::{Add, Deref}, }; @@ -10,17 +8,14 @@ use bytes::{BufMut, Bytes}; use generic_array::{ArrayLength, GenericArray}; use hpke::Serializable as _; use rand_core::{CryptoRng, RngCore}; -use typenum::{Unsigned, U1, U18, U8}; +use typenum::{Sum, Unsigned, U1, U16}; use crate::{ - error::{BoxError, Error, UnwrapInfallible}, - ff::{ - boolean::Boolean, boolean_array::BA64, GaloisField, Gf40Bit, Gf8Bit, PrimeField, - Serializable, - }, + error::BoxError, + ff::{boolean_array::BA64, GaloisField, Gf40Bit, Gf8Bit, PrimeField, Serializable}, hpke::{ - open_in_place, seal_in_place, CryptError, FieldShareCrypt, Info, KeyPair, KeyRegistry, - PublicKeyRegistry, + open_in_place, seal_in_place, CryptError, EncapsulationSize, FieldShareCrypt, Info, + KeyPair, KeyRegistry, PublicKeyRegistry, TagSize, }, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, SharedValue}, }; @@ -63,16 +58,16 @@ impl Serializable for EventType { fn serialize(&self, buf: &mut GenericArray) { let raw: &[u8] = match self { - EventType::Trigger => &[0], - EventType::Source => &[1], + EventType::Trigger => &[1], + EventType::Source => &[0], }; buf.copy_from_slice(raw); } fn deserialize(buf: &GenericArray) -> Result { match buf[0] { - 0 => Ok(EventType::Trigger), - 1 => Ok(EventType::Source), + 1 => Ok(EventType::Trigger), + 0 => Ok(EventType::Source), _ => Err(UnknownEventType(buf[0])), } } @@ -160,6 +155,8 @@ pub enum InvalidReportError { Crypt(#[from] CryptError), #[error("failed to deserialize field {0}: {1}")] DeserializationError(&'static str, #[source] BoxError), + #[error("report is too short: {0}, expected length at least: {1}")] + Length(usize, usize), } /// A binary report as submitted by a report collector, containing encrypted match key shares. @@ -346,7 +343,7 @@ where Replicated: Serializable, { /// # Panics - /// If report length does not fit in u16. + /// If report length does not fit in `u16`. pub fn encrypted_len(&self) -> u16 { let len = EncryptedReport::::SITE_DOMAIN_OFFSET + self.site_domain.as_bytes().len(); @@ -374,7 +371,7 @@ where key_registry: &impl PublicKeyRegistry, rng: &mut R, ) -> Result, InvalidReportError> { - let mut out = Vec::new(); + let mut out = Vec::with_capacity(usize::from(self.encrypted_len())); self.encrypt_to(key_id, key_registry, rng, &mut out)?; debug_assert_eq!(out.len(), usize::from(self.encrypted_len())); Ok(out) @@ -424,136 +421,384 @@ where } } -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct OprfReport +/// A binary report as submitted by a report collector, containing encrypted `OprfReport` +/// An `EncryptedOprfReport` consists of: +/// `ct_mk`: Enc(`match_key`) +/// `ct_btt`: Enc(`breakdown_key`, `trigger_value`, `timestamp`) +/// associated data of `ct_mk`: `key_id`, `epoch`, `event_type`, `site_domain`, +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct EncryptedOprfReport where + B: Deref, BK: SharedValue, TV: SharedValue, TS: SharedValue, { - pub match_key: Replicated, - pub is_trigger: Replicated, - pub breakdown_key: Replicated, - pub trigger_value: Replicated, - pub timestamp: Replicated, + data: B, + phantom_data: PhantomData<(BK, TV, TS)>, } -impl Serializable for u64 { - type Size = U8; - type DeserializationError = Infallible; +// follows the outline of the implementation of `EncryptedReport` +// Report structure: +// * 0..a: `encap_key_1` +// * a..b: `mk_ciphertext` +// * b..c: `encap_key_2` +// * c..d: `btt_ciphertext` +// * d: `event_type` +// * d+1: `key_id` +// * d+2..d+4: `epoch` +// * d+4..: `site_domain` + +// btt ciphertext structure +// * 0..a `timestamp` +// * a..b `breakdown` +// * b..c `trigger value` +impl EncryptedOprfReport +where + B: Deref, + BK: SharedValue, + TV: SharedValue, + TS: SharedValue, + Replicated: Serializable, + Replicated: Serializable, + Replicated: Serializable, + as Serializable>::Size: Add< as Serializable>::Size>, + Sum< as Serializable>::Size, as Serializable>::Size>: + Add< as Serializable>::Size>, + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >: Add, + Sum< + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >, + U16, + >: ArrayLength, +{ + const ENCAP_KEY_MK_OFFSET: usize = 0; + const CIPHERTEXT_MK_OFFSET: usize = Self::ENCAP_KEY_MK_OFFSET + EncapsulationSize::USIZE; + const ENCAP_KEY_BTT_OFFSET: usize = (Self::CIPHERTEXT_MK_OFFSET + + TagSize::USIZE + + as Serializable>::Size::USIZE); + const CIPHERTEXT_BTT_OFFSET: usize = Self::ENCAP_KEY_BTT_OFFSET + EncapsulationSize::USIZE; - fn serialize(&self, buf: &mut GenericArray) { - let raw = &self.to_le_bytes()[..buf.len()]; - buf.copy_from_slice(raw); + const EVENT_TYPE_OFFSET: usize = (Self::CIPHERTEXT_BTT_OFFSET + + TagSize::USIZE + + as Serializable>::Size::USIZE + + as Serializable>::Size::USIZE + + as Serializable>::Size::USIZE); + const KEY_IDENTIFIER_OFFSET: usize = Self::EVENT_TYPE_OFFSET + 1; + const EPOCH_OFFSET: usize = Self::KEY_IDENTIFIER_OFFSET + 1; + const SITE_DOMAIN_OFFSET: usize = Self::EPOCH_OFFSET + 2; + + // offsets within Ciphertext_BTT + const TS_OFFSET: usize = 0; + + const BK_OFFSET: usize = Self::TS_OFFSET + as Serializable>::Size::USIZE; + const TV_OFFSET: usize = Self::BK_OFFSET + as Serializable>::Size::USIZE; + const TV_END: usize = Self::TV_OFFSET + as Serializable>::Size::USIZE; + + pub fn encap_key_mk(&self) -> &[u8] { + &self.data[Self::ENCAP_KEY_MK_OFFSET..Self::CIPHERTEXT_MK_OFFSET] } - fn deserialize(buf: &GenericArray) -> Result { - let mut buf_to = [0u8; 8]; - buf_to[..buf.len()].copy_from_slice(buf); - Ok(u64::from_le_bytes(buf_to)) + pub fn mk_ciphertext(&self) -> &[u8] { + &self.data[Self::CIPHERTEXT_MK_OFFSET..Self::ENCAP_KEY_BTT_OFFSET] + } + + pub fn encap_key_btt(&self) -> &[u8] { + &self.data[Self::ENCAP_KEY_BTT_OFFSET..Self::CIPHERTEXT_BTT_OFFSET] + } + + pub fn btt_ciphertext(&self) -> &[u8] { + &self.data[Self::CIPHERTEXT_BTT_OFFSET..Self::EVENT_TYPE_OFFSET] + } + + /// ## Panics + /// Only if a `Report` constructor failed to validate the contents properly, which would be a bug. + pub fn event_type(&self) -> EventType { + EventType::try_from(self.data[Self::EVENT_TYPE_OFFSET]).unwrap() // validated on construction + } + + pub fn key_id(&self) -> KeyIdentifier { + self.data[Self::KEY_IDENTIFIER_OFFSET] + } + + /// ## Panics + /// Never. + pub fn epoch(&self) -> Epoch { + u16::from_le_bytes( + self.data[Self::EPOCH_OFFSET..Self::SITE_DOMAIN_OFFSET] + .try_into() + .unwrap(), // infallible slice-to-array conversion + ) + } + + /// ## Panics + /// Only if a `Report` constructor failed to validate the contents properly, which would be a bug. + pub fn site_domain(&self) -> &str { + std::str::from_utf8(&self.data[Self::SITE_DOMAIN_OFFSET..]).unwrap() // validated on construction + } + + /// ## Errors + /// If the report contents are invalid. + pub fn from_bytes(bytes: B) -> Result { + if bytes.len() <= Self::SITE_DOMAIN_OFFSET { + return Err(InvalidReportError::Length( + bytes.len(), + Self::SITE_DOMAIN_OFFSET, + )); + } + EventType::try_from(bytes[Self::EVENT_TYPE_OFFSET])?; + let site_domain = &bytes[Self::SITE_DOMAIN_OFFSET..]; + if !site_domain.is_ascii() { + return Err(NonAsciiStringError::from(site_domain).into()); + } + Ok(Self { + data: bytes, + phantom_data: PhantomData, + }) + } + + /// ## Errors + /// If the match key shares in the report cannot be decrypted (e.g. due to a + /// failure of the authenticated encryption). + /// ## Panics + /// Should not panic. Only panics if a `Report` constructor failed to validate the + /// contents properly, which would be a bug. + pub fn decrypt( + &self, + key_registry: &KeyRegistry, + ) -> Result, InvalidReportError> { + type CTMKLength = Sum< as Serializable>::Size, TagSize>; + type CTBTTLength = Sum< + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >, + TagSize, + >; + + let info = Info::new( + self.key_id(), + self.epoch(), + self.event_type(), + HELPER_ORIGIN, + self.site_domain(), + ) + .unwrap(); // validated on construction + + let mut ct_mk: GenericArray = + *GenericArray::from_slice(self.mk_ciphertext()); + // let mut ct_mk = self.mk_ciphertext().to_vec(); + let plaintext_mk = open_in_place(key_registry, self.encap_key_mk(), &mut ct_mk, &info)?; + let mut ct_btt: GenericArray> = + GenericArray::from_slice(self.btt_ciphertext()).clone(); + // let mut ct_btt = self.btt_ciphertext().to_vec(); + let plaintext_btt = open_in_place(key_registry, self.encap_key_btt(), &mut ct_btt, &info)?; + + Ok(OprfReport:: { + timestamp: Replicated::::deserialize(GenericArray::from_slice( + &plaintext_btt[Self::TS_OFFSET..Self::BK_OFFSET], + )) + .map_err(|e| InvalidReportError::DeserializationError("timestamp", e.into()))?, + match_key: Replicated::::deserialize(GenericArray::from_slice(plaintext_mk)) + .map_err(|e| InvalidReportError::DeserializationError("matchkey", e.into()))?, + event_type: self.event_type(), + breakdown_key: Replicated::::deserialize(GenericArray::from_slice( + &plaintext_btt[Self::BK_OFFSET..Self::TV_OFFSET], + )) + .map_err(|e| InvalidReportError::DeserializationError("is_trigger", e.into()))?, + trigger_value: Replicated::::deserialize(GenericArray::from_slice( + &plaintext_btt[Self::TV_OFFSET..Self::TV_END], + )) + .map_err(|e| InvalidReportError::DeserializationError("trigger_value", e.into()))?, + epoch: self.epoch(), + site_domain: self.site_domain().to_owned(), + }) } } -impl Serializable for OprfReport +impl TryFrom for EncryptedOprfReport where + BK: SharedValue, + TV: SharedValue, + TS: SharedValue, Replicated: Serializable, Replicated: Serializable, Replicated: Serializable, - as Serializable>::Size: Add, - as Serializable>::Size: - Add<< as Serializable>::Size as Add>::Output>, - as Serializable>::Size: Add< - < as Serializable>::Size as Add< - < as Serializable>::Size as Add>::Output, - >>::Output, - >, - < as Serializable>::Size as Add< - < as Serializable>::Size as Add< - < as Serializable>::Size as Add>::Output, - >>::Output, - >>::Output: ArrayLength, + as Serializable>::Size: Add< as Serializable>::Size>, + Sum< as Serializable>::Size, as Serializable>::Size>: + Add< as Serializable>::Size>, + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >: Add, + Sum< + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >, + U16, + >: ArrayLength, { - type Size = < as Serializable>::Size as Add< - < as Serializable>::Size as Add< - < as Serializable>::Size as Add>::Output, - >>::Output, - >>::Output; - type DeserializationError = Error; + type Error = InvalidReportError; - fn serialize(&self, buf: &mut GenericArray) { - let sizeof_matchkey = size_of::() * 2; - let sizeof_eventtype = size_of::() * 2; - let ts_sz = as Serializable>::Size::USIZE; - let bk_sz = as Serializable>::Size::USIZE; - let tv_sz = as Serializable>::Size::USIZE; + fn try_from(bytes: Bytes) -> Result { + EncryptedOprfReport::from_bytes(bytes) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct OprfReport +where + BK: SharedValue, + TV: SharedValue, + TS: SharedValue, +{ + pub match_key: Replicated, + pub event_type: EventType, + pub breakdown_key: Replicated, + pub trigger_value: Replicated, + pub timestamp: Replicated, + pub epoch: Epoch, + pub site_domain: String, +} + +impl OprfReport +where + BK: SharedValue, + TV: SharedValue, + TS: SharedValue, + Replicated: Serializable, + Replicated: Serializable, + Replicated: Serializable, + as Serializable>::Size: Add< as Serializable>::Size>, + Sum< as Serializable>::Size, as Serializable>::Size>: + Add< as Serializable>::Size>, + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >: Add, + Sum< + Sum< + Sum< as Serializable>::Size, as Serializable>::Size>, + as Serializable>::Size, + >, + U16, + >: ArrayLength, +{ + // offsets for BTT Ciphertext + const TS_OFFSET: usize = 0; + const BK_OFFSET: usize = Self::TS_OFFSET + as Serializable>::Size::USIZE; + const TV_OFFSET: usize = Self::BK_OFFSET + as Serializable>::Size::USIZE; + const BTT_END: usize = Self::TV_OFFSET + as Serializable>::Size::USIZE; + + /// # Panics + /// If report length does not fit in `u16`. + pub fn encrypted_len(&self) -> u16 { + let len = EncryptedOprfReport::::SITE_DOMAIN_OFFSET + + self.site_domain.as_bytes().len(); + len.try_into().unwrap() + } + + /// # Errors + /// If there is a problem encrypting the report. + pub fn delimited_encrypt_to( + &self, + key_id: KeyIdentifier, + key_registry: &impl PublicKeyRegistry, + rng: &mut R, + out: &mut B, + ) -> Result<(), InvalidReportError> { + out.put_u16_le(self.encrypted_len()); + self.encrypt_to(key_id, key_registry, rng, out) + } + + /// # Errors + /// If there is a problem encrypting the report. + pub fn encrypt( + &self, + key_id: KeyIdentifier, + key_registry: &impl PublicKeyRegistry, + rng: &mut R, + ) -> Result, InvalidReportError> { + let mut out = Vec::with_capacity(usize::from(self.encrypted_len())); + self.encrypt_to(key_id, key_registry, rng, &mut out)?; + debug_assert_eq!(out.len(), usize::from(self.encrypted_len())); + Ok(out) + } + + /// # Errors + /// If there is a problem encrypting the report. + pub fn encrypt_to( + &self, + key_id: KeyIdentifier, + key_registry: &impl PublicKeyRegistry, + rng: &mut R, + out: &mut B, + ) -> Result<(), InvalidReportError> { + let info = Info::new( + key_id, + self.epoch, + self.event_type, + HELPER_ORIGIN, + self.site_domain.as_ref(), + )?; - self.match_key - .serialize(GenericArray::from_mut_slice(&mut buf[..sizeof_matchkey])); + let mut plaintext_mk = GenericArray::default(); + self.match_key.serialize(&mut plaintext_mk); + let mut plaintext_btt = vec![0u8; Self::BTT_END]; self.timestamp.serialize(GenericArray::from_mut_slice( - &mut buf[sizeof_matchkey..sizeof_matchkey + ts_sz], + &mut plaintext_btt[Self::TS_OFFSET..Self::BK_OFFSET], )); - self.breakdown_key.serialize(GenericArray::from_mut_slice( - &mut buf[sizeof_matchkey + ts_sz..sizeof_matchkey + ts_sz + bk_sz], + &mut plaintext_btt[Self::BK_OFFSET..Self::TV_OFFSET], )); - self.trigger_value.serialize(GenericArray::from_mut_slice( - &mut buf[sizeof_matchkey + ts_sz + bk_sz..sizeof_matchkey + ts_sz + bk_sz + tv_sz], + &mut plaintext_btt[Self::TV_OFFSET + ..(Self::TV_OFFSET + as Serializable>::Size::USIZE)], )); - self.is_trigger.serialize(GenericArray::from_mut_slice( - &mut buf[sizeof_matchkey + ts_sz + bk_sz + tv_sz - ..sizeof_matchkey + ts_sz + bk_sz + tv_sz + sizeof_eventtype], - )); - } + let (encap_key_mk, ciphertext_mk, tag_mk) = + seal_in_place(key_registry, plaintext_mk.as_mut(), &info, rng)?; - fn deserialize(buf: &GenericArray) -> Result { - let sizeof_matchkey = size_of::() * 2; - let sizeof_eventtype = size_of::() * 2; - - let ts_sz = as Serializable>::Size::USIZE; - let bk_sz = as Serializable>::Size::USIZE; - let tv_sz = as Serializable>::Size::USIZE; - - let match_key = - Replicated::::deserialize(GenericArray::from_slice(&buf[..sizeof_matchkey])) - .unwrap_infallible(); - let timestamp = Replicated::::deserialize(GenericArray::from_slice( - &buf[sizeof_matchkey..sizeof_matchkey + ts_sz], - )) - .map_err(|e| Error::ParseError(e.into()))?; - let breakdown_key = Replicated::::deserialize(GenericArray::from_slice( - &buf[sizeof_matchkey + ts_sz..sizeof_matchkey + ts_sz + bk_sz], - )) - .map_err(|e| Error::ParseError(e.into()))?; - let trigger_value = Replicated::::deserialize(GenericArray::from_slice( - &buf[sizeof_matchkey + ts_sz + bk_sz..sizeof_matchkey + ts_sz + bk_sz + tv_sz], - )) - .map_err(|e| Error::ParseError(e.into()))?; - let is_trigger = Replicated::::deserialize(GenericArray::from_slice( - &buf[sizeof_matchkey + ts_sz + bk_sz + tv_sz - ..sizeof_matchkey + ts_sz + bk_sz + tv_sz + sizeof_eventtype], - )) - .map_err(|e| Error::ParseError(e.into()))?; + let (encap_key_btt, ciphertext_btt, tag_btt) = + seal_in_place(key_registry, plaintext_btt.as_mut(), &info, rng)?; - Ok(Self { - match_key, - is_trigger, - breakdown_key, - trigger_value, - timestamp, - }) + out.put_slice(&encap_key_mk.to_bytes()); + out.put_slice(ciphertext_mk); + out.put_slice(&tag_mk.to_bytes()); + out.put_slice(&encap_key_btt.to_bytes()); + out.put_slice(ciphertext_btt); + out.put_slice(&tag_btt.to_bytes()); + out.put_slice(&[u8::from(&self.event_type)]); + out.put_slice(&[key_id]); + out.put_slice(&self.epoch.to_le_bytes()); + out.put_slice(self.site_domain.as_bytes()); + + Ok(()) } } #[cfg(all(test, unit_test))] mod test { - use rand::{distributions::Alphanumeric, rngs::StdRng, Rng}; + use rand::{distributions::Alphanumeric, rngs::StdRng, thread_rng, Rng}; use rand_core::SeedableRng; use super::*; - use crate::ff::{Fp32BitPrime, Gf40Bit, Gf8Bit}; + use crate::{ + ff::{ + boolean_array::{BA20, BA3, BA8}, + Fp32BitPrime, Gf40Bit, Gf8Bit, + }, + report, + report::EventType::{Source, Trigger}, + secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, + }; #[test] fn enc_dec_roundtrip() { @@ -583,6 +828,70 @@ mod test { assert_eq!(dec_report, report); } + #[test] + fn enc_dec_roundtrip_oprf() { + let mut rng = thread_rng(); + + let b: EventType = if rng.gen::() { Trigger } else { Source }; + + let report = OprfReport:: { + match_key: AdditiveShare::new(rng.gen(), rng.gen()), + timestamp: AdditiveShare::new(rng.gen(), rng.gen()), + breakdown_key: AdditiveShare::new(rng.gen(), rng.gen()), + trigger_value: AdditiveShare::new(rng.gen(), rng.gen()), + event_type: b, + epoch: rng.gen(), + site_domain: (&mut rng) + .sample_iter(Alphanumeric) + .map(char::from) + .take(10) + .collect(), + }; + + let key_registry = KeyRegistry::random(1, &mut rng); + let key_id = 0; + + let enc_report_bytes = report.encrypt(key_id, &key_registry, &mut rng).unwrap(); + let enc_report = EncryptedOprfReport::from_bytes(enc_report_bytes.as_slice()).unwrap(); + let dec_report: OprfReport = enc_report.decrypt(&key_registry).unwrap(); + + assert_eq!(dec_report, report); + } + + #[test] + fn test_decryption_fails() { + let mut rng = thread_rng(); + + let b: EventType = if rng.gen::() { Trigger } else { Source }; + + let report = OprfReport:: { + match_key: AdditiveShare::new(rng.gen(), rng.gen()), + timestamp: AdditiveShare::new(rng.gen(), rng.gen()), + breakdown_key: AdditiveShare::new(rng.gen(), rng.gen()), + trigger_value: AdditiveShare::new(rng.gen(), rng.gen()), + event_type: b, + epoch: rng.gen(), + site_domain: (&mut rng) + .sample_iter(Alphanumeric) + .map(char::from) + .take(10) + .collect(), + }; + + let enc_key_registry = KeyRegistry::random(1, &mut rng); + let enc_key_id = 0; + let dec_key_registry = KeyRegistry::random(1, &mut rng); + + let enc_report_bytes = report + .encrypt(enc_key_id, &enc_key_registry, &mut rng) + .unwrap(); + let enc_report: report::EncryptedOprfReport = + EncryptedOprfReport::from_bytes(enc_report_bytes.as_slice()).unwrap(); + let dec_report = enc_report.decrypt(&dec_key_registry); + + assert!(dec_report.is_err()); + } + #[test] fn decrypt() { let mut rng = StdRng::from_seed([1_u8; 32]); diff --git a/ipa-core/src/test_fixture/input/sharing.rs b/ipa-core/src/test_fixture/input/sharing.rs index 873b2cc8a..b184fbb1f 100644 --- a/ipa-core/src/test_fixture/input/sharing.rs +++ b/ipa-core/src/test_fixture/input/sharing.rs @@ -11,6 +11,7 @@ use crate::{ AccumulateCreditInputRow, ApplyAttributionWindowInputRow, CreditCappingInputRow, }, ipa::IPAInputRow, + ipa_prf::OPRFIPAInputRow, BreakdownKey, MatchKey, }, rand::Rng, @@ -199,6 +200,51 @@ where } } +impl IntoShares> for TestRawDataRecord +where + BK: SharedValue + Field + IntoShares>, + TV: SharedValue + Field + IntoShares>, + TS: SharedValue + Field + IntoShares>, +{ + fn share_with(self, rng: &mut R) -> [OprfReport; 3] { + let match_key = BA64::try_from(u128::from(self.user_id)) + .unwrap() + .share_with(rng); + let timestamp: [Replicated; 3] = TS::try_from(u128::from(self.timestamp)) + .unwrap() + .share_with(rng); + let breakdown_key = BK::try_from(self.breakdown_key.into()) + .unwrap() + .share_with(rng); + let trigger_value = TV::try_from(self.trigger_value.into()) + .unwrap() + .share_with(rng); + let event_type = if self.is_trigger_report { + EventType::Trigger + } else { + EventType::Source + }; + let epoch = 1; + let site_domain = DOMAINS[rng.gen_range(0..DOMAINS.len())].to_owned(); + + zip(zip(match_key, zip(timestamp, breakdown_key)), trigger_value) + .map( + |((match_key_share, (ts_share, bk_share)), tv_share)| OprfReport { + timestamp: ts_share, + match_key: match_key_share, + event_type, + breakdown_key: bk_share, + trigger_value: tv_share, + epoch, + site_domain: site_domain.clone(), + }, + ) + .collect::>() + .try_into() + .unwrap() + } +} + impl IntoShares> for GenericReportTestInput where F: Field + IntoShares>, @@ -360,13 +406,13 @@ where } } -impl IntoShares> for TestRawDataRecord +impl IntoShares> for TestRawDataRecord where BK: SharedValue + Field + IntoShares>, TV: SharedValue + Field + IntoShares>, TS: SharedValue + Field + IntoShares>, { - fn share_with(self, rng: &mut R) -> [OprfReport; 3] { + fn share_with(self, rng: &mut R) -> [OPRFIPAInputRow; 3] { let is_trigger = Replicated::new( Boolean::from(self.is_trigger_report), Boolean::from(self.is_trigger_report), @@ -389,12 +435,14 @@ where repeat(is_trigger), ) .map( - |(((match_key_share, (ts_share, bk_share)), tv_share), is_trigger_share)| OprfReport { - timestamp: ts_share, - match_key: match_key_share, - is_trigger: is_trigger_share, - breakdown_key: bk_share, - trigger_value: tv_share, + |(((match_key_share, (ts_share, bk_share)), tv_share), is_trigger_share)| { + OPRFIPAInputRow { + timestamp: ts_share, + match_key: match_key_share, + is_trigger: is_trigger_share, + breakdown_key: bk_share, + trigger_value: tv_share, + } }, ) .collect::>() @@ -403,7 +451,7 @@ where } } -impl Reconstruct for [&OprfReport; 3] +impl Reconstruct for [&OPRFIPAInputRow; 3] where BK: SharedValue + Field, TV: SharedValue + Field, diff --git a/ipa-core/src/test_fixture/ipa.rs b/ipa-core/src/test_fixture/ipa.rs index 363b2e723..57739df2f 100644 --- a/ipa-core/src/test_fixture/ipa.rs +++ b/ipa-core/src/test_fixture/ipa.rs @@ -6,7 +6,7 @@ use crate::{ ff::{PrimeField, Serializable}, helpers::query::IpaQueryConfig, ipa_test_input, - protocol::{ipa::ipa, BreakdownKey, MatchKey}, + protocol::{ipa::ipa, ipa_prf::OPRFIPAInputRow, BreakdownKey, MatchKey}, secret_sharing::{ replicated::{ malicious, malicious::ExtendableField, semi_honest, @@ -250,7 +250,6 @@ pub async fn test_oprf_ipa( use crate::{ ff::boolean_array::{BA20, BA3, BA4, BA5, BA6, BA7, BA8}, protocol::ipa_prf::oprf_ipa, - report::OprfReport, test_fixture::Runner, }; @@ -258,7 +257,7 @@ pub async fn test_oprf_ipa( let result: Vec<_> = world .semi_honest( records.into_iter(), - |ctx, input_rows: Vec>| async move { + |ctx, input_rows: Vec>| async move { match config.per_user_credit_cap { 8 => oprf_ipa::<_, BA8, BA3, BA20, BA3, F>(ctx, input_rows, aws)