Skip to content

Commit

Permalink
Merge pull request #814 from richajaindce/oneshot
Browse files Browse the repository at this point in the history
Oneshot for OPRF + enabling compact-gate for it
  • Loading branch information
benjaminsavage authored Oct 25, 2023
2 parents 090d94f + 51e05fe commit b396557
Show file tree
Hide file tree
Showing 12 changed files with 887 additions and 33 deletions.
30 changes: 21 additions & 9 deletions benches/oneshot/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use ipa::{
ff::Fp32BitPrime,
helpers::{query::IpaQueryConfig, GatewayConfig},
test_fixture::{
ipa::{ipa_in_the_clear, test_ipa, IpaSecurityModel},
ipa::{ipa_in_the_clear, test_ipa, test_oprf_ipa, CappingOrder, IpaSecurityModel},
EventGenerator, EventGeneratorConfig, TestWorld, TestWorldConfig,
},
};
Expand Down Expand Up @@ -70,6 +70,8 @@ struct Args {
/// Needed for benches.
#[arg(long, hide = true)]
bench: bool,
#[arg(short = 'o', long)]
oprf: bool,
}

impl Args {
Expand Down Expand Up @@ -121,25 +123,35 @@ async fn run(args: Args) -> Result<(), Error> {
.take(args.query_size)
.collect::<Vec<_>>();

let order = if args.oprf {
CappingOrder::CapMostRecentFirst
} else {
CappingOrder::CapOldestFirst
};
let expected_results = ipa_in_the_clear(
&raw_data,
args.per_user_cap,
args.attribution_window(),
args.breakdown_keys,
&order,
);

let world = TestWorld::new_with(config.clone());
tracing::trace!("Preparation complete in {:?}", _prep_time.elapsed());

let _protocol_time = Instant::now();
test_ipa::<BenchField>(
&world,
&raw_data,
&expected_results,
args.config(),
args.mode,
)
.await;
if args.oprf {
test_oprf_ipa::<BenchField>(&world, raw_data, &expected_results, args.config()).await;
} else {
test_ipa::<BenchField>(
&world,
&raw_data,
&expected_results,
args.config(),
args.mode,
)
.await;
}
tracing::trace!(
"{m:?} IPA for {q} records took {t:?}",
m = args.mode,
Expand Down
3 changes: 3 additions & 0 deletions pre-commit
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ check "Concurrency tests" \
check "IPA benchmark" \
cargo bench --bench oneshot_ipa --no-default-features --features="enable-benches descriptive-gate" -- -n 62

check "IPA OPRF benchmark" \
cargo bench --bench oneshot_ipa --no-default-features --features="enable-benches descriptive-gate" -- -n 62 --oprf

check "Arithmetic circuit benchmark" \
cargo bench --bench oneshot_arithmetic --no-default-features --features "enable-benches descriptive-gate"

Expand Down
38 changes: 35 additions & 3 deletions scripts/collect_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ def extract_intermediate_steps(steps):

return steps


if __name__ == "__main__":
steps = set()
def ipa_steps(steps):
for c in PER_USER_CAP:
for w in ATTRIBUTION_WINDOW:
for b in BREAKDOWN_KEYS:
Expand All @@ -171,6 +169,40 @@ def extract_intermediate_steps(steps):
print(" ".join(args), file=sys.stderr)
steps.update(collect_steps(args))

OPRF_BREAKDOWN_KEYS = [256]
OPRF_USER_CAP = [16, 64, 128]
OPRF_SECURITY_MODEL = ["semi-honest"]
OPRF_TRIGGER_VALUE = [7]

def oprf_steps(steps):
for c in OPRF_USER_CAP:
for w in ATTRIBUTION_WINDOW:
for b in OPRF_BREAKDOWN_KEYS:
for m in OPRF_SECURITY_MODEL:
for tv in OPRF_TRIGGER_VALUE:
args = ARGS + [
"-n",
str(QUERY_SIZE),
"-c",
str(c),
"-w",
str(w),
"-b",
str(b),
"-m",
m,
"-t",
str(tv),
"-o"
]
print(" ".join(args), file=sys.stderr)
steps.update(collect_steps(args))

if __name__ == "__main__":
steps = set()
ipa_steps(steps)
oprf_steps(steps)

full_steps = extract_intermediate_steps(steps)
sorted_steps = sorted(full_steps)

Expand Down
3 changes: 2 additions & 1 deletion src/bin/report_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use ipa::{
protocol::{BreakdownKey, MatchKey},
report::{KeyIdentifier, DEFAULT_KEY_ID},
test_fixture::{
ipa::{ipa_in_the_clear, IpaSecurityModel, TestRawDataRecord},
ipa::{ipa_in_the_clear, CappingOrder, IpaSecurityModel, TestRawDataRecord},
EventGenerator, EventGeneratorConfig,
},
};
Expand Down Expand Up @@ -247,6 +247,7 @@ async fn ipa(
ipa_query_config.per_user_credit_cap,
ipa_query_config.attribution_window_seconds,
ipa_query_config.max_breakdown_key,
&CappingOrder::CapOldestFirst,
);

// pad the output vector to the max breakdown key, to make sure it is aligned with the MPC results
Expand Down
3 changes: 2 additions & 1 deletion src/protocol/ipa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ pub mod tests {
test_executor::{run, run_with},
test_fixture::{
input::GenericReportTestInput,
ipa::{ipa_in_the_clear, test_ipa, IpaSecurityModel},
ipa::{ipa_in_the_clear, test_ipa, CappingOrder, IpaSecurityModel},
logging, EventGenerator, EventGeneratorConfig, Reconstruct, Runner, TestWorld,
TestWorldConfig,
},
Expand Down Expand Up @@ -815,6 +815,7 @@ pub mod tests {
per_user_cap,
ATTRIBUTION_WINDOW_SECONDS,
MAX_BREAKDOWN_KEY,
&CappingOrder::CapOldestFirst,
);

let config = TestWorldConfig {
Expand Down
5 changes: 3 additions & 2 deletions src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ pub mod context;
pub mod dp;
pub mod ipa;
pub mod modulus_conversion;
#[cfg(feature = "descriptive-gate")]
pub mod prf_sharding;
pub mod prss;
pub mod sort;
Expand All @@ -22,11 +21,13 @@ pub use basics::BasicProtocols;

use crate::{
error::Error,
ff::{Gf40Bit, Gf8Bit},
ff::{Gf20Bit, Gf3Bit, Gf40Bit, Gf8Bit},
};

pub type MatchKey = Gf40Bit;
pub type BreakdownKey = Gf8Bit;
pub type TriggerValue = Gf3Bit;
pub type Timestamp = Gf20Bit;

/// Unique identifier of the MPC query requested by report collectors
/// TODO(615): Generating this unique id may be tricky as it may involve communication between helpers and
Expand Down
10 changes: 5 additions & 5 deletions src/protocol/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ pub mod bucket;
pub mod feature_label_dot_product;

pub struct PrfShardedIpaInputRow<BK: GaloisField, TV: GaloisField, TS: GaloisField> {
prf_of_match_key: u64,
is_trigger_bit: Replicated<Gf2>,
breakdown_key: Replicated<BK>,
trigger_value: Replicated<TV>,
timestamp: Replicated<TS>,
pub prf_of_match_key: u64,
pub is_trigger_bit: Replicated<Gf2>,
pub breakdown_key: Replicated<BK>,
pub trigger_value: Replicated<TV>,
pub timestamp: Replicated<TS>,
}

impl<BK: GaloisField, TV: GaloisField, TS: GaloisField> PrfShardedIpaInputRow<BK, TV, TS> {
Expand Down
Loading

0 comments on commit b396557

Please sign in to comment.