diff --git a/src/protocol/oprf/mod.rs b/src/protocol/oprf/mod.rs index 4d1f0f0398..74491318ea 100644 --- a/src/protocol/oprf/mod.rs +++ b/src/protocol/oprf/mod.rs @@ -175,7 +175,6 @@ pub async fn oprf_shuffle( input_rows: &[OPRFInputRow], _config: QueryConfig, ) -> Result, Error> { - let role = ctx.role(); let batch_size = u32::try_from(input_rows.len()).map_err(|_e| { Error::FieldValueTruncation(format!( "Cannot truncate the number of input rows {} to u32", @@ -186,64 +185,31 @@ pub async fn oprf_shuffle( let my_shares = split_shares_and_get_left(input_rows); let shared_with_rhs = split_shares_and_get_right(input_rows); - match role { - Role::H1 => run_h1(&ctx, &role, batch_size, my_shares, shared_with_rhs).await, - Role::H2 => run_h2(&ctx, &role, batch_size, my_shares, shared_with_rhs).await, - Role::H3 => run_h3(&ctx, &role, batch_size, my_shares, shared_with_rhs).await, + match ctx.role() { + Role::H1 => run_h1(&ctx, batch_size, my_shares, shared_with_rhs).await, + Role::H2 => run_h2(&ctx, batch_size, my_shares, shared_with_rhs).await, + Role::H3 => run_h3(&ctx, batch_size, my_shares, shared_with_rhs).await, } } -async fn run_h1( - ctx: &C, - role: &Role, - batch_size: u32, - my_shares: L, - rhs_shared: R, -) -> Result, Error> +async fn run_h1(ctx: &C, batch_size: u32, a: L, b: R) -> Result, Error> where C: Context, L: IntoIterator, R: IntoIterator, { - let a = my_shares; - let b = rhs_shared; - // // 1. Generate permutations - let pi_12 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi12, - Direction::Right, - ); - let pi_31 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi31, - Direction::Left, - ); - // - // 2. Generate random tables - let z_12 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ12, - Direction::Right, - ); - let z_31 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ31, - Direction::Left, - ); + let (pi_31, pi_12) = generate_permutations_with_peers(batch_size, ctx); - let a_hat = generate_random_table_with_peer( + // 2. Generate random tables + let (z_31, z_12) = generate_random_tables_with_peers(batch_size, ctx); + let a_hat = generate_random_table( batch_size, ctx, &OPRFShuffleStep::GenerateAHat, Direction::Left, ); - - let b_hat = generate_random_table_with_peer( + let b_hat = generate_random_table( batch_size, ctx, &OPRFShuffleStep::GenerateBHat, @@ -253,14 +219,13 @@ where // 3. Run computations let x_1_arg = add_single_shares(add_single_shares(a, b), z_12); let x_1 = permute(&pi_12, x_1_arg); - let x_2_arg = add_single_shares(x_1.iter(), z_31.iter()); let x_2 = permute(&pi_31, x_2_arg); send_to_peer( ctx, &OPRFShuffleStep::TransferX2, - role.peer(Direction::Right), + Direction::Right, x_2.clone(), ) .await?; @@ -269,72 +234,32 @@ where Ok(res) } -async fn run_h2( - ctx: &C, - role: &Role, - batch_size: u32, - _my_shares: L, - shared_with_rhs: R, -) -> Result, Error> +async fn run_h2(ctx: &C, batch_size: u32, _b: L, c: R) -> Result, Error> where C: Context, L: IntoIterator, R: IntoIterator, { - let c = shared_with_rhs; - // 1. Generate permutations - let pi_12 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi12, - Direction::Left, - ); - - let pi_23 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi23, - Direction::Right, - ); - + let (pi_12, pi_23) = generate_permutations_with_peers(batch_size, ctx); // 2. Generate random tables - let z_12 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ12, - Direction::Left, - ); - - let z_23 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ23, - Direction::Right, - ); - - let b_hat = generate_random_table_with_peer( + let (z_12, z_23) = generate_random_tables_with_peers(batch_size, ctx); + let b_hat = generate_random_table( batch_size, ctx, &OPRFShuffleStep::GenerateBHat, Direction::Left, ); - // // 3. Run computations let y_1_arg = add_single_shares(c, z_12.into_iter()); let y_1 = permute(&pi_12, y_1_arg); let ((), x_2) = try_join!( - send_to_peer( - ctx, - &OPRFShuffleStep::TransferY1, - role.peer(Direction::Right), - y_1, - ), + send_to_peer(ctx, &OPRFShuffleStep::TransferY1, Direction::Right, y_1), receive_from_peer( ctx, &OPRFShuffleStep::TransferX2, - role.peer(Direction::Left), + Direction::Left, batch_size, ), )?; @@ -342,87 +267,33 @@ where let x_3_arg = add_single_shares(x_2.into_iter(), z_23.into_iter()); let x_3 = permute(&pi_23, x_3_arg); let c_hat_1 = add_single_shares(x_3.iter(), b_hat.iter()).collect::>(); - - let ((), c_hat_2) = try_join!( - send_to_peer( - ctx, - &OPRFShuffleStep::TransferCHat1, - role.peer(Direction::Right), - c_hat_1.clone(), - ), - receive_from_peer( - ctx, - &OPRFShuffleStep::TransferCHat2, - role.peer(Direction::Right), - batch_size, - ) - )?; - + let c_hat_2 = exchange_c_hat(ctx, batch_size, c_hat_1.clone()).await?; let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()); let res = combine_shares(b_hat, c_hat); Ok(res) } -async fn run_h3( - ctx: &C, - role: &Role, - batch_size: u32, - _my_shares: L, - _shared_with_rhs: R, -) -> Result, Error> +async fn run_h3(ctx: &C, batch_size: u32, _c: L, _a: R) -> Result, Error> where C: Context, L: IntoIterator, R: IntoIterator, { - // H3 does not need any secret shares. - // Its "C" shares are processed by helper2, Its "A" shares are processed by helper 1 - /* - let c = my_shares; - let a = rhs_shared; - */ - // 1. Generate permutations - let pi_23 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi23, - Direction::Left, - ); - let pi_31 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi31, - Direction::Right, - ); - + let (pi_23, pi_31) = generate_permutations_with_peers(batch_size, ctx); // 2. Generate random tables - let z_23 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ23, - Direction::Left, - ); - - let z_31 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ31, - Direction::Right, - ); - - let a_hat = generate_random_table_with_peer( + let (z_23, z_31) = generate_random_tables_with_peers(batch_size, ctx); + let a_hat = generate_random_table( batch_size, ctx, &OPRFShuffleStep::GenerateAHat, Direction::Right, ); - // 3. Run computations let y_1 = receive_from_peer( ctx, &OPRFShuffleStep::TransferY1, - role.peer(Direction::Left), + Direction::Left, batch_size, ) .await?; @@ -432,22 +303,7 @@ where let y_3_arg = add_single_shares(y_2, z_23); let y_3 = permute(&pi_23, y_3_arg); let c_hat_2 = add_single_shares(y_3, a_hat.clone()).collect::>(); - - let (c_hat_1, ()) = try_join!( - receive_from_peer( - ctx, - &OPRFShuffleStep::TransferCHat1, - role.peer(Direction::Left), - batch_size, - ), - send_to_peer( - ctx, - &OPRFShuffleStep::TransferCHat2, - role.peer(Direction::Left), - c_hat_2.clone(), - ) - )?; - + let c_hat_1 = exchange_c_hat(ctx, batch_size, c_hat_2.clone()).await?; let c_hat = add_single_shares(c_hat_1, c_hat_2); let res = combine_shares(c_hat, a_hat); Ok(res) @@ -493,7 +349,22 @@ where l.into_iter().zip(r).map(|(a, b)| a + b) } -fn generate_random_table_with_peer( +fn generate_random_tables_with_peers( + batch_size: u32, + ctx: &C, +) -> (Vec, Vec) { + let (step_left, step_right) = match ctx.role() { + Role::H1 => (OPRFShuffleStep::GenerateZ31, OPRFShuffleStep::GenerateZ12), + Role::H2 => (OPRFShuffleStep::GenerateZ12, OPRFShuffleStep::GenerateZ23), + Role::H3 => (OPRFShuffleStep::GenerateZ23, OPRFShuffleStep::GenerateZ12), + }; + + let with_left = generate_random_table(batch_size, ctx, &step_left, Direction::Left); + let with_right = generate_random_table(batch_size, ctx, &step_right, Direction::Right); + (with_left, with_right) +} + +fn generate_random_table( batch_size: u32, ctx: &C, step: &OPRFShuffleStep, @@ -526,9 +397,10 @@ where async fn send_to_peer>( ctx: &C, step: &OPRFShuffleStep, - role: Role, + direction: Direction, items: I, ) -> Result<(), Error> { + let role = ctx.role().peer(direction); let send_channel = ctx.narrow(step).send_channel(role); for (record_id, row) in items.into_iter().enumerate() { send_channel.send(RecordId::from(record_id), row).await?; @@ -539,9 +411,10 @@ async fn send_to_peer async fn receive_from_peer( ctx: &C, step: &OPRFShuffleStep, - role: Role, + direction: Direction, batch_size: u32, ) -> Result, Error> { + let role = ctx.role().peer(direction); let receive_channel: ReceivingEnd = ctx.narrow(step).recv_channel(role); let mut output: Vec = Vec::with_capacity(batch_size as usize); @@ -553,8 +426,48 @@ async fn receive_from_peer( Ok(output) } +async fn exchange_c_hat>( + ctx: &C, + batch_size: u32, + part_to_send: I, +) -> Result, Error> { + let (step_send, step_recv, dir) = match ctx.role() { + Role::H2 => ( + OPRFShuffleStep::TransferCHat1, + OPRFShuffleStep::TransferCHat2, + Direction::Right, + ), + Role::H3 => ( + OPRFShuffleStep::TransferCHat2, + OPRFShuffleStep::TransferCHat1, + Direction::Left, + ), + role => unreachable!("Role {:?} does not participate in C_hat computation", role), + }; + + let ((), received_part) = try_join!( + send_to_peer(ctx, &step_send, dir, part_to_send), + receive_from_peer(ctx, &step_recv, dir, batch_size), + )?; + + Ok(received_part) +} + // --------------------------- permutation-related function --------------------------------------------- // +fn generate_permutations_with_peers(batch_size: u32, ctx: &C) -> (Vec, Vec) { + let (step_left, step_right) = match &ctx.role() { + Role::H1 => (OPRFShuffleStep::GeneratePi31, OPRFShuffleStep::GeneratePi12), + Role::H2 => (OPRFShuffleStep::GeneratePi12, OPRFShuffleStep::GeneratePi23), + Role::H3 => (OPRFShuffleStep::GeneratePi23, OPRFShuffleStep::GeneratePi12), + }; + + let with_left = generate_pseudorandom_permutation(batch_size, ctx, &step_left, Direction::Left); + let with_right = + generate_pseudorandom_permutation(batch_size, ctx, &step_right, Direction::Right); + (with_left, with_right) +} + fn generate_pseudorandom_permutation( batch_size: u32, ctx: &C,