From 9292cc2209f15e141dbf8331c6912ce551b4d3f8 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Tue, 30 Jul 2024 16:55:01 -0700 Subject: [PATCH 1/2] Implement reveal for upgraded contexts only --- ipa-core/src/protocol/basics/check_zero.rs | 11 +- ipa-core/src/protocol/basics/mod.rs | 10 +- ipa-core/src/protocol/basics/reveal.rs | 225 +++++++++++++++------ ipa-core/src/protocol/context/mod.rs | 6 +- ipa-core/src/protocol/context/validator.rs | 11 +- ipa-core/src/protocol/ipa_prf/prf_eval.rs | 12 +- 6 files changed, 192 insertions(+), 83 deletions(-) diff --git a/ipa-core/src/protocol/basics/check_zero.rs b/ipa-core/src/protocol/basics/check_zero.rs index a6686acfd..c528cf363 100644 --- a/ipa-core/src/protocol/basics/check_zero.rs +++ b/ipa-core/src/protocol/basics/check_zero.rs @@ -2,7 +2,7 @@ use crate::{ error::Error, ff::Field, protocol::{ - basics::{mul::semi_honest_multiply, reveal::Reveal, step::CheckZeroStep as Step}, + basics::{malicious_reveal, mul::semi_honest_multiply, step::CheckZeroStep as Step}, context::Context, prss::{FromRandom, SharedRandomness}, RecordId, @@ -39,6 +39,9 @@ use crate::{ /// ## Errors /// Lots of things may go wrong here, from timeouts to bad output. They will be signalled /// back via the error response +/// ## Panics +/// If the full reveal of `rv_share` does not return a value, which would only happen if the +/// reveal implementation is broken. pub async fn check_zero(ctx: C, record_id: RecordId, v: &Replicated) -> Result where C: Context, @@ -49,9 +52,9 @@ where let rv_share = semi_honest_multiply(ctx.narrow(&Step::MultiplyWithR), record_id, &r_sharing, v).await?; let rv = F::from_array( - &rv_share - .reveal(ctx.narrow(&Step::RevealR), record_id) - .await?, + &malicious_reveal(ctx.narrow(&Step::RevealR), record_id, None, &rv_share) + .await? + .expect("full reveal should always return a value"), ); Ok(rv == F::ZERO) diff --git a/ipa-core/src/protocol/basics/mod.rs b/ipa-core/src/protocol/basics/mod.rs index a2cfff1fc..df3390482 100644 --- a/ipa-core/src/protocol/basics/mod.rs +++ b/ipa-core/src/protocol/basics/mod.rs @@ -13,7 +13,7 @@ pub use check_zero::check_zero; pub use if_else::select; pub use mul::{BooleanArrayMul, SecureMul}; pub use reshare::Reshare; -pub use reveal::{partial_reveal, reveal, Reveal}; +pub use reveal::{malicious_reveal, partial_reveal, reveal, semi_honest_reveal, Reveal}; pub use share_known_value::ShareKnownValue; use crate::{ @@ -22,7 +22,7 @@ use crate::{ protocol::{ context::{ Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, - SemiHonestContext, UpgradedSemiHonestContext, + UpgradedSemiHonestContext, }, ipa_prf::{AGG_CHUNK, PRF_CHUNK}, prss::FromPrss, @@ -85,12 +85,6 @@ impl<'a, B: ShardBinding> BooleanProtocols BooleanProtocols, AGG_CHUNK> - for AdditiveShare -{ -} - impl<'a, B: ShardBinding> BooleanProtocols, AGG_CHUNK> for AdditiveShare { diff --git a/ipa-core/src/protocol/basics/reveal.rs b/ipa-core/src/protocol/basics/reveal.rs index 1d59df603..d52f74bed 100644 --- a/ipa-core/src/protocol/basics/reveal.rs +++ b/ipa-core/src/protocol/basics/reveal.rs @@ -5,9 +5,13 @@ use futures::TryFutureExt; use crate::{ error::Error, + ff::boolean::Boolean, helpers::{Direction, MaybeFuture, Role}, protocol::{ - context::{Context, UpgradedMaliciousContext}, + context::{ + Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, + UpgradedMaliciousContext, UpgradedSemiHonestContext, + }, RecordId, }, secret_sharing::{ @@ -17,6 +21,7 @@ use crate::{ }, SharedValue, Vectorizable, }, + sharding::ShardBinding, }; /// Trait for reveal protocol to open a shared secret to all helpers inside the MPC ring. @@ -80,41 +85,77 @@ pub trait Reveal: Sized { /// Each helper sends their left share to the right helper. The helper then reconstructs their secret by adding the three shares /// i.e. their own shares and received share. #[embed_doc_image("reveal", "images/reveal.png")] -impl, const N: usize> Reveal - for Replicated +pub async fn semi_honest_reveal<'fut, C, V, const N: usize>( + ctx: C, + record_id: RecordId, + excluded: Option, + share: &'fut Replicated, +) -> Result>::Array>, Error> +where + C: Context + 'fut, + V: SharedValue + Vectorizable, +{ + let left = share.left_arr(); + let right = share.right_arr(); + + // Send shares, unless the target helper is excluded + if Some(ctx.role().peer(Direction::Right)) != excluded { + ctx.send_channel::<>::Array>(ctx.role().peer(Direction::Right)) + .send(record_id, left) + .await?; + } + + if Some(ctx.role()) == excluded { + Ok(None) + } else { + // Sleep until `helper's left` sends their share + let share: >::Array = ctx + .recv_channel(ctx.role().peer(Direction::Left)) + .receive(record_id) + .await?; + + Ok(Some(share + left + right)) + } +} + +impl<'a, B, V, const N: usize> Reveal, N> for Replicated +where + B: ShardBinding, + V: SharedValue + Vectorizable + ExtendableField, { type Output = >::Array; async fn generic_reveal<'fut>( &'fut self, - ctx: C, + ctx: UpgradedSemiHonestContext<'a, B, V>, record_id: RecordId, excluded: Option, - ) -> Result>::Array>, Error> + ) -> Result, Error> where - C: 'fut, + UpgradedSemiHonestContext<'a, B, V>: 'fut, { - let left = self.left_arr(); - let right = self.right_arr(); - - // Send shares, unless the target helper is excluded - if Some(ctx.role().peer(Direction::Right)) != excluded { - ctx.send_channel::<>::Array>(ctx.role().peer(Direction::Right)) - .send(record_id, left) - .await?; - } + semi_honest_reveal(ctx, record_id, excluded, self).await + } +} - if Some(ctx.role()) == excluded { - Ok(None) - } else { - // Sleep until `helper's left` sends their share - let share: >::Array = ctx - .recv_channel(ctx.role().peer(Direction::Left)) - .receive(record_id) - .await?; +impl<'a, B, const N: usize> Reveal, N> + for Replicated +where + B: ShardBinding, + Boolean: Vectorizable, +{ + type Output = >::Array; - Ok(Some(share + left + right)) - } + async fn generic_reveal<'fut>( + &'fut self, + ctx: DZKPUpgradedSemiHonestContext<'a, B>, + record_id: RecordId, + excluded: Option, + ) -> Result, Error> + where + DZKPUpgradedSemiHonestContext<'a, B>: 'fut, + { + semi_honest_reveal(ctx, record_id, excluded, self).await } } @@ -122,7 +163,62 @@ impl, const N: usize> Reveal /// It works similarly to semi-honest reveal, the key difference is that each helper sends its share /// to both helpers (right and left) and upon receiving 2 shares from peers it validates that they /// indeed match. -impl<'a, F: ExtendableField> Reveal, 1> for MaliciousReplicated { +pub async fn malicious_reveal<'fut, C, V, const N: usize>( + ctx: C, + record_id: RecordId, + excluded: Option, + share: &'fut Replicated, +) -> Result>::Array>, Error> +where + C: Context + 'fut, + V: SharedValue + Vectorizable, +{ + use futures::future::try_join; + + let left = share.left_arr(); + let right = share.right_arr(); + let left_sender = + ctx.send_channel::<>::Array>(ctx.role().peer(Direction::Left)); + let left_receiver = + ctx.recv_channel::<>::Array>(ctx.role().peer(Direction::Left)); + let right_sender = + ctx.send_channel::<>::Array>(ctx.role().peer(Direction::Right)); + let right_receiver = + ctx.recv_channel::<>::Array>(ctx.role().peer(Direction::Right)); + + // Send shares to the left and right helpers, unless excluded. + let send_left_fut = + MaybeFuture::future_or_ok(Some(ctx.role().peer(Direction::Left)) != excluded, || { + left_sender.send(record_id, right) + }); + + let send_right_fut = + MaybeFuture::future_or_ok(Some(ctx.role().peer(Direction::Right)) != excluded, || { + right_sender.send(record_id, left) + }); + try_join(send_left_fut, send_right_fut).await?; + + if Some(ctx.role()) == excluded { + Ok(None) + } else { + let (share_from_left, share_from_right) = try_join( + left_receiver.receive(record_id), + right_receiver.receive(record_id), + ) + .await?; + + if share_from_left == share_from_right { + Ok(Some(share_from_left + left + right)) + } else { + Err(Error::MaliciousRevealFailed) + } + } +} + +impl<'a, F> Reveal, 1> for Replicated +where + F: ExtendableField, +{ type Output = >::Array; async fn generic_reveal<'fut>( @@ -134,43 +230,48 @@ impl<'a, F: ExtendableField> Reveal, 1> for Mali where UpgradedMaliciousContext<'a, F>: 'fut, { - use futures::future::try_join; + malicious_reveal(ctx, record_id, excluded, self).await + } +} +impl<'a, F> Reveal, 1> for MaliciousReplicated +where + F: ExtendableField, +{ + type Output = >::Array; + + async fn generic_reveal<'fut>( + &'fut self, + ctx: UpgradedMaliciousContext<'a, F>, + record_id: RecordId, + excluded: Option, + ) -> Result>::Array>, Error> + where + UpgradedMaliciousContext<'a, F>: 'fut, + { use crate::secret_sharing::replicated::malicious::ThisCodeIsAuthorizedToDowngradeFromMalicious; - let (left, right) = self.x().access_without_downgrade().as_tuple(); - let left_sender = ctx.send_channel(ctx.role().peer(Direction::Left)); - let left_receiver = ctx.recv_channel::(ctx.role().peer(Direction::Left)); - let right_sender = ctx.send_channel(ctx.role().peer(Direction::Right)); - let right_receiver = ctx.recv_channel::(ctx.role().peer(Direction::Right)); - - // Send shares to the left and right helpers, unless excluded. - let send_left_fut = - MaybeFuture::future_or_ok(Some(ctx.role().peer(Direction::Left)) != excluded, || { - left_sender.send(record_id, right) - }); - - let send_right_fut = - MaybeFuture::future_or_ok(Some(ctx.role().peer(Direction::Right)) != excluded, || { - right_sender.send(record_id, left) - }); - try_join(send_left_fut, send_right_fut).await?; - - if Some(ctx.role()) == excluded { - Ok(None) - } else { - let (share_from_left, share_from_right) = try_join( - left_receiver.receive(record_id), - right_receiver.receive(record_id), - ) - .await?; + let x_share = self.x().access_without_downgrade(); + malicious_reveal(ctx, record_id, excluded, x_share).await + } +} - if share_from_left == share_from_right { - Ok(Some((left + right + share_from_left).into_array())) - } else { - Err(Error::MaliciousRevealFailed) - } - } +impl<'a, const N: usize> Reveal, N> for Replicated +where + Boolean: Vectorizable, +{ + type Output = >::Array; + + async fn generic_reveal<'fut>( + &'fut self, + ctx: DZKPUpgradedMaliciousContext<'a>, + record_id: RecordId, + excluded: Option, + ) -> Result, Error> + where + DZKPUpgradedMaliciousContext<'a>: 'fut, + { + malicious_reveal(ctx, record_id, excluded, self).await } } @@ -232,7 +333,7 @@ mod tests { let input = rng.gen::(); let results = world - .semi_honest(input, |ctx, share| async move { + .upgraded_semi_honest(input, |ctx, share| async move { TestField::from_array( &share .reveal(ctx.set_total_records(1), RecordId::from(0)) @@ -259,7 +360,7 @@ mod tests { for &excluded in Role::all() { let input = rng.gen::(); let results = world - .semi_honest(input, |ctx, share| async move { + .upgraded_semi_honest(input, |ctx, share| async move { share .partial_reveal(ctx.set_total_records(1), RecordId::from(0), excluded) .await @@ -289,7 +390,7 @@ mod tests { let input = rng.gen::(); let results = world - .semi_honest( + .upgraded_semi_honest( input, |ctx, share: AdditiveShare| async move { share diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 0f6cc79d8..c82fa500f 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -732,13 +732,13 @@ mod tests { let input_size = input.len(); let snapshot = world.metrics_snapshot(); - // Malicious protocol has an amplification factor of 3 and constant overhead of 3. For each input row it + // Malicious protocol has an amplification factor of 3 and constant overhead of 5. For each input row it // (input size) upgrades input to malicious // (input size) executes toy protocol // (input size) propagates u and w // (1) multiply r * share of zero - // (2) reveals r (1 for check_zero, 1 for validate) - let comm_factor = |input_size| 3 * input_size + 3; + // (4) reveals r (2 for check_zero, 2 for validate) + let comm_factor = |input_size| 3 * input_size + 5; let records_sent_assert = snapshot .assert_metric(RECORDS_SENT) .total(3 * comm_factor(input_size)) diff --git a/ipa-core/src/protocol/context/validator.rs b/ipa-core/src/protocol/context/validator.rs index c1845dcc1..66f1d8919 100644 --- a/ipa-core/src/protocol/context/validator.rs +++ b/ipa-core/src/protocol/context/validator.rs @@ -12,7 +12,7 @@ use crate::{ ff::Field, helpers::{Direction, TotalRecords}, protocol::{ - basics::Reveal, + basics::malicious_reveal, context::{ step::{MaliciousProtocolStep as Step, ValidateStep}, Base, Context, MaliciousContext, UpgradableContext, UpgradedMaliciousContext, @@ -194,7 +194,10 @@ pub struct Malicious<'a, F: ExtendableField> { } #[async_trait] -impl<'a, F: ExtendableField> Validator, F> for Malicious<'a, F> { +impl<'a, F> Validator, F> for Malicious<'a, F> +where + F: ExtendableField, +{ /// Get a copy of the context that can be used for malicious protocol execution. fn context<'b>(&'b self) -> UpgradedMaliciousContext<'a, F> { self.protocol_ctx.clone() @@ -219,7 +222,9 @@ impl<'a, F: ExtendableField> Validator, F> for Malicious<'a .narrow(&ValidateStep::RevealR) .set_total_records(TotalRecords::ONE); let r = ::ExtendedField::from_array( - &self.r_share.reveal(narrow_ctx, RecordId::FIRST).await?, + &malicious_reveal(narrow_ctx, RecordId::FIRST, None, &self.r_share) + .await? + .expect("full reveal should always return a value"), ); let t = u_share - &(w_share * r); diff --git a/ipa-core/src/protocol/ipa_prf/prf_eval.rs b/ipa-core/src/protocol/ipa_prf/prf_eval.rs index 8515188ee..116ea7b56 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_eval.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_eval.rs @@ -5,7 +5,7 @@ use crate::{ ff::{boolean::Boolean, curve_points::RP25519, ec_prime_field::Fp25519, Expand}, helpers::TotalRecords, protocol::{ - basics::{Reveal, SecureMul}, + basics::{malicious_reveal, SecureMul}, context::Context, ipa_prf::step::PrfStep as Step, prss::{FromPrss, SharedRandomness}, @@ -99,8 +99,14 @@ where let sh_gr = AdditiveShare::::from(sh_r); //reconstruct (z,R) - let gr = sh_gr.reveal(ctx.narrow(&Step::RevealR), record_id).await?; - let z = y.reveal(ctx.narrow(&Step::Revealz), record_id).await?; + // TODO: these should invoke reveal via the trait when this function + // takes a context of an appropriate type. + let gr = malicious_reveal(ctx.narrow(&Step::RevealR), record_id, None, &sh_gr) + .await? + .unwrap(); + let z = malicious_reveal(ctx.narrow(&Step::Revealz), record_id, None, &y) + .await? + .unwrap(); //compute R^(1/z) to u64 Ok(zip(gr, z) From eb28127c3707ae907c285aeb47bba291203c5ee9 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 31 Jul 2024 12:30:19 -0700 Subject: [PATCH 2/2] Test of malicious behavior during zero check protocol --- .../src/helpers/transport/in_memory/config.rs | 52 +++-- .../helpers/transport/in_memory/transport.rs | 20 +- ipa-core/src/protocol/basics/check_zero.rs | 186 +++++++++++++++++- ipa-core/src/protocol/basics/reshare.rs | 4 + ipa-core/src/protocol/basics/reveal.rs | 6 +- ipa-core/src/test_fixture/world.rs | 5 +- 6 files changed, 246 insertions(+), 27 deletions(-) diff --git a/ipa-core/src/helpers/transport/in_memory/config.rs b/ipa-core/src/helpers/transport/in_memory/config.rs index 95834c318..893191bfb 100644 --- a/ipa-core/src/helpers/transport/in_memory/config.rs +++ b/ipa-core/src/helpers/transport/in_memory/config.rs @@ -1,4 +1,6 @@ -use std::borrow::Cow; +use std::{borrow::Cow, future::ready, pin::Pin}; + +use futures::{Future, FutureExt}; use crate::{ helpers::{HelperIdentity, Role, RoleAssignment}, @@ -42,19 +44,33 @@ pub trait StreamInterceptor: Send + Sync { /// from additive attacks without additional measures implemented /// at the transport layer, like checksumming, share consistency /// checks, etc. - fn peek(&self, ctx: &Self::Context, data: &mut Vec); + fn peek<'a>( + &'a self, + ctx: &'a Self::Context, + data: &'a mut Vec, + ) -> Pin + Send + 'a>>; } -impl) + Send + Sync + 'static> StreamInterceptor for F { +impl StreamInterceptor for F +where + for<'a> F: Fn(&'a InspectContext, &'a mut Vec) -> Pin + Send>> + + Send + + Sync + + 'a, +{ type Context = InspectContext; - fn peek(&self, ctx: &Self::Context, data: &mut Vec) { - (self)(ctx, data); + fn peek( + &self, + ctx: &Self::Context, + data: &mut Vec, + ) -> Pin + Send>> { + (self)(ctx, data) } } /// The general context provided to stream inspectors. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct InspectContext { /// The shard index of this instance. /// This is `None` for non-sharded helpers. @@ -76,7 +92,7 @@ pub struct InspectContext { #[inline] #[must_use] pub fn passthrough() -> Arc> { - Arc::new(|_ctx: &InspectContext, _data: &mut Vec| {}) + Arc::new(|_ctx: &InspectContext, _data: &mut Vec| ready(()).boxed()) } /// This narrows the implementation of stream seeker @@ -93,7 +109,10 @@ pub struct MaliciousHelper { inner: F, } -impl) + Send + Sync> MaliciousHelper { +impl MaliciousHelper +where + F: Fn(MaliciousHelperContext, &mut Vec) -> Pin + Send>>, +{ pub fn new(role: Role, role_assignment: &RoleAssignment, peeker: F) -> Arc { Arc::new(Self { identity: role_assignment.identity(role), @@ -131,14 +150,23 @@ pub struct MaliciousHelperContext { pub gate: Gate, } -impl) + Send + Sync> StreamInterceptor - for MaliciousHelper +impl StreamInterceptor for MaliciousHelper +where + F: Fn(MaliciousHelperContext, &mut Vec) -> Pin + Send>> + + Send + + Sync, { type Context = InspectContext; - fn peek(&self, ctx: &Self::Context, data: &mut Vec) { + fn peek( + &self, + ctx: &Self::Context, + data: &mut Vec, + ) -> Pin + Send>> { if ctx.identity == self.identity { - (self.inner)(&self.context(ctx), data); + (self.inner)(self.context(ctx), data) + } else { + ready(()).boxed() } } } diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 3c1a9e926..794c29d0e 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, fmt::{Debug, Formatter}, + future::ready, io, pin::Pin, task::{Context, Poll}, @@ -12,7 +13,7 @@ use ::tokio::sync::{ }; use async_trait::async_trait; use bytes::Bytes; -use futures::{Stream, StreamExt}; +use futures::{FutureExt, Stream, StreamExt}; #[cfg(all(feature = "shuttle", test))] use shuttle::future as tokio; use tokio_stream::wrappers::ReceiverStream; @@ -21,8 +22,7 @@ use tracing::Instrument; use crate::{ error::BoxError, helpers::{ - in_memory_config, - in_memory_config::DynStreamInterceptor, + in_memory_config::{self, DynStreamInterceptor}, transport::{ in_memory::config::InspectContext, routing::{Addr, RouteId}, @@ -195,12 +195,16 @@ impl Transport for Weak> { channel .send(( addr, - InMemoryStream::wrap(data.map({ - move |mut chunk| { - if let Some(ref context) = context { - this.config.stream_interceptor.peek(context, &mut chunk); + InMemoryStream::wrap(data.then(move |mut chunk| { + let interceptor = Arc::clone(&this.config.stream_interceptor); + if let Some(context) = context.clone() { + async move { + interceptor.peek(&context, &mut chunk).await; + Ok(Bytes::from(chunk)) } - Ok(Bytes::from(chunk)) + .boxed() + } else { + ready(Ok(Bytes::from(chunk))).boxed() } })), ack_tx, diff --git a/ipa-core/src/protocol/basics/check_zero.rs b/ipa-core/src/protocol/basics/check_zero.rs index c528cf363..1cd69f14f 100644 --- a/ipa-core/src/protocol/basics/check_zero.rs +++ b/ipa-core/src/protocol/basics/check_zero.rs @@ -10,6 +10,10 @@ use crate::{ secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, }; +#[cfg(test)] +static SH1: once_cell::sync::Lazy> = + once_cell::sync::Lazy::new(|| tests::NotifyOnceCell::new()); + /// A very simple protocol to check if a replicated secret sharing is a sharing of zero. /// /// NOTE: this protocol leaks information about `v` the helpers. Please only use this in cases where @@ -51,26 +55,82 @@ where let rv_share = semi_honest_multiply(ctx.narrow(&Step::MultiplyWithR), record_id, &r_sharing, v).await?; + tracing::info!("{:?}", &rv_share); let rv = F::from_array( &malicious_reveal(ctx.narrow(&Step::RevealR), record_id, None, &rv_share) .await? .expect("full reveal should always return a value"), ); + tracing::info!("{:?}", &rv); Ok(rv == F::ZERO) } +#[cfg(test)] +pub async fn check_zero_fp32bitprime( + ctx: C, + record_id: RecordId, + v: &Replicated, +) -> Result +where + C: Context, +{ + use crate::{ + ff::Fp32BitPrime, + secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue}, + }; + + let r_sharing: Replicated = ctx.prss().generate(record_id); + + let rv_share = + semi_honest_multiply(ctx.narrow(&Step::MultiplyWithR), record_id, &r_sharing, v).await?; + tracing::info!("{:?}", &rv_share); + if ctx.role() == crate::helpers::Role::H1 { + SH1.set(rv_share.right()).unwrap(); + tracing::info!("sent sh1 = {:?}", rv_share.right()); + } + let rv = Fp32BitPrime::from_array( + &malicious_reveal(ctx.narrow(&Step::RevealR), record_id, None, &rv_share) + .await? + .expect("full reveal should always return a value"), + ); + + Ok(rv == Fp32BitPrime::ZERO) +} + #[cfg(all(test, unit_test))] mod tests { - use futures_util::future::try_join3; + use std::{ + future::ready, + pin::Pin, + sync::{Arc, Mutex}, + }; + + use futures::{ + future::{join, try_join3}, + Future, FutureExt, + }; + use generic_array::GenericArray; + use once_cell::sync::OnceCell; + use rand::Rng; + use tokio::sync::Notify; + use typenum::U4; use crate::{ error::Error, - ff::{Fp31, PrimeField, U128Conversions}, - protocol::{basics::check_zero, context::Context, RecordId}, + ff::{Fp31, Fp32BitPrime, PrimeField, Serializable, U128Conversions}, + helpers::{ + in_memory_config::{InspectContext, StreamInterceptor}, + HelperIdentity, TransportIdentity, + }, + protocol::{ + basics::check_zero::{check_zero, check_zero_fp32bitprime}, + context::Context, + RecordId, + }, rand::thread_rng, secret_sharing::{IntoShares, SharedValue}, - test_fixture::TestWorld, + test_fixture::{Runner, TestWorld, TestWorldConfig}, }; #[tokio::test] @@ -127,4 +187,122 @@ mod tests { Ok(()) } + + pub struct NotifyOnceCell { + inner: Mutex>, + } + + struct NotifyOnceCellInner { + cell: OnceCell, + notify: Arc, + } + + impl NotifyOnceCell { + pub fn new() -> Self { + Self { + inner: Mutex::new(NotifyOnceCellInner { + cell: OnceCell::new(), + notify: Arc::new(Notify::new()), + }), + } + } + + pub fn set(&self, value: T) -> Result<(), T> { + let inner = self.inner.lock().unwrap(); + inner.cell.set(value)?; + inner.notify.notify_waiters(); + Ok(()) + } + + pub fn get(&self) -> Pin + Send + '_>> { + let inner = self.inner.lock().unwrap(); + if let Some(value) = inner.cell.get() { + return ready(value.clone()).boxed(); + } + let notify = inner.notify.clone(); + async move { + notify.notified().await; + self.inner.lock().unwrap().cell.get().unwrap().clone() + } + .boxed() + } + } + + struct MaliciousCheckZeroInterceptor { + sh2: NotifyOnceCell, + } + + impl MaliciousCheckZeroInterceptor { + fn new() -> Self { + Self { + sh2: NotifyOnceCell::new(), + } + } + } + + impl StreamInterceptor for MaliciousCheckZeroInterceptor { + type Context = InspectContext; + + fn peek<'a>( + &'a self, + ctx: &'a Self::Context, + data: &'a mut Vec, + ) -> Pin + Send + 'a>> { + if ctx + .gate + .as_ref() + .contains(super::Step::MultiplyWithR.as_ref()) + && ctx.identity == HelperIdentity::ONE + || ctx.gate.as_ref().contains(super::Step::RevealR.as_ref()) + && ctx.identity == HelperIdentity::ONE + && ctx.dest == HelperIdentity::TWO.as_str() + { + async { + assert_eq!(data.len(), 4); + let (sh1, sh2) = join(super::SH1.get(), self.sh2.get()).await; + tracing::info!("got shares: {sh1:?} {sh2:?}"); + let adjusted_share = -sh1 - sh2; + tracing::info!("adjusted share {adjusted_share:?}"); + adjusted_share.serialize( + <&mut GenericArray>::try_from(data.as_mut_slice()).unwrap(), + ); + } + .boxed() + } else if ctx.gate.as_ref().contains(super::Step::RevealR.as_ref()) + && ctx.identity == HelperIdentity::TWO + && ctx.dest == HelperIdentity::ONE.as_str() + { + assert_eq!(data.len(), 4); + let sh2 = Fp32BitPrime::deserialize_unchecked( + <&GenericArray>::try_from(data.as_slice()).unwrap(), + ); + self.sh2.set(sh2).unwrap(); + tracing::info!("sent sh2 = {sh2:?}"); + ready(()).boxed() + } else { + ready(()).boxed() + } + } + } + + #[tokio::test] + async fn malicious_check_zero() { + let mut config = TestWorldConfig::default(); + config.stream_interceptor = Arc::new(MaliciousCheckZeroInterceptor::new()); + let world = TestWorld::new_with(&config); + let mut rng = thread_rng(); + let v = rng.gen::(); + + let [res0, res1, res2] = world + .semi_honest(v, |ctx, v| async move { + check_zero_fp32bitprime(ctx.set_total_records(1), RecordId::FIRST, &v) + .await + .unwrap() + }) + .await; + + assert_eq!(res0, false, "zero check failed on H1"); + assert_eq!(res1, false, "zero check failed on H2"); + assert_eq!(res2, false, "zero check failed on H3"); + } } diff --git a/ipa-core/src/protocol/basics/reshare.rs b/ipa-core/src/protocol/basics/reshare.rs index f8e3019f5..8154a7496 100644 --- a/ipa-core/src/protocol/basics/reshare.rs +++ b/ipa-core/src/protocol/basics/reshare.rs @@ -200,6 +200,9 @@ mod tests { mod malicious { + use std::future::ready; + + use futures::FutureExt; use rand::{distributions::Standard, prelude::Distribution}; use crate::{ @@ -290,6 +293,7 @@ mod tests { corrupt(data, small_value); } } + ready(()).boxed() }, ); let world = TestWorld::new_with(&config); diff --git a/ipa-core/src/protocol/basics/reveal.rs b/ipa-core/src/protocol/basics/reveal.rs index d52f74bed..ae6186f73 100644 --- a/ipa-core/src/protocol/basics/reveal.rs +++ b/ipa-core/src/protocol/basics/reveal.rs @@ -207,6 +207,7 @@ where ) .await?; + tracing::info!("reveal ({:?}): left {left:?} right {right:?} from left {share_from_left:?} from right {share_from_right:?}", ctx.role()); if share_from_left == share_from_right { Ok(Some(share_from_left + left + right)) } else { @@ -305,9 +306,9 @@ where #[cfg(all(test, unit_test))] mod tests { - use std::iter::zip; + use std::{future::ready, iter::zip}; - use futures::future::join_all; + use futures::{future::join_all, FutureExt}; use crate::{ error::Error, @@ -510,6 +511,7 @@ mod tests { let v = Fp31::deserialize_from_slice(data) + Fp31::ONE; v.serialize_to_slice(data); } + ready(()).boxed() }); let world = TestWorld::new_with(config); diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 60fb642be..6aaba4db9 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -750,9 +750,11 @@ impl Distribute for Random { mod tests { use std::{ collections::{HashMap, HashSet}, + future::ready, sync::{Arc, Mutex}, }; + use futures::FutureExt; use futures_util::future::try_join4; use crate::{ @@ -847,10 +849,11 @@ mod tests { config.stream_interceptor = MaliciousHelper::new( Role::H1, config.role_assignment(), - |ctx: &MaliciousHelperContext, data: &mut Vec| { + |ctx: MaliciousHelperContext, data: &mut Vec| { if ctx.gate.as_ref().contains(STEP) { corrupt_byte(&mut data[0]); } + ready(()).boxed() }, );