Skip to content

Commit

Permalink
Merge pull request #784 from martinthomson/bitwise_to_onehot
Browse files Browse the repository at this point in the history
Comments, tweak, test, rename for bitwise_to_onehot
  • Loading branch information
martinthomson authored Sep 11, 2023
2 parents 73a2342 + 88368eb commit 6c60df5
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 49 deletions.
4 changes: 2 additions & 2 deletions src/protocol/attribution/aggregate_credit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
protocol::{
context::{UpgradableContext, UpgradedContext, Validator},
modulus_conversion::convert_bits,
sort::{check_everything, generate_permutation::ShuffledPermutationWrapper},
sort::{bitwise_to_onehot, generate_permutation::ShuffledPermutationWrapper},
step::BitOpStep,
BasicProtocols, RecordId,
},
Expand Down Expand Up @@ -114,7 +114,7 @@ where
let ceq = &equality_check_context;
let cmul = &check_times_credit_context;
async move {
let equality_checks = check_everything(ceq.clone(), i, &bk?).await?;
let equality_checks = bitwise_to_onehot(ceq.clone(), i, &bk?).await?;
ceq.try_join(equality_checks.into_iter().take(to_take).enumerate().map(
|(check_idx, check)| {
let step = BitOpStep::from(check_idx);
Expand Down
152 changes: 109 additions & 43 deletions src/protocol/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,28 +82,29 @@ pub(crate) enum ReshareStep {
ReshareRx,
}

/// Convert a bitwise representation of a number into a one-hot encoding of that number.
/// That is, an array of value of 1 at the index corresponding to the value of the number,
/// and a 0 at all other indices.
///
/// This function accepts a sequence of N secret-shared bits.
/// When considered as a bitwise representation of an N-bit unsigned number, it's clear that there are exactly
/// This function accepts a sequence of N secret-shared bits, with the least significant bit at index 0.
/// When considered as a bitwise representation of an N-bit unsigned number, there are exactly
/// `2^N` possible values this could have.
/// This function checks all of these possible values, and returns a vector of secret-shared results.
/// Only one result will be a secret-sharing of one, all of the others will be secret-sharings of zero.
///
/// # Errors
/// If any multiplication fails, or if the record is too long (e.g. more than 64 multiplications required)
pub async fn check_everything<F, C, S>(
pub async fn bitwise_to_onehot<F, C, S>(
ctx: C,
record_idx: usize,
record: &[S],
number: &[S],
) -> Result<BitDecomposed<S>, Error>
where
F: Field,
C: Context,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
{
let num_bits = record.len();
let num_bits = number.len();
let precomputed_combinations =
pregenerate_all_combinations(ctx, record_idx, record, num_bits).await?;
generate_all_combinations(ctx, record_idx, number, num_bits).await?;

// This loop just iterates over all the possible values this N-bit input could potentially represent
// and checks if the bits are equal to this value. It does so by computing a linear combination of the
Expand All @@ -114,40 +115,49 @@ where
// https://en.wikipedia.org/wiki/Sierpi%C5%84ski_triangle#/media/File:Multigrade_operator_AND.svg.
//
// For example, for a three bit value, we have the following:
// 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1
// 0 | 1 | 0 | 1 | 0 | 1 | 0 | 1
// 0 | 0 | 1 | 1 | 0 | 0 | 1 | 1
// 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1
// 0 | 0 | 0 | 0 | 1 | 1 | 1 | 1
// 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1
// 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1
// 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1
// 0: 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1
// 1: 0 | 1 | 0 | 1 | 0 | 1 | 0 | 1
// 2: 0 | 0 | 1 | 1 | 0 | 0 | 1 | 1
// 3: 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1
// 4: 0 | 0 | 0 | 0 | 1 | 1 | 1 | 1
// 5: 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1
// 6: 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1
// 7: 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1
//
// This can be computed from row (i) and column (j) indices with i & j == i
//
// The sign of the inclusion is less obvious, but we discovered that this
// can be found by taking the same row (i) and column (j) indices:
// 1. Invert the row index and bitwise AND the values: a = !i & j
// 1. Invert the row index and bitwise AND the indices: a = !i & j
// 2. Count the number of bits that are set: b = a.count_ones()
// 3. An odd number means a positive coefficient; an odd number means a negative.
//
// For example, for a three bit value, step 1 produces (in binary):
// 000 | 001 | 010 | 011 | 100 | 101 | 110 | 111
// 000 | 000 | 010 | 010 | 100 | 100 | 110 | 110
// 000 | 001 | 000 | 001 | 100 | 101 | 100 | 101
// 000 | 000 | 000 | 000 | 100 | 100 | 100 | 100
// 000 | 001 | 010 | 011 | 000 | 001 | 010 | 011
// 000 | 000 | 010 | 010 | 000 | 000 | 010 | 010
// 000 | 001 | 000 | 001 | 000 | 001 | 000 | 001
// 000 | 000 | 000 | 000 | 000 | 000 | 000 | 000
// 0: 000 | 001 | 010 | 011 | 100 | 101 | 110 | 111
// 1: 000 | 000 | 010 | 010 | 100 | 100 | 110 | 110
// 2: 000 | 001 | 000 | 001 | 100 | 101 | 100 | 101
// 3: 000 | 000 | 000 | 000 | 100 | 100 | 100 | 100
// 4: 000 | 001 | 010 | 011 | 000 | 001 | 010 | 011
// 5: 000 | 000 | 010 | 010 | 000 | 000 | 010 | 010
// 6: 000 | 001 | 000 | 001 | 000 | 001 | 000 | 001
// 7: 000 | 000 | 000 | 000 | 000 | 000 | 000 | 000
//
// Where 000, 101, 011, and 110 mean positive contributions, and
// 001, 010, 100, and 111 mean negative contributions.
//
// 0: + | - | - | + | - | + | + | -
// 1: . | + | . | - | . | - | . | +
// 2: . | . | + | - | . | . | - | +
// 3: . | . | . | + | . | . | . | -
// 4: . | . | . | . | + | - | - | +
// 5: . | . | . | . | . | + | . | -
// 6: . | . | . | . | . | . | + | -
// 7: . | . | . | . | . | . | . | +
Ok(BitDecomposed::decompose(1 << num_bits, |i| {
let mut check = S::ZERO;
for (j, combination) in precomputed_combinations.iter().enumerate() {
let bit: i8 = i8::from((i & j) == i);
if bit > 0 {
// Small optimization: skip the blank area and start with the first "+".
let mut check = precomputed_combinations[i].clone();
for (j, combination) in precomputed_combinations.iter().enumerate().skip(i + 1) {
if (i & j) == i {
if (!i & j).count_ones() & 1 == 1 {
check -= combination;
} else {
Expand Down Expand Up @@ -184,7 +194,7 @@ where
// Operation complexity of this function is `2^n-n-1` where `n` is the number of bits.
// Circuit depth is equal to `n-2`.
// This gets inefficient very quickly as a result.
async fn pregenerate_all_combinations<F, C, S>(
async fn generate_all_combinations<F, C, S>(
ctx: C,
record_idx: usize,
input: &[S],
Expand All @@ -196,25 +206,81 @@ where
S: SecretSharing<F> + BasicProtocols<C, F>,
{
let record_id = RecordId::from(record_idx);
let mut precomputed_combinations = Vec::with_capacity(1 << num_bits);
precomputed_combinations.push(S::share_known_value(&ctx, F::ONE));
let mut all_combinations = Vec::with_capacity(1 << num_bits);
all_combinations.push(S::share_known_value(&ctx, F::ONE));
for (bit_idx, bit) in input.iter().enumerate() {
let step = 1 << bit_idx;
// Concurrency needed here because we are operating on different bits for the same record.
let mut multiplication_results = ctx
.parallel_join(precomputed_combinations.iter().skip(1).enumerate().map(
|(j, precomputed_combination)| {
let mut multiplication_results =
ctx.parallel_join(all_combinations.iter().skip(1).enumerate().map(
|(j, combination)| {
let child_idx = j + step;
precomputed_combination.multiply(
bit,
ctx.narrow(&BitOpStep::from(child_idx)),
record_id,
)
combination.multiply(bit, ctx.narrow(&BitOpStep::from(child_idx)), record_id)
},
))
.await?;
precomputed_combinations.push(bit.clone());
precomputed_combinations.append(&mut multiplication_results);
all_combinations.push(bit.clone());
all_combinations.append(&mut multiplication_results);
}
Ok(all_combinations)
}

#[cfg(all(test, unit_test))]
mod test {
use futures::future::join4;

use crate::{
ff::{Field, Fp31},
protocol::{context::Context, sort::bitwise_to_onehot},
secret_sharing::{BitDecomposed, SharedValue},
seq_join::SeqJoin,
test_fixture::{Reconstruct, Runner, TestWorld},
};

async fn check_onehot(bits: u32) {
let world = TestWorld::default();

// Construct bitwise sharings of all values from 0 to 2^BITS-1.
let input = (0..(1 << bits)).map(move |i| {
BitDecomposed::decompose(bits, |j| {
Fp31::truncate_from(u128::from((i & (1 << j)) == (1 << j)))
})
});

let result = world
.semi_honest(input, |ctx, m_shares| async move {
let ctx = ctx.set_total_records(m_shares.len());
ctx.try_join(
m_shares
.iter()
.enumerate()
.map(|(i, n)| bitwise_to_onehot(ctx.clone(), i, n)),
)
.await
.unwrap()
})
.await
.reconstruct();

for (i, onehot) in result.into_iter().enumerate() {
for (j, v) in onehot.into_iter().enumerate() {
if i == j {
assert_eq!(Fp31::ONE, v);
} else {
assert_eq!(Fp31::ZERO, v);
}
}
}
}

#[tokio::test]
async fn several_onehot() {
_ = join4(
check_onehot(1),
check_onehot(2),
check_onehot(3),
check_onehot(4),
)
.await;
}
Ok(precomputed_combinations)
}
8 changes: 4 additions & 4 deletions src/protocol/sort/multi_bit_permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
error::Error,
ff::PrimeField,
protocol::{
basics::SumOfProducts, context::UpgradedContext, sort::check_everything, BasicProtocols,
basics::SumOfProducts, context::UpgradedContext, sort::bitwise_to_onehot, BasicProtocols,
RecordId,
},
secret_sharing::{
Expand Down Expand Up @@ -66,7 +66,7 @@ where
.iter()
.zip(repeat(ctx.set_total_records(num_records)))
.enumerate()
.map(|(idx, (record, ctx))| check_everything(ctx, idx, record)),
.map(|(idx, (record, ctx))| bitwise_to_onehot(ctx, idx, record)),
)
.await?;

Expand Down Expand Up @@ -117,7 +117,7 @@ mod tests {
ff::{Field, Fp31},
protocol::{
context::{Context, UpgradableContext, Validator},
sort::check_everything,
sort::bitwise_to_onehot,
},
secret_sharing::{BitDecomposed, SharedValue},
seq_join::SeqJoin,
Expand Down Expand Up @@ -170,7 +170,7 @@ mod tests {
let ctx = ctx.set_total_records(num_records);
let mut equality_check_futures = Vec::with_capacity(num_records);
for (i, record) in m_shares.iter().enumerate() {
equality_check_futures.push(check_everything(ctx.clone(), i, record));
equality_check_futures.push(bitwise_to_onehot(ctx.clone(), i, record));
}
ctx.try_join(equality_check_futures).await.unwrap()
})
Expand Down

0 comments on commit 6c60df5

Please sign in to comment.