Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
andyleiserson committed Mar 11, 2024
1 parent fa59f74 commit 7f48867
Showing 1 changed file with 44 additions and 36 deletions.
80 changes: 44 additions & 36 deletions ipa-core/src/protocol/basics/reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ use futures::TryFutureExt;

use crate::{
error::Error,
helpers::{Direction, Role},
helpers::{Direction, MaybeFuture, Role},
protocol::{context::Context, RecordId},
secret_sharing::{
replicated::semi_honest::AdditiveShare as Replicated, SharedValue, Vectorizable,
},
seq_join::SeqJoin,
};
#[cfg(feature = "descriptive-gate")]
use crate::{
Expand All @@ -36,34 +35,34 @@ pub trait Reveal<C: Context, const N: usize>: Sized {
where
C: 'fut,
{
// Passing `left_out = None` guarantees any ok result is `Some`.
// Passing `excluded = None` guarantees any ok result is `Some`.
self.generic_reveal(ctx, record_id, None)
.map_ok(Option::unwrap)
}

/// partial reveal protocol to open a shared secret to all helpers except helper `left_out` inside the MPC ring.
/// Partial reveal protocol to open a shared secret to all helpers except helper `excluded` inside the MPC ring.
fn partial_reveal<'fut>(
&'fut self,
ctx: C,
record_id: RecordId,
left_out: Role,
excluded: Role,
) -> impl Future<Output = Result<Option<Self::Output>, Error>> + Send + 'fut
where
C: 'fut,
{
self.generic_reveal(ctx, record_id, Some(left_out))
self.generic_reveal(ctx, record_id, Some(excluded))
}

/// Generic reveal implementation usable for both `reveal` and `partial_reveal`.
///
/// When `left_out` is `None`, open a shared secret to all helpers in the MPC ring.
/// When `left_out` is `Some`, open a shared secret to all helpers except the helper
/// specified in `left_out`.
/// When `excluded` is `None`, open a shared secret to all helpers in the MPC ring.
/// When `excluded` is `Some`, open a shared secret to all helpers except the helper
/// specified in `excluded`.
fn generic_reveal<'fut>(
&'fut self,
ctx: C,
record_id: RecordId,
left_out: Option<Role>,
excluded: Option<Role>,
) -> impl Future<Output = Result<Option<Self::Output>, Error>> + Send + 'fut
where
C: 'fut;
Expand All @@ -90,22 +89,22 @@ impl<C: Context, V: SharedValue + Vectorizable<N>, const N: usize> Reveal<C, N>
&'fut self,
ctx: C,
record_id: RecordId,
left_out: Option<Role>,
excluded: Option<Role>,
) -> Result<Option<<V as Vectorizable<N>>::Array>, Error>
where
C: 'fut,
{
let left = self.left_arr();
let right = self.right_arr();

// send except to excluded helper (if any)
if Some(ctx.role().peer(Direction::Right)) != left_out {
// Send shares, unless the target helper is excluded
if Some(ctx.role().peer(Direction::Right)) != excluded {
ctx.send_channel::<<V as Vectorizable<N>>::Array>(ctx.role().peer(Direction::Right))
.send(record_id, left)
.await?;
}

if Some(ctx.role()) == left_out {
if Some(ctx.role()) == excluded {
Ok(None)
} else {
// Sleep until `helper's left` sends their share
Expand All @@ -131,7 +130,7 @@ impl<'a, F: ExtendableField> Reveal<UpgradedMaliciousContext<'a, F>, 1> for Mali
&'fut self,
ctx: UpgradedMaliciousContext<'a, F>,
record_id: RecordId,
left_out: Option<Role>,
excluded: Option<Role>,
) -> Result<Option<<F as Vectorizable<1>>::Array>, Error>
where
UpgradedMaliciousContext<'a, F>: 'fut,
Expand All @@ -146,16 +145,19 @@ impl<'a, F: ExtendableField> Reveal<UpgradedMaliciousContext<'a, F>, 1> for Mali
let right_sender = ctx.send_channel(ctx.role().peer(Direction::Right));
let right_receiver = ctx.recv_channel::<F>(ctx.role().peer(Direction::Right));

// Send share to helpers to the right and left
// send except to left_out
let send_left_fut = (Some(ctx.role().peer(Direction::Left)) != left_out)
.then(|| left_sender.send(record_id, right));
let send_right_fut = (Some(ctx.role().peer(Direction::Right)) != left_out)
.then(|| right_sender.send(record_id, left));
ctx.parallel_join(send_left_fut.into_iter().chain(send_right_fut))
.await?;
// Send shares to the left and right helpers, unless excluded.
let send_left_fut =
MaybeFuture::future_or_ok(Some(ctx.role().peer(Direction::Left)) != excluded, || {
left_sender.send(record_id, right)
});

let send_right_fut =
MaybeFuture::future_or_ok(Some(ctx.role().peer(Direction::Right)) != excluded, || {
right_sender.send(record_id, left)
});
try_join(send_left_fut, send_right_fut).await?;

if Some(ctx.role()) == left_out {
if Some(ctx.role()) == excluded {
Ok(None)
} else {
let (share_from_left, share_from_right) = try_join(
Expand Down Expand Up @@ -249,20 +251,20 @@ mod tests {
let mut rng = thread_rng();
let world = TestWorld::default();

for &left_out in Role::all() {
for &excluded in Role::all() {
let input = rng.gen::<TestField>();
let results = world
.semi_honest(input, |ctx, share| async move {
share
.partial_reveal(ctx.set_total_records(1), RecordId::from(0), left_out)
.partial_reveal(ctx.set_total_records(1), RecordId::from(0), excluded)
.await
.unwrap()
.map(|revealed| TestField::from_array(&revealed))
})
.await;

for &helper in Role::all() {
if helper == left_out {
if helper == excluded {
assert_eq!(None, results[helper]);
} else {
assert_eq!(Some(input), results[helper]);
Expand Down Expand Up @@ -342,7 +344,7 @@ mod tests {
let mut rng = thread_rng();
let world = TestWorld::default();

for &left_out in Role::all() {
for &excluded in Role::all() {
let sh_ctx = world.malicious_contexts();
let v = sh_ctx.map(UpgradableContext::validator);
let m_ctx: [_; 3] = v
Expand All @@ -364,15 +366,15 @@ mod tests {
let results = join_all(zip(m_ctx.clone().into_iter(), m_shares).map(
|(m_ctx, m_share)| async move {
m_share
.partial_reveal(m_ctx, record_id, left_out)
.partial_reveal(m_ctx, record_id, excluded)
.await
.unwrap()
},
))
.await;

for &helper in Role::all() {
if helper == left_out {
if helper == excluded {
assert_eq!(None, results[helper]);
} else {
assert_eq!(Some(input.into_array()), results[helper]);
Expand Down Expand Up @@ -419,11 +421,11 @@ mod tests {
.await;

assert!(matches!(result, Err(Error::MaliciousRevealFailed)));
})
});
}

#[test]
pub fn malicious_partial_validation_fail() {
pub fn malicious_partial_validation_fail() {
run(|| async {
let mut rng = thread_rng();
let world = TestWorld::default();
Expand All @@ -447,19 +449,25 @@ mod tests {
let result = try_join3(
m_shares[0].partial_reveal(m_ctx[0].clone(), record_id, Role::H3),
m_shares[1].partial_reveal(m_ctx[1].clone(), record_id, Role::H3),
reveal_with_additive_attack(m_ctx[2].clone(), record_id, &m_shares[2], true, Fp31::ONE),
reveal_with_additive_attack(
m_ctx[2].clone(),
record_id,
&m_shares[2],
true,
Fp31::ONE,
),
)
.await;

assert!(matches!(result, Err(Error::MaliciousRevealFailed)));
})
});
}

pub async fn reveal_with_additive_attack<F: ExtendableField>(
ctx: UpgradedMaliciousContext<'_, F>,
record_id: RecordId,
input: &MaliciousReplicated<F>,
left_out: bool,
excluded: bool,
additive_error: F,
) -> Result<Option<F>, Error> {
let (left, right) = input.x().access_without_downgrade().as_tuple();
Expand All @@ -475,7 +483,7 @@ mod tests {
)
.await?;

if left_out {
if excluded {
Ok(None)
} else {
let (share_from_left, _share_from_right): (F, F) =
Expand Down

0 comments on commit 7f48867

Please sign in to comment.