Skip to content

Commit

Permalink
Merge pull request #1053 from benjaminsavage/remove_sparse_multiply
Browse files Browse the repository at this point in the history
Remove sparse multiply
  • Loading branch information
benjaminsavage authored May 13, 2024
2 parents a6924d8 + b766bed commit 498e4ec
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 567 deletions.
12 changes: 3 additions & 9 deletions ipa-core/src/protocol/basics/check_zero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
error::Error,
ff::Field,
protocol::{
basics::{mul::semi_honest_multiply, reveal::Reveal, ZeroPositions},
basics::{mul::semi_honest_multiply, reveal::Reveal},
context::Context,
prss::{FromRandom, SharedRandomness},
RecordId,
Expand Down Expand Up @@ -54,14 +54,8 @@ where
{
let r_sharing: Replicated<F> = ctx.prss().generate(record_id);

let rv_share = semi_honest_multiply(
ctx.narrow(&Step::MultiplyWithR),
record_id,
&r_sharing,
v,
ZeroPositions::NONE,
)
.await?;
let rv_share =
semi_honest_multiply(ctx.narrow(&Step::MultiplyWithR), record_id, &r_sharing, v).await?;
let rv = F::from_array(
&rv_share
.reveal(ctx.narrow(&Step::RevealR), record_id)
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/basics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::ops::Not;
#[cfg(feature = "descriptive-gate")]
pub use check_zero::check_zero;
pub use if_else::{if_else, select};
pub use mul::{BooleanArrayMul, MultiplyZeroPositions, SecureMul, ZeroPositions};
pub use mul::{BooleanArrayMul, SecureMul};
pub use reshare::Reshare;
pub use reveal::{partial_reveal, reveal, Reveal};
pub use share_known_value::ShareKnownValue;
Expand Down
8 changes: 3 additions & 5 deletions ipa-core/src/protocol/basics/mul/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
error::Error,
ff::Field,
protocol::{
basics::{mul::semi_honest::multiplication_protocol, MultiplyZeroPositions, SecureMul},
basics::{mul::semi_honest::multiplication_protocol, SecureMul},
context::{
dzkp_field::DZKPCompatibleField, dzkp_validator::Segment, Context, DZKPContext,
DZKPUpgradedMaliciousContext,
Expand Down Expand Up @@ -32,7 +32,6 @@ pub async fn multiply<'a, F, const N: usize>(
record_id: RecordId,
a: &Replicated<F, N>,
b: &Replicated<F, N>,
zeros: MultiplyZeroPositions,
) -> Result<Replicated<F, N>, Error>
where
F: Field + DZKPCompatibleField<N>,
Expand All @@ -42,7 +41,7 @@ where
.prss()
.generate::<(<F as Vectorizable<N>>::Array, _), _>(record_id);

let z = multiplication_protocol(&ctx, record_id, a, b, &prss_left, &prss_right, zeros).await?;
let z = multiplication_protocol(&ctx, record_id, a, b, &prss_left, &prss_right).await?;

// create segment
let segment = Segment::from_entries(
Expand Down Expand Up @@ -71,12 +70,11 @@ impl<'a, F: Field + DZKPCompatibleField<N>, const N: usize>
rhs: &Self,
ctx: DZKPUpgradedMaliciousContext<'a>,
record_id: RecordId,
zeros_at: MultiplyZeroPositions,
) -> Result<Self, Error>
where
DZKPUpgradedMaliciousContext<'a>: 'fut,
{
multiply(ctx, record_id, self, rhs, zeros_at).await
multiply(ctx, record_id, self, rhs).await
}
}

Expand Down
8 changes: 2 additions & 6 deletions ipa-core/src/protocol/basics/mul/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use ipa_macros::Step;
use crate::{
error::Error,
protocol::{
basics::{mul::semi_honest_multiply, MultiplyZeroPositions, SecureMul, ZeroPositions},
basics::{mul::semi_honest_multiply, SecureMul},
context::{Context, UpgradedMaliciousContext},
RecordId,
},
Expand Down Expand Up @@ -59,7 +59,6 @@ pub async fn multiply<F>(
record_id: RecordId,
a: &MaliciousReplicated<F>,
b: &MaliciousReplicated<F>,
zeros_at: MultiplyZeroPositions,
) -> Result<MaliciousReplicated<F>, Error>
where
F: ExtendableField,
Expand Down Expand Up @@ -95,14 +94,12 @@ where
record_id,
a.x().access_without_downgrade(),
b_x,
zeros_at,
),
semi_honest_multiply(
duplicate_multiply_ctx.base_context(),
record_id,
a.rx(),
&b_induced_share,
(ZeroPositions::Pvvv, zeros_at.1),
),
)
.await?;
Expand All @@ -121,12 +118,11 @@ impl<'a, F: ExtendableField> SecureMul<UpgradedMaliciousContext<'a, F>> for Mali
rhs: &Self,
ctx: UpgradedMaliciousContext<'a, F>,
record_id: RecordId,
zeros_at: MultiplyZeroPositions,
) -> Result<Self, Error>
where
UpgradedMaliciousContext<'a, F>: 'fut,
{
multiply(ctx, record_id, self, rhs, zeros_at).await
multiply(ctx, record_id, self, rhs).await
}
}

Expand Down
8 changes: 2 additions & 6 deletions ipa-core/src/protocol/basics/mul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ mod dzkp_malicious;
#[cfg(feature = "descriptive-gate")]
pub(crate) mod malicious;
mod semi_honest;
pub(in crate::protocol) mod sparse;

#[cfg(feature = "descriptive-gate")]
pub use semi_honest::multiply as semi_honest_multiply;
pub use sparse::{MultiplyZeroPositions, ZeroPositions};

/// Trait to multiply secret shares. That requires communication and `multiply` function is async.
#[async_trait]
Expand All @@ -34,8 +32,7 @@ pub trait SecureMul<C: Context>: Send + Sync + Sized {
where
C: 'fut,
{
self.multiply_sparse(rhs, ctx, record_id, ZeroPositions::NONE)
.await
self.multiply_sparse(rhs, ctx, record_id).await
}

/// Multiply and return the result of `a` * `b`.
Expand All @@ -48,7 +45,6 @@ pub trait SecureMul<C: Context>: Send + Sync + Sized {
rhs: &Self,
ctx: C,
record_id: RecordId,
zeros_at: MultiplyZeroPositions,
) -> Result<Self, Error>
where
C: 'fut;
Expand Down Expand Up @@ -109,7 +105,7 @@ macro_rules! boolean_array_mul {
where
C: Context + 'fut,
{
semi_honest_mul(ctx, record_id, a, b, ZeroPositions::NONE)
semi_honest_mul(ctx, record_id, a, b)
}
}
};
Expand Down
79 changes: 21 additions & 58 deletions ipa-core/src/protocol/basics/mul/semi_honest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
ff::{Field, PrimeField},
helpers::Direction,
protocol::{
basics::{mul::sparse::MultiplyWork, MultiplyZeroPositions},
context::{
dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded,
semi_honest::{Context as SemiHonestContext, Upgraded as UpgradedSemiHonestContext},
Expand All @@ -15,8 +14,7 @@ use crate::{
RecordId,
},
secret_sharing::{
replicated::semi_honest::AdditiveShare as Replicated, FieldSimd, SharedValueArray,
Vectorizable,
replicated::semi_honest::AdditiveShare as Replicated, FieldSimd, Vectorizable,
},
sharding,
};
Expand All @@ -34,7 +32,6 @@ pub async fn multiply<C, F, const N: usize>(
record_id: RecordId,
a: &Replicated<F, N>,
b: &Replicated<F, N>,
zeros: MultiplyZeroPositions,
) -> Result<Replicated<F, N>, Error>
where
C: Context,
Expand All @@ -46,7 +43,7 @@ where
.prss()
.generate::<(<F as Vectorizable<N>>::Array, _), _>(record_id);

multiplication_protocol(&ctx, record_id, a, b, &prss_left, &prss_right, zeros).await
multiplication_protocol(&ctx, record_id, a, b, &prss_left, &prss_right).await
}

/// IKHC multiplication protocol
Expand All @@ -69,54 +66,29 @@ pub async fn multiplication_protocol<C, F, const N: usize>(
b: &Replicated<F, N>,
prss_left: &<F as Vectorizable<N>>::Array,
prss_right: &<F as Vectorizable<N>>::Array,
zeros: MultiplyZeroPositions,
) -> Result<Replicated<F, N>, Error>
where
C: Context,
F: Field + FieldSimd<N>,
{
let role = ctx.role();
let [need_to_recv, need_to_send, need_random_right] = zeros.work_for(role);
zeros.0.check(role, "a", a);
zeros.1.check(role, "b", b);

let mut rhs = a.right_arr().clone() * b.right_arr();

if need_to_send {
// Compute the value (d_i) we want to send to the right helper (i+1).
let right_d =
a.left_arr().clone() * b.right_arr() + a.right_arr().clone() * b.left_arr() - prss_left;

ctx.send_channel::<<F as Vectorizable<N>>::Array>(role.peer(Direction::Right))
.send(record_id, &right_d)
.await?;
rhs += right_d;
} else {
debug_assert_eq!(
a.left_arr().clone() * b.right_arr() + a.right_arr().clone() * b.left_arr(),
<<F as Vectorizable<N>>::Array as SharedValueArray<F>>::ZERO_ARRAY
);
}
// Add randomness to this value whether we sent or not, depending on whether the
// peer to the right needed to send. If they send, they subtract randomness,
// and we need to add to our share to compensate.
if need_random_right {
rhs += prss_right;
}

// Compute the value (d_i) we want to send to the right helper (i+1).
let right_d =
a.left_arr().clone() * b.right_arr() + a.right_arr().clone() * b.left_arr() - prss_left;

ctx.send_channel::<<F as Vectorizable<N>>::Array>(role.peer(Direction::Right))
.send(record_id, &right_d)
.await?;

let rhs = a.right_arr().clone() * b.right_arr() + right_d + prss_right;

// Sleep until helper on the left sends us their (d_i-1) value.
let mut lhs = a.left_arr().clone() * b.left_arr();
if need_to_recv {
let left_d: <F as Vectorizable<N>>::Array = ctx
.recv_channel(role.peer(Direction::Left))
.receive(record_id)
.await?;
lhs += left_d;
}
// If we send, we subtract randomness, so we need to add to our share.
if need_to_send {
lhs += prss_left;
}
let left_d: <F as Vectorizable<N>>::Array = ctx
.recv_channel(role.peer(Direction::Left))
.receive(record_id)
.await?;
let lhs = a.left_arr().clone() * b.left_arr() + left_d + prss_left;

Ok(Replicated::new_arr(lhs, rhs))
}
Expand All @@ -138,12 +110,11 @@ where
rhs: &Self,
ctx: SemiHonestContext<'a, B>,
record_id: RecordId,
zeros_at: MultiplyZeroPositions,
) -> Result<Self, Error>
where
SemiHonestContext<'a, B>: 'fut,
{
multiply(ctx, record_id, self, rhs, zeros_at).await
multiply(ctx, record_id, self, rhs).await
}
}

Expand All @@ -160,12 +131,11 @@ where
rhs: &Self,
ctx: UpgradedSemiHonestContext<'a, B, F>,
record_id: RecordId,
zeros_at: MultiplyZeroPositions,
) -> Result<Self, Error>
where
UpgradedSemiHonestContext<'a, B, F>: 'fut,
{
multiply(ctx, record_id, self, rhs, zeros_at).await
multiply(ctx, record_id, self, rhs).await
}
}

Expand All @@ -181,12 +151,11 @@ where
rhs: &Self,
ctx: SemiHonestDZKPUpgraded<'a, B>,
record_id: RecordId,
zeros_at: MultiplyZeroPositions,
) -> Result<Self, Error>
where
SemiHonestDZKPUpgraded<'a, B>: 'fut,
{
multiply(ctx, record_id, self, rhs, zeros_at).await
multiply(ctx, record_id, self, rhs).await
}
}

Expand All @@ -204,11 +173,7 @@ mod test {
use crate::{
ff::{Field, Fp31, Fp32BitPrime, U128Conversions},
helpers::TotalRecords,
protocol::{
basics::{SecureMul, ZeroPositions},
context::Context,
RecordId,
},
protocol::{basics::SecureMul, context::Context, RecordId},
rand::{thread_rng, Rng},
secret_sharing::replicated::semi_honest::AdditiveShare,
seq_join::SeqJoin,
Expand Down Expand Up @@ -323,7 +288,6 @@ mod test {
RecordId::from(0),
&a_shares,
&b_shares,
ZeroPositions::NONE,
)
.await
.unwrap()
Expand Down Expand Up @@ -423,7 +387,6 @@ mod test {
RecordId::from(i - 1),
&val,
iter.next().unwrap(),
ZeroPositions::NONE,
)
.await
.unwrap();
Expand Down
Loading

0 comments on commit 498e4ec

Please sign in to comment.