Skip to content

Commit

Permalink
Merge pull request #1335 from akoshelev/malicious-sharded-context-3
Browse files Browse the repository at this point in the history
Support sharded malicious protocols inside the test infrastructure
  • Loading branch information
akoshelev authored Oct 4, 2024
2 parents 27de807 + 824ee6d commit b56e21c
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 39 deletions.
13 changes: 7 additions & 6 deletions ipa-core/src/protocol/basics/mul/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{
malicious::{AdditiveShare as MaliciousReplicated, ExtendableFieldSimd},
semi_honest::AdditiveShare as Replicated,
},
sharding::ShardBinding,
};

///
Expand Down Expand Up @@ -49,8 +50,8 @@ use crate::{
/// back via the error response
/// ## Panics
/// Panics if the mutex is found to be poisoned
pub async fn mac_multiply<F, const N: usize>(
ctx: UpgradedMaliciousContext<'_, F>,
pub async fn mac_multiply<F, B: ShardBinding, const N: usize>(
ctx: UpgradedMaliciousContext<'_, F, B>,
record_id: RecordId,
a: &MaliciousReplicated<F, N>,
b: &MaliciousReplicated<F, N>,
Expand Down Expand Up @@ -108,19 +109,19 @@ where

/// Implement secure multiplication for malicious contexts with replicated secret sharing.
#[async_trait]
impl<'a, F: ExtendableFieldSimd<N>, const N: usize> SecureMul<UpgradedMaliciousContext<'a, F>>
for MaliciousReplicated<F, N>
impl<'a, F: ExtendableFieldSimd<N>, B: ShardBinding, const N: usize>
SecureMul<UpgradedMaliciousContext<'a, F, B>> for MaliciousReplicated<F, N>
where
Replicated<F::ExtendedField, N>: FromPrss,
{
async fn multiply<'fut>(
&self,
rhs: &Self,
ctx: UpgradedMaliciousContext<'a, F>,
ctx: UpgradedMaliciousContext<'a, F, B>,
record_id: RecordId,
) -> Result<Self, Error>
where
UpgradedMaliciousContext<'a, F>: 'fut,
UpgradedMaliciousContext<'a, F, B>: 'fut,
{
mac_multiply(ctx, record_id, self, rhs).await
}
Expand Down
32 changes: 18 additions & 14 deletions ipa-core/src/protocol/context/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,10 @@ impl<'a, F: ExtendableField, B: ShardBinding> SeqJoin for Upgraded<'a, F, B> {
/// protocols should be generic over `SecretShare` trait and not requiring this cast and taking
/// `ProtocolContext<'a, S: SecretShare<F>, F: Field>` as the context. If that is not possible,
/// this implementation makes it easier to reinterpret the context as semi-honest.
impl<'a, F: ExtendableField> SpecialAccessToUpgradedContext<F> for Upgraded<'a, F, NotSharded> {
type Base = Base<'a>;
impl<'a, F: ExtendableField, B: ShardBinding> SpecialAccessToUpgradedContext<F>
for Upgraded<'a, F, B>
{
type Base = Base<'a, B>;

fn base_context(self) -> Self::Base {
self.base_ctx.inner
Expand All @@ -340,7 +342,7 @@ impl<F: ExtendableField, B: ShardBinding> Debug for Upgraded<'_, F, B> {
/// Upgrading a semi-honest replicated share using malicious context produces
/// a MAC-secured share with the same vectorization factor.
#[async_trait]
impl<'a, V: ExtendableFieldSimd<N>, const N: usize> Upgradable<Upgraded<'a, V, NotSharded>>
impl<'a, V: ExtendableFieldSimd<N>, B: ShardBinding, const N: usize> Upgradable<Upgraded<'a, V, B>>
for Replicated<V, N>
where
Replicated<<V as ExtendableField>::ExtendedField, N>: FromPrss,
Expand All @@ -349,7 +351,7 @@ where

async fn upgrade(
self,
ctx: Upgraded<'a, V, NotSharded>,
ctx: Upgraded<'a, V, B>,
record_id: RecordId,
) -> Result<Self::Output, Error> {
let ctx = ctx.narrow(&UpgradeStep);
Expand Down Expand Up @@ -383,7 +385,7 @@ where
#[cfg(all(test, descriptive_gate))]
#[async_trait]
impl<'a, V: ExtendableFieldSimd<N>, const N: usize> Upgradable<Upgraded<'a, V, NotSharded>>
impl<'a, V: ExtendableFieldSimd<N>, B: ShardBinding, const N: usize> Upgradable<Upgraded<'a, V, B>>
for (Replicated<V, N>, Replicated<V, N>)
where
Replicated<<V as ExtendableField>::ExtendedField, N>: FromPrss,
Expand All @@ -392,7 +394,7 @@ where

async fn upgrade(
self,
ctx: Upgraded<'a, V, NotSharded>,
ctx: Upgraded<'a, V, B>,
record_id: RecordId,
) -> Result<Self::Output, Error> {
let (l, r) = self;
Expand All @@ -404,12 +406,12 @@ where

#[cfg(all(test, descriptive_gate))]
#[async_trait]
impl<'a, V: ExtendableField> Upgradable<Upgraded<'a, V, NotSharded>> for () {
impl<'a, V: ExtendableField, B: ShardBinding> Upgradable<Upgraded<'a, V, B>> for () {
type Output = ();

async fn upgrade(
self,
_context: Upgraded<'a, V, NotSharded>,
_context: Upgraded<'a, V, B>,
_record_id: RecordId,
) -> Result<Self::Output, Error> {
Ok(())
Expand All @@ -418,28 +420,30 @@ impl<'a, V: ExtendableField> Upgradable<Upgraded<'a, V, NotSharded>> for () {

#[cfg(all(test, descriptive_gate))]
#[async_trait]
impl<'a, V, U> Upgradable<Upgraded<'a, V, NotSharded>> for Vec<U>
impl<'a, V, U, B> Upgradable<Upgraded<'a, V, B>> for Vec<U>
where
V: ExtendableField,
U: Upgradable<Upgraded<'a, V, NotSharded>, Output: Send> + Send + 'a,
U: Upgradable<Upgraded<'a, V, B>, Output: Send> + Send + 'a,
B: ShardBinding,
{
type Output = Vec<U::Output>;

async fn upgrade(
self,
ctx: Upgraded<'a, V, NotSharded>,
ctx: Upgraded<'a, V, B>,
record_id: RecordId,
) -> Result<Self::Output, Error> {
/// Need a standalone function to avoid GAT issue that apparently can manifest
/// even with `async_trait`.
fn upgrade_vec<'a, V, U>(
ctx: Upgraded<'a, V, NotSharded>,
fn upgrade_vec<'a, V, U, B>(
ctx: Upgraded<'a, V, B>,
record_id: RecordId,
input: Vec<U>,
) -> impl std::future::Future<Output = Result<Vec<U::Output>, Error>> + 'a
where
V: ExtendableField,
U: Upgradable<Upgraded<'a, V, NotSharded>> + 'a,
U: Upgradable<Upgraded<'a, V, B>> + 'a,
B: ShardBinding,
{
let mut upgraded = Vec::with_capacity(input.len());
async move {
Expand Down
1 change: 1 addition & 0 deletions ipa-core/src/protocol/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub type SemiHonestContext<'a, B = NotSharded> = semi_honest::Context<'a, B>;
pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>;

pub type MaliciousContext<'a, B = NotSharded> = malicious::Context<'a, B>;
pub type ShardedMaliciousContext<'a> = malicious::Context<'a, Sharded>;
pub type UpgradedMaliciousContext<'a, F, B = NotSharded> = malicious::Upgraded<'a, F, B>;

#[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))]
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/protocol/context/semi_honest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,14 @@ impl<B: ShardBinding, F: ExtendableField> Debug for Upgraded<'_, B, F> {
}

#[async_trait]
impl<'a, V: ExtendableField + Vectorizable<N>, const N: usize>
Upgradable<Upgraded<'a, NotSharded, V>> for Replicated<V, N>
impl<'a, V: ExtendableField + Vectorizable<N>, B: ShardBinding, const N: usize>
Upgradable<Upgraded<'a, B, V>> for Replicated<V, N>
{
type Output = Replicated<V, N>;

async fn upgrade(
self,
_context: Upgraded<'a, NotSharded, V>,
_context: Upgraded<'a, B, V>,
_record_id: RecordId,
) -> Result<Self::Output, Error> {
Ok(self)
Expand Down
136 changes: 120 additions & 16 deletions ipa-core/src/test_fixture/world.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use crate::{
context::{
dzkp_validator::DZKPValidator, upgrade::Upgradable, Context,
DZKPUpgradedMaliciousContext, MaliciousContext, SemiHonestContext,
ShardedSemiHonestContext, UpgradableContext, UpgradedContext, UpgradedMaliciousContext,
UpgradedSemiHonestContext, Validator, TEST_DZKP_STEPS,
ShardedMaliciousContext, ShardedSemiHonestContext, UpgradableContext, UpgradedContext,
UpgradedMaliciousContext, UpgradedSemiHonestContext, Validator, TEST_DZKP_STEPS,
},
prss::Endpoint as PrssEndpoint,
Gate, QueryId, RecordId,
Expand Down Expand Up @@ -369,6 +369,10 @@ where
pub trait Runner<S: ShardingScheme> {
/// This could be also derived from [`S`], but maybe that's too much for that trait.
type SemiHonestContext<'ctx>: Context;
/// The type of context used to run protocols that are secure against
/// active adversaries. It varies depending on whether sharding is used or not.
type MaliciousContext<'ctx>: Context;

/// Run with a context that can be upgraded, but is only good for semi-honest.
async fn semi_honest<'a, I, A, O, H, R>(
&'a self,
Expand Down Expand Up @@ -396,12 +400,12 @@ pub trait Runner<S: ShardingScheme> {
R: Future<Output = O> + Send;

/// Run with a context that can be upgraded to malicious.
async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3]
async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> S::Container<[O; 3]>
where
I: IntoShares<A> + Send + 'static,
I: RunnerInput<S, A>,
A: Send,
O: Send + Debug,
H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync,
H: Fn(Self::MaliciousContext<'a>, S::Container<A>) -> R + Send + Sync,
R: Future<Output = O> + Send;

/// Run with a context that has already been upgraded to malicious.
Expand Down Expand Up @@ -444,6 +448,7 @@ impl<const SHARDS: usize, D: Distribute> Runner<WithShards<SHARDS, D>>
for TestWorld<WithShards<SHARDS, D>>
{
type SemiHonestContext<'ctx> = ShardedSemiHonestContext<'ctx>;
type MaliciousContext<'ctx> = ShardedMaliciousContext<'ctx>;
async fn semi_honest<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> Vec<[O; 3]>
where
I: RunnerInput<WithShards<SHARDS, D>, A>,
Expand Down Expand Up @@ -494,15 +499,39 @@ impl<const SHARDS: usize, D: Distribute> Runner<WithShards<SHARDS, D>>
unimplemented!()
}

async fn malicious<'a, I, A, O, H, R>(&'a self, _input: I, _helper_fn: H) -> [O; 3]
async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> Vec<[O; 3]>
where
I: IntoShares<A> + Send + 'static,
I: RunnerInput<WithShards<SHARDS, D>, A>,
A: Send,
O: Send + Debug,
H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync,
H: Fn(
Self::MaliciousContext<'a>,
<WithShards<SHARDS> as ShardingScheme>::Container<A>,
) -> R
+ Send
+ Sync,
R: Future<Output = O> + Send,
{
unimplemented!()
let shards = self.shards();
let [h1, h2, h3]: [[Vec<A>; SHARDS]; 3] = input.share().map(D::distribute);
let gate = self.next_gate();
// todo!()

// No clippy, you're wrong, it is not redundant, it allows shard_fn to be `Copy`
#[allow(clippy::redundant_closure)]
let shard_fn = |ctx, input| helper_fn(ctx, input);
zip(shards.into_iter(), zip(zip(h1, h2), h3))
.map(|(shard, ((h1, h2), h3))| {
ShardWorld::<Sharded>::run_either(
shard.malicious_contexts(&gate),
self.metrics_handle.span(),
[h1, h2, h3],
shard_fn,
)
})
.collect::<FuturesOrdered<_>>()
.collect::<Vec<_>>()
.await
}

async fn upgraded_malicious<'a, F, I, A, M, O, H, R, P>(
Expand Down Expand Up @@ -541,6 +570,7 @@ impl<const SHARDS: usize, D: Distribute> Runner<WithShards<SHARDS, D>>
#[async_trait]
impl Runner<NotSharded> for TestWorld<NotSharded> {
type SemiHonestContext<'ctx> = SemiHonestContext<'ctx>;
type MaliciousContext<'ctx> = MaliciousContext<'ctx>;

async fn semi_honest<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3]
where
Expand Down Expand Up @@ -583,10 +613,10 @@ impl Runner<NotSharded> for TestWorld<NotSharded> {

async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> [O; 3]
where
I: IntoShares<A> + Send + 'static,
I: RunnerInput<NotSharded, A>,
A: Send,
O: Send + Debug,
H: Fn(MaliciousContext<'a>, A) -> R + Send + Sync,
H: Fn(Self::MaliciousContext<'a>, A) -> R + Send + Sync,
R: Future<Output = O> + Send,
{
ShardWorld::<NotSharded>::run_either(
Expand Down Expand Up @@ -778,9 +808,14 @@ impl<B: ShardBinding> ShardWorld<B> {
/// # Panics
/// Panics if world has more or less than 3 gateways/participants
#[must_use]
pub fn malicious_contexts(&self, gate: &Gate) -> [MaliciousContext<'_>; 3] {
pub fn malicious_contexts(&self, gate: &Gate) -> [MaliciousContext<'_, B>; 3] {
zip3_ref(&self.participants, &self.gateways).map(|(participant, gateway)| {
MaliciousContext::new_with_gate(participant, gateway, gate.clone(), NotSharded)
MaliciousContext::new_with_gate(
participant,
gateway,
gate.clone(),
self.shard_info.clone(),
)
})
}
}
Expand Down Expand Up @@ -826,12 +861,20 @@ mod tests {
use futures_util::future::try_join4;

use crate::{
ff::{boolean_array::BA3, Field, Fp31, U128Conversions},
ff::{boolean::Boolean, boolean_array::BA3, Field, Fp31, U128Conversions},
helpers::{
in_memory_config::{MaliciousHelper, MaliciousHelperContext},
Direction, Role,
Direction, Role, TotalRecords,
},
protocol::{
basics::SecureMul,
context::{
dzkp_validator::DZKPValidator, upgrade::Upgradable, Context, UpgradableContext,
UpgradedContext, Validator, TEST_DZKP_STEPS,
},
prss::SharedRandomness,
RecordId,
},
protocol::{context::Context, prss::SharedRandomness, RecordId},
secret_sharing::{
replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing},
SharedValue,
Expand Down Expand Up @@ -961,4 +1004,65 @@ mod tests {
assert_eq!(shares[1].right(), shares[2].left());
});
}

#[test]
fn zkp_malicious_sharded() {
run(|| async {
let world: TestWorld<WithShards<2>> =
TestWorld::with_shards(TestWorldConfig::default());
let input = vec![Boolean::truncate_from(0_u32), Boolean::truncate_from(1_u32)];
let r = world
.malicious(input.clone().into_iter(), |ctx, input| async move {
assert_eq!(1, input.iter().len());
let ctx = ctx.set_total_records(TotalRecords::ONE);
let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 1);
let ctx = validator.context();
let r = input[0]
.multiply(&input[0], ctx, RecordId::FIRST)
.await
.unwrap();
validator.validate().await.unwrap();

vec![r]
})
.await
.into_iter()
.flat_map(|v| v.reconstruct())
.collect::<Vec<_>>();

assert_eq!(input, r);
});
}

#[test]
fn mac_malicious_sharded() {
run(|| async {
let world: TestWorld<WithShards<2>> =
TestWorld::with_shards(TestWorldConfig::default());
let input = vec![Fp31::truncate_from(0_u32), Fp31::truncate_from(1_u32)];
let r = world
.malicious(input.clone().into_iter(), |ctx, input| async move {
assert_eq!(1, input.iter().len());
let validator = ctx.set_total_records(1).validator();
let ctx = validator.context();
let (a_upgraded, b_upgraded) = (input[0].clone(), input[0].clone())
.upgrade(ctx.clone(), RecordId::FIRST)
.await
.unwrap();
let _ = a_upgraded
.multiply(&b_upgraded, ctx.narrow("multiply"), RecordId::FIRST)
.await
.unwrap();
ctx.validate_record(RecordId::FIRST).await.unwrap();

input
})
.await
.into_iter()
.flat_map(|v| v.reconstruct())
.collect::<Vec<_>>();

assert_eq!(input, r);
});
}
}

0 comments on commit b56e21c

Please sign in to comment.