From c456668fc66f21ca381a85374d24e5f00a3f4306 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 3 Oct 2024 15:15:14 -0700 Subject: [PATCH 1/2] Support sharded malicious protocols inside test infrastructure This is the final PR in the series (#1333, #1315) that enables writing sharded malicious protocols. The support is added for `Runner::malicious` function, and it can be used to execute malicious circuits. Other methods like `upgraded_semi_honest` and `upgraded_malicious` haven't been touched yet - I am not convinced that these are necessary and we can update them as we go. The core goal here is to unblock efforts to implement Hybrid protocol --- ipa-core/src/protocol/basics/mul/malicious.rs | 13 +- ipa-core/src/protocol/context/malicious.rs | 32 ++-- ipa-core/src/protocol/context/mod.rs | 1 + ipa-core/src/protocol/context/semi_honest.rs | 6 +- ipa-core/src/test_fixture/world.rs | 139 +++++++++++++++--- 5 files changed, 151 insertions(+), 40 deletions(-) diff --git a/ipa-core/src/protocol/basics/mul/malicious.rs b/ipa-core/src/protocol/basics/mul/malicious.rs index e55d855d6..92bb6bee7 100644 --- a/ipa-core/src/protocol/basics/mul/malicious.rs +++ b/ipa-core/src/protocol/basics/mul/malicious.rs @@ -16,6 +16,7 @@ use crate::{ malicious::{AdditiveShare as MaliciousReplicated, ExtendableFieldSimd}, semi_honest::AdditiveShare as Replicated, }, + sharding::ShardBinding, }; /// @@ -49,8 +50,8 @@ use crate::{ /// back via the error response /// ## Panics /// Panics if the mutex is found to be poisoned -pub async fn mac_multiply( - ctx: UpgradedMaliciousContext<'_, F>, +pub async fn mac_multiply( + ctx: UpgradedMaliciousContext<'_, F, B>, record_id: RecordId, a: &MaliciousReplicated, b: &MaliciousReplicated, @@ -108,19 +109,19 @@ where /// Implement secure multiplication for malicious contexts with replicated secret sharing. #[async_trait] -impl<'a, F: ExtendableFieldSimd, const N: usize> SecureMul> - for MaliciousReplicated +impl<'a, F: ExtendableFieldSimd, B: ShardBinding, const N: usize> + SecureMul> for MaliciousReplicated where Replicated: FromPrss, { async fn multiply<'fut>( &self, rhs: &Self, - ctx: UpgradedMaliciousContext<'a, F>, + ctx: UpgradedMaliciousContext<'a, F, B>, record_id: RecordId, ) -> Result where - UpgradedMaliciousContext<'a, F>: 'fut, + UpgradedMaliciousContext<'a, F, B>: 'fut, { mac_multiply(ctx, record_id, self, rhs).await } diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index def33e950..ac008a19c 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -323,8 +323,10 @@ impl<'a, F: ExtendableField, B: ShardBinding> SeqJoin for Upgraded<'a, F, B> { /// protocols should be generic over `SecretShare` trait and not requiring this cast and taking /// `ProtocolContext<'a, S: SecretShare, F: Field>` as the context. If that is not possible, /// this implementation makes it easier to reinterpret the context as semi-honest. -impl<'a, F: ExtendableField> SpecialAccessToUpgradedContext for Upgraded<'a, F, NotSharded> { - type Base = Base<'a>; +impl<'a, F: ExtendableField, B: ShardBinding> SpecialAccessToUpgradedContext + for Upgraded<'a, F, B> +{ + type Base = Base<'a, B>; fn base_context(self) -> Self::Base { self.base_ctx.inner @@ -340,7 +342,7 @@ impl Debug for Upgraded<'_, F, B> { /// Upgrading a semi-honest replicated share using malicious context produces /// a MAC-secured share with the same vectorization factor. #[async_trait] -impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> +impl<'a, V: ExtendableFieldSimd, B: ShardBinding, const N: usize> Upgradable> for Replicated where Replicated<::ExtendedField, N>: FromPrss, @@ -349,7 +351,7 @@ where async fn upgrade( self, - ctx: Upgraded<'a, V, NotSharded>, + ctx: Upgraded<'a, V, B>, record_id: RecordId, ) -> Result { let ctx = ctx.narrow(&UpgradeStep); @@ -383,7 +385,7 @@ where #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V: ExtendableFieldSimd, const N: usize> Upgradable> +impl<'a, V: ExtendableFieldSimd, B: ShardBinding, const N: usize> Upgradable> for (Replicated, Replicated) where Replicated<::ExtendedField, N>: FromPrss, @@ -392,7 +394,7 @@ where async fn upgrade( self, - ctx: Upgraded<'a, V, NotSharded>, + ctx: Upgraded<'a, V, B>, record_id: RecordId, ) -> Result { let (l, r) = self; @@ -404,12 +406,12 @@ where #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V: ExtendableField> Upgradable> for () { +impl<'a, V: ExtendableField, B: ShardBinding> Upgradable> for () { type Output = (); async fn upgrade( self, - _context: Upgraded<'a, V, NotSharded>, + _context: Upgraded<'a, V, B>, _record_id: RecordId, ) -> Result { Ok(()) @@ -418,28 +420,30 @@ impl<'a, V: ExtendableField> Upgradable> for () { #[cfg(all(test, descriptive_gate))] #[async_trait] -impl<'a, V, U> Upgradable> for Vec +impl<'a, V, U, B> Upgradable> for Vec where V: ExtendableField, - U: Upgradable, Output: Send> + Send + 'a, + U: Upgradable, Output: Send> + Send + 'a, + B: ShardBinding, { type Output = Vec; async fn upgrade( self, - ctx: Upgraded<'a, V, NotSharded>, + ctx: Upgraded<'a, V, B>, record_id: RecordId, ) -> Result { /// Need a standalone function to avoid GAT issue that apparently can manifest /// even with `async_trait`. - fn upgrade_vec<'a, V, U>( - ctx: Upgraded<'a, V, NotSharded>, + fn upgrade_vec<'a, V, U, B>( + ctx: Upgraded<'a, V, B>, record_id: RecordId, input: Vec, ) -> impl std::future::Future, Error>> + 'a where V: ExtendableField, - U: Upgradable> + 'a, + U: Upgradable> + 'a, + B: ShardBinding, { let mut upgraded = Vec::with_capacity(input.len()); async move { diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 0651b74a4..627ffc0df 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -26,6 +26,7 @@ pub type SemiHonestContext<'a, B = NotSharded> = semi_honest::Context<'a, B>; pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>; pub type MaliciousContext<'a, B = NotSharded> = malicious::Context<'a, B>; +pub type ShardedMaliciousContext<'a> = malicious::Context<'a, Sharded>; pub type UpgradedMaliciousContext<'a, F, B = NotSharded> = malicious::Upgraded<'a, F, B>; #[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))] diff --git a/ipa-core/src/protocol/context/semi_honest.rs b/ipa-core/src/protocol/context/semi_honest.rs index 65f4e644e..bd8c2e260 100644 --- a/ipa-core/src/protocol/context/semi_honest.rs +++ b/ipa-core/src/protocol/context/semi_honest.rs @@ -302,14 +302,14 @@ impl Debug for Upgraded<'_, B, F> { } #[async_trait] -impl<'a, V: ExtendableField + Vectorizable, const N: usize> - Upgradable> for Replicated +impl<'a, V: ExtendableField + Vectorizable, B: ShardBinding, const N: usize> + Upgradable> for Replicated { type Output = Replicated; async fn upgrade( self, - _context: Upgraded<'a, NotSharded, V>, + _context: Upgraded<'a, B, V>, _record_id: RecordId, ) -> Result { Ok(self) diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 1c337b10e..b54d76abb 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -23,8 +23,8 @@ use crate::{ context::{ dzkp_validator::DZKPValidator, upgrade::Upgradable, Context, DZKPUpgradedMaliciousContext, MaliciousContext, SemiHonestContext, - ShardedSemiHonestContext, UpgradableContext, UpgradedContext, UpgradedMaliciousContext, - UpgradedSemiHonestContext, Validator, TEST_DZKP_STEPS, + ShardedMaliciousContext, ShardedSemiHonestContext, UpgradableContext, UpgradedContext, + UpgradedMaliciousContext, UpgradedSemiHonestContext, Validator, TEST_DZKP_STEPS, }, prss::Endpoint as PrssEndpoint, Gate, QueryId, RecordId, @@ -369,6 +369,10 @@ where pub trait Runner { /// This could be also derived from [`S`], but maybe that's too much for that trait. type SemiHonestContext<'ctx>: Context; + /// The type of context used to run protocols that are secure against + /// active adversaries. It varies depending on whether sharding is used or not. + type MaliciousContext<'ctx>: Context; + /// Run with a context that can be upgraded, but is only good for semi-honest. async fn semi_honest<'a, I, A, O, H, R>( &'a self, @@ -396,12 +400,12 @@ pub trait Runner { R: Future + Send; /// Run with a context that can be upgraded to malicious. - async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3] + async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> S::Container<[O; 3]> where - I: IntoShares + Send + 'static, + I: RunnerInput, A: Send, O: Send + Debug, - H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(Self::MaliciousContext<'a>, S::Container) -> R + Send + Sync, R: Future + Send; /// Run with a context that has already been upgraded to malicious. @@ -444,6 +448,7 @@ impl Runner> for TestWorld> { type SemiHonestContext<'ctx> = ShardedSemiHonestContext<'ctx>; + type MaliciousContext<'ctx> = ShardedMaliciousContext<'ctx>; async fn semi_honest<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> Vec<[O; 3]> where I: RunnerInput, A>, @@ -494,15 +499,39 @@ impl Runner> unimplemented!() } - async fn malicious<'a, I, A, O, H, R>(&'a self, _input: I, _helper_fn: H) -> [O; 3] + async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> Vec<[O; 3]> where - I: IntoShares + Send + 'static, + I: RunnerInput, A>, A: Send, O: Send + Debug, - H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn( + Self::MaliciousContext<'a>, + as ShardingScheme>::Container, + ) -> R + + Send + + Sync, R: Future + Send, { - unimplemented!() + let shards = self.shards(); + let [h1, h2, h3]: [[Vec; SHARDS]; 3] = input.share().map(D::distribute); + let gate = self.next_gate(); + // todo!() + + // No clippy, you're wrong, it is not redundant, it allows shard_fn to be `Copy` + #[allow(clippy::redundant_closure)] + let shard_fn = |ctx, input| helper_fn(ctx, input); + zip(shards.into_iter(), zip(zip(h1, h2), h3)) + .map(|(shard, ((h1, h2), h3))| { + ShardWorld::::run_either( + shard.malicious_contexts(&gate), + self.metrics_handle.span(), + [h1, h2, h3], + shard_fn, + ) + }) + .collect::>() + .collect::>() + .await } async fn upgraded_malicious<'a, F, I, A, M, O, H, R, P>( @@ -541,6 +570,7 @@ impl Runner> #[async_trait] impl Runner for TestWorld { type SemiHonestContext<'ctx> = SemiHonestContext<'ctx>; + type MaliciousContext<'ctx> = MaliciousContext<'ctx>; async fn semi_honest<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3] where @@ -583,10 +613,10 @@ impl Runner for TestWorld { async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3] where - I: IntoShares + Send + 'static, + I: RunnerInput, A: Send, O: Send + Debug, - H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(Self::MaliciousContext<'a>, A) -> R + Send + Sync, R: Future + Send, { ShardWorld::::run_either( @@ -778,9 +808,14 @@ impl ShardWorld { /// # Panics /// Panics if world has more or less than 3 gateways/participants #[must_use] - pub fn malicious_contexts(&self, gate: &Gate) -> [MaliciousContext<'_>; 3] { + pub fn malicious_contexts(&self, gate: &Gate) -> [MaliciousContext<'_, B>; 3] { zip3_ref(&self.participants, &self.gateways).map(|(participant, gateway)| { - MaliciousContext::new_with_gate(participant, gateway, gate.clone(), NotSharded) + MaliciousContext::new_with_gate( + participant, + gateway, + gate.clone(), + self.shard_info.clone(), + ) }) } } @@ -816,7 +851,8 @@ impl Distribute for Random { } } -#[cfg(all(test, unit_test))] +// #[cfg(all(test, unit_test))] +#[cfg(test)] mod tests { use std::{ collections::{HashMap, HashSet}, @@ -826,12 +862,20 @@ mod tests { use futures_util::future::try_join4; use crate::{ - ff::{boolean_array::BA3, Field, Fp31, U128Conversions}, + ff::{boolean::Boolean, boolean_array::BA3, Field, Fp31, U128Conversions}, helpers::{ in_memory_config::{MaliciousHelper, MaliciousHelperContext}, - Direction, Role, + Direction, Role, TotalRecords, + }, + protocol::{ + basics::SecureMul, + context::{ + dzkp_validator::DZKPValidator, upgrade::Upgradable, Context, UpgradableContext, + UpgradedContext, Validator, TEST_DZKP_STEPS, + }, + prss::SharedRandomness, + RecordId, }, - protocol::{context::Context, prss::SharedRandomness, RecordId}, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, SharedValue, @@ -961,4 +1005,65 @@ mod tests { assert_eq!(shares[1].right(), shares[2].left()); }); } + + #[test] + fn zkp_malicious_sharded() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let input = vec![Boolean::truncate_from(0_u32), Boolean::truncate_from(1_u32)]; + let r = world + .malicious(input.clone().into_iter(), |ctx, input| async move { + assert_eq!(1, input.iter().len()); + let ctx = ctx.set_total_records(TotalRecords::ONE); + let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 1); + let ctx = validator.context(); + let r = input[0] + .multiply(&input[0], ctx, RecordId::FIRST) + .await + .unwrap(); + validator.validate().await.unwrap(); + + vec![r] + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } + + #[test] + fn mac_malicious_sharded() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let input = vec![Fp31::truncate_from(0_u32), Fp31::truncate_from(1_u32)]; + let r = world + .malicious(input.clone().into_iter(), |ctx, input| async move { + assert_eq!(1, input.iter().len()); + let validator = ctx.set_total_records(1).validator(); + let ctx = validator.context(); + let (a_upgraded, b_upgraded) = (input[0].clone(), input[0].clone()) + .upgrade(ctx.clone(), RecordId::FIRST) + .await + .unwrap(); + let _ = a_upgraded + .multiply(&b_upgraded, ctx.narrow("multiply"), RecordId::FIRST) + .await + .unwrap(); + ctx.validate_record(RecordId::FIRST).await.unwrap(); + + input + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } } From 824ee6dff8989f9d0210c79c3643005c0110384a Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 3 Oct 2024 16:51:29 -0700 Subject: [PATCH 2/2] Fix compact gate tests --- ipa-core/src/test_fixture/world.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index b54d76abb..b2ea2759a 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -851,8 +851,7 @@ impl Distribute for Random { } } -// #[cfg(all(test, unit_test))] -#[cfg(test)] +#[cfg(all(test, unit_test))] mod tests { use std::{ collections::{HashMap, HashSet},