Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Real world test #822

Merged
merged 8 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 56 additions & 18 deletions src/bin/report_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use hyper::http::uri::Scheme;
use ipa::{
cli::{
noise::{apply, ApplyDpArgs},
playbook::{make_clients, playbook_ipa, validate, InputSource},
playbook::{make_clients, playbook_ipa, playbook_oprf_ipa, validate, InputSource},
CsvSerializer, IpaQueryResult, Verbosity,
},
config::NetworkConfig,
Expand All @@ -26,7 +26,7 @@ use ipa::{
protocol::{BreakdownKey, MatchKey},
report::{KeyIdentifier, DEFAULT_KEY_ID},
test_fixture::{
ipa::{ipa_in_the_clear, CappingOrder, IpaSecurityModel, TestRawDataRecord},
ipa::{ipa_in_the_clear, CappingOrder, IpaQueryStyle, IpaSecurityModel, TestRawDataRecord},
EventGenerator, EventGeneratorConfig,
},
};
Expand Down Expand Up @@ -103,6 +103,8 @@ enum ReportCollectorCommand {
},
/// Apply differential privacy noise to IPA inputs
ApplyDpNoise(ApplyDpArgs),
/// Execute OPRF IPA in a semi-honest majority setting
OprfIpa(IpaQueryConfig),
}

#[derive(Debug, clap::Args)]
Expand Down Expand Up @@ -134,6 +136,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
IpaSecurityModel::SemiHonest,
config,
&clients,
IpaQueryStyle::SortInMpc,
)
.await?
}
Expand All @@ -144,6 +147,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
IpaSecurityModel::Malicious,
config,
&clients,
IpaQueryStyle::SortInMpc,
)
.await?
}
Expand All @@ -153,6 +157,17 @@ async fn main() -> Result<(), Box<dyn Error>> {
gen_args,
} => gen_inputs(count, seed, args.output_file, gen_args)?,
ReportCollectorCommand::ApplyDpNoise(ref dp_args) => apply_dp_noise(&args, dp_args)?,
ReportCollectorCommand::OprfIpa(config) => {
ipa(
&args,
&network,
IpaSecurityModel::SemiHonest,
config,
&clients,
IpaQueryStyle::Oprf,
)
.await?
}
};

Ok(())
Expand Down Expand Up @@ -221,16 +236,23 @@ async fn ipa(
security_model: IpaSecurityModel,
ipa_query_config: IpaQueryConfig,
helper_clients: &[MpcHelperClient; 3],
query_style: IpaQueryStyle,
) -> Result<(), Box<dyn Error>> {
let input = InputSource::from(&args.input);
let query_type: QueryType;
match security_model {
IpaSecurityModel::SemiHonest => {
match (security_model, &query_style) {
(IpaSecurityModel::SemiHonest, IpaQueryStyle::SortInMpc) => {
query_type = QueryType::SemiHonestIpa(ipa_query_config.clone());
}
IpaSecurityModel::Malicious => {
(IpaSecurityModel::Malicious, IpaQueryStyle::SortInMpc) => {
query_type = QueryType::MaliciousIpa(ipa_query_config.clone())
}
(IpaSecurityModel::SemiHonest, IpaQueryStyle::Oprf) => {
query_type = QueryType::OprfIpa(ipa_query_config.clone());
}
(IpaSecurityModel::Malicious, IpaQueryStyle::Oprf) => {
panic!("OPRF for malicious is not implemented as yet")
}
};

let input_rows = input.iter::<TestRawDataRecord>().collect::<Vec<_>>();
Expand All @@ -247,7 +269,10 @@ async fn ipa(
ipa_query_config.per_user_credit_cap,
ipa_query_config.attribution_window_seconds,
ipa_query_config.max_breakdown_key,
&CappingOrder::CapOldestFirst,
&(match query_style {
IpaQueryStyle::Oprf => CappingOrder::CapMostRecentFirst,
IpaQueryStyle::SortInMpc => CappingOrder::CapOldestFirst,
}),
);

// pad the output vector to the max breakdown key, to make sure it is aligned with the MPC results
Expand All @@ -260,18 +285,27 @@ async fn ipa(
};

let mut key_registries = KeyRegistries::default();
let actual = playbook_ipa::<Fp32BitPrime, MatchKey, BreakdownKey, _>(
&input_rows,
&helper_clients,
query_id,
ipa_query_config,
key_registries.init_from(network),
)
.await;

tracing::info!("{m:?}", m = ipa_query_config);

validate(&expected, &actual.breakdowns);
let actual = match query_style {
IpaQueryStyle::Oprf => {
playbook_oprf_ipa::<Fp32BitPrime>(
input_rows,
&helper_clients,
query_id,
ipa_query_config,
)
.await
}
IpaQueryStyle::SortInMpc => {
playbook_ipa::<Fp32BitPrime, MatchKey, BreakdownKey, _>(
&input_rows,
&helper_clients,
query_id,
ipa_query_config,
key_registries.init_from(network),
)
.await
}
};

if let Some(ref path) = args.output_file {
// it will be sad to lose the results if file already exists.
Expand Down Expand Up @@ -308,6 +342,10 @@ async fn ipa(
write!(file, "{}", serde_json::to_string_pretty(&actual)?)?;
}

tracing::info!("{m:?}", m = ipa_query_config);
benjaminsavage marked this conversation as resolved.
Show resolved Hide resolved

validate(&expected, &actual.breakdowns);

Ok(())
}

Expand Down
65 changes: 62 additions & 3 deletions src/cli/playbook/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ use crate::{
hpke::PublicKeyRegistry,
ipa_test_input,
net::MpcHelperClient,
protocol::{ipa::IPAInputRow, BreakdownKey, MatchKey, QueryId},
protocol::{ipa::IPAInputRow, BreakdownKey, MatchKey, QueryId, Timestamp, TriggerValue},
query::QueryStatus,
report::{KeyIdentifier, Report},
report::{KeyIdentifier, OprfReport, Report},
secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares},
test_fixture::{input::GenericReportTestInput, ipa::TestRawDataRecord, Reconstruct},
};
Expand Down Expand Up @@ -99,6 +99,57 @@ where

let inputs = buffers.map(BodyStream::from);
tracing::info!("Starting query after finishing encryption");

do_processing::<F>(inputs, query_size, clients, query_id, query_config).await
}

pub async fn playbook_oprf_ipa<F>(
mut records: Vec<TestRawDataRecord>,
clients: &[MpcHelperClient; 3],
query_id: QueryId,
query_config: IpaQueryConfig,
) -> IpaQueryResult
where
F: PrimeField,
AdditiveShare<F>: Serializable,
{
let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new());
let query_size = records.len();

let sz = <OprfReport<Timestamp, BreakdownKey, TriggerValue> as Serializable>::Size::USIZE;
for buffer in &mut buffers {
buffer.resize(query_size * sz, 0u8);
}

//TODO(richaj) This manual sorting will be removed once we have the PRF sharding in place.
//This does a stable sort. It also expects the inputs to be sorted by timestamp
records.sort_by(|a, b| b.user_id.cmp(&a.user_id));

let shares: [Vec<OprfReport<Timestamp, BreakdownKey, TriggerValue>>; 3] =
records.iter().cloned().share();
zip(&mut buffers, shares).for_each(|(buf, shares)| {
for (share, chunk) in zip(shares, buf.chunks_mut(sz)) {
share.serialize(GenericArray::from_mut_slice(chunk));
}
});
benjaminsavage marked this conversation as resolved.
Show resolved Hide resolved

let inputs = buffers.map(BodyStream::from);
tracing::info!("Starting query for OPRF");
benjaminsavage marked this conversation as resolved.
Show resolved Hide resolved

do_processing::<F>(inputs, query_size, clients, query_id, query_config).await
}

pub async fn do_processing<F>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's select a better name =)

inputs: [BodyStream; 3],
query_size: usize,
clients: &[MpcHelperClient; 3],
query_id: QueryId,
query_config: IpaQueryConfig,
) -> IpaQueryResult
where
F: PrimeField,
AdditiveShare<F>: Serializable,
{
let mpc_time = Instant::now();
try_join_all(
inputs
Expand Down Expand Up @@ -143,12 +194,20 @@ where
.reconstruct();

let lat = mpc_time.elapsed();

tracing::info!("Running IPA for {query_size:?} records took {t:?}", t = lat);
let mut breakdowns = vec![0; usize::try_from(query_config.max_breakdown_key).unwrap()];
for (breakdown_key, trigger_value) in results.into_iter().enumerate() {
// TODO: make the data type used consistent with `ipa_in_the_clear`
// I think using u32 is wrong, we should move to u128
breakdowns[breakdown_key] += u32::try_from(trigger_value.as_u128()).unwrap();
assert!(
breakdown_key < query_config.max_breakdown_key.try_into().unwrap()
|| trigger_value == F::ZERO,
"trigger values were attributed to buckets more than max breakdown key"
);
if breakdown_key < query_config.max_breakdown_key.try_into().unwrap() {
breakdowns[breakdown_key] += u32::try_from(trigger_value.as_u128()).unwrap();
}
}

IpaQueryResult {
Expand Down
2 changes: 1 addition & 1 deletion src/cli/playbook/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub use input::InputSource;
pub use multiply::secure_mul;
use tokio::time::sleep;

pub use self::ipa::playbook_ipa;
pub use self::ipa::{playbook_ipa, playbook_oprf_ipa};
use crate::{
config::{ClientConfig, NetworkConfig, PeerConfig},
net::{ClientIdentity, MpcHelperClient},
Expand Down
3 changes: 3 additions & 0 deletions src/helpers/transport/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ pub enum QueryType {
MaliciousIpa(IpaQueryConfig),
SemiHonestSparseAggregate(SparseAggregateQueryConfig),
MaliciousSparseAggregate(SparseAggregateQueryConfig),
OprfIpa(IpaQueryConfig),
}

impl QueryType {
Expand All @@ -214,6 +215,7 @@ impl QueryType {
pub const MALICIOUS_IPA_STR: &'static str = "malicious-ipa";
pub const SEMIHONEST_AGGREGATE_STR: &'static str = "semihonest-sparse-aggregate";
pub const MALICIOUS_AGGREGATE_STR: &'static str = "malicious-sparse-aggregate";
pub const OPRF_IPA_STR: &'static str = "oprf_ipa";
}

/// TODO: should this `AsRef` impl (used for `Substep`) take into account config of IPA?
Expand All @@ -226,6 +228,7 @@ impl AsRef<str> for QueryType {
QueryType::MaliciousIpa(_) => Self::MALICIOUS_IPA_STR,
QueryType::SemiHonestSparseAggregate(_) => Self::SEMIHONEST_AGGREGATE_STR,
QueryType::MaliciousSparseAggregate(_) => Self::MALICIOUS_AGGREGATE_STR,
QueryType::OprfIpa(_) => Self::OPRF_IPA_STR,
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion src/net/http_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ pub mod query {
let Query(q) = req.extract().await?;
Ok(QueryType::MaliciousSparseAggregate(q))
}
QueryType::OPRF_IPA_STR => {
let Query(q) = req.extract().await?;
Ok(QueryType::OprfIpa(q))
}
other => Err(Error::bad_query_value("query_type", other)),
}?;
Ok(QueryConfigQueryParams(QueryConfig {
Expand All @@ -161,7 +165,9 @@ pub mod query {
match self.query_type {
#[cfg(any(test, feature = "test-fixture", feature = "cli"))]
QueryType::TestMultiply => Ok(()),
QueryType::SemiHonestIpa(config) | QueryType::MaliciousIpa(config) => {
QueryType::SemiHonestIpa(config)
| QueryType::MaliciousIpa(config)
| QueryType::OprfIpa(config) => {
write!(
f,
"&per_user_credit_cap={}&max_breakdown_key={}&num_multi_bits={}",
Expand Down
28 changes: 28 additions & 0 deletions src/query/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use rand_core::SeedableRng;
use shuttle::future as tokio;
use typenum::Unsigned;

use super::runner::OprfIpaQuery;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
use crate::query::runner::execute_test_multiply;
use crate::{
Expand Down Expand Up @@ -202,6 +203,33 @@ pub fn execute(
},
)
}
(QueryType::OprfIpa(ipa_config), FieldType::Fp32BitPrime) => do_query(
config,
gateway,
input,
move |prss, gateway, config, input| {
let ctx = SemiHonestContext::new(prss, gateway);
Box::pin(
OprfIpaQuery::<_, Fp32BitPrime>::new(ipa_config)
.execute(ctx, config.size, input)
.then(|res| ready(res.map(|out| Box::new(out) as Box<dyn Result>))),
)
},
),
#[cfg(any(test, feature = "weak-field"))]
(QueryType::OprfIpa(ipa_config), FieldType::Fp31) => do_query(
config,
gateway,
input,
move |prss, gateway, config, input| {
let ctx = SemiHonestContext::new(prss, gateway);
Box::pin(
OprfIpaQuery::<_, Fp32BitPrime>::new(ipa_config)
.execute(ctx, config.size, input)
.then(|res| ready(res.map(|out| Box::new(out) as Box<dyn Result>))),
)
},
),
}
}

Expand Down
8 changes: 3 additions & 5 deletions src/query/runner/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::marker::PhantomData;

use futures_util::TryStreamExt;

use super::ipa::assert_stream_send;
use crate::{
error::Error,
ff::{Gf2, Gf8Bit, PrimeField, Serializable},
Expand Down Expand Up @@ -83,10 +82,9 @@ where

let input = {
//TODO: Replace `Gf8Bit` with an appropriate type specified by the config `contribution_bits`
let mut v = assert_stream_send(RecordsStream::<
SparseAggregateInputRow<Gf8Bit, BreakdownKey>,
_,
>::new(input_stream))
let mut v = RecordsStream::<SparseAggregateInputRow<Gf8Bit, BreakdownKey>, _>::new(
input_stream,
)
.try_concat()
.await?;
v.truncate(sz);
Expand Down
Loading