Skip to content

Commit

Permalink
Merge pull request #892 from private-attribution/quicksort_batches
Browse files Browse the repository at this point in the history
Quicksort with friendlier approach for our infra
  • Loading branch information
benjaminsavage authored Dec 13, 2023
2 parents f244371 + 9604c78 commit 4289946
Showing 1 changed file with 238 additions and 53 deletions.
291 changes: 238 additions & 53 deletions ipa-core/src/protocol/ipa_prf/quicksort.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
use std::{
iter::{repeat, zip},
ops::Range,
};

use bitvec::prelude::{BitVec, Lsb0};
use futures::stream::{iter as stream_iter, TryStreamExt};
use ipa_macros::Step;
Expand All @@ -15,8 +20,8 @@ use crate::{

#[derive(Step)]
pub(crate) enum Step {
Left,
Right,
#[dynamic(1024)]
QuicksortPass(usize),
Compare,
Reveal,
}
Expand All @@ -41,11 +46,12 @@ pub(crate) enum Step {
/// It terminates once the stack is empty.
/// # Errors
/// Will propagate errors from transport and a few typecasts
pub async fn quicksort_by_key_insecure<C, K, F, S>(
pub async fn quicksort_ranges_by_key_insecure<C, K, F, S>(
ctx: C,
list: &mut [S],
desc: bool,
get_key: F,
mut ranges_to_sort: Vec<Range<usize>>,
) -> Result<(), Error>
where
C: Context,
Expand All @@ -54,65 +60,93 @@ where
for<'a> &'a AdditiveShare<K>: IntoIterator<Item = AdditiveShare<K::Element>>,
K: SharedValue + Field + CustomArray<Element = Boolean>,
{
// expected amount of recursion: compute ceil(log_2(list.len()))
let expected_depth = std::mem::size_of::<usize>() * 8 - list.len().leading_zeros() as usize;
// create stack
let mut stack: Vec<(C, usize, usize)> = Vec::with_capacity(expected_depth);

// initialize stack
stack.push((ctx, 0usize, list.len()));

// iterate through quicksort recursions
while let Some((ctx, b_l, b_r)) = stack.pop() {
// start of quicksort function
// check whether sort is needed
if b_l + 1 < b_r {
// set up iterator
let mut iterator = list[b_l..b_r].iter().map(get_key);
// first element is pivot, apply key extraction function f
let pivot = iterator.next().unwrap();
// create pointer to context for moving into closure
let pctx = &(ctx.set_total_records(b_r - (b_l + 1)));
// precompute comparison against pivot and reveal result in parallel
let comp: BitVec<usize, Lsb0> = seq_join(
ctx.active_work(),
stream_iter(iterator.enumerate().map(|(n, k)| async move {
// Compare the current element against pivot and reveal the result.
let comparison =
compare_gt(pctx.narrow(&Step::Compare), RecordId::from(n), k, pivot)
.await?
.reveal(pctx.narrow(&Step::Reveal), RecordId::from(n))
.await?;

// reveal outcome of comparison
// desc = true will flip the order of the sort
Ok::<_, Error>(Boolean::from(false ^ desc) == comparison)
})),
)
.try_collect()
.await?;
let mut ranges_for_next_pass = Vec::with_capacity(ranges_to_sort.len() * 2);
let mut quicksort_pass = 1;

// iterate through all of the potentially incorrectly ordered ranges
// make one pass, comparing each element to the pivot and splitting into two more
// potentially incorrectly ordered ranges
while !ranges_to_sort.is_empty() {
// compute the total number of comparisons that will be needed
let mut num_comparisons_needed = 0;
for range in &ranges_to_sort {
if range.len() > 1 {
num_comparisons_needed += range.len() - 1;
}
}

let c = ctx
.narrow(&Step::QuicksortPass(quicksort_pass))
.set_total_records(num_comparisons_needed);
let cmp_ctx = c.narrow(&Step::Compare);
let rvl_ctx = c.narrow(&Step::Reveal);

let comp: BitVec<usize, Lsb0> = seq_join(
ctx.active_work(),
stream_iter(
ranges_to_sort
.iter()
.filter(|r| r.len() > 1)
.flat_map(|range| {
// set up iterator
let mut iterator = list[range.clone()].iter().map(get_key);
// first element is pivot, apply key extraction function f
let pivot = iterator.next().unwrap();
zip(repeat(pivot), iterator)
})
.enumerate()
.map(|(i, (pivot, k))| {
let cmp_ctx = cmp_ctx.clone();
let rvl_ctx = rvl_ctx.clone();
let record_id = RecordId::from(i);
async move {
// Compare the current element against pivot and reveal the result.
let comparison = compare_gt(cmp_ctx, record_id, k, pivot)
.await?
.reveal(rvl_ctx, record_id) // reveal outcome of comparison
.await?;

// desc = true will flip the order of the sort
Ok::<_, Error>(Boolean::from(desc) == comparison)
}
}),
),
)
.try_collect()
.await?;

let mut n = 0;
for range in &ranges_to_sort {
if range.len() <= 1 {
continue;
}

// swap elements based on comparisons
// i is index of first element larger than pivot
let mut i = b_l + 1;
for (j, b) in comp.iter().enumerate() {
let mut i = range.start + 1;
for (j, b) in comp[n..(n + range.len() - 1)].iter().enumerate() {
if *b {
list.swap(i, j + b_l + 1);
list.swap(i, j + range.start + 1);
i += 1;
}
}
n += range.len() - 1;

// put pivot to index i-1
list.swap(i - 1, b_l);
list.swap(i - 1, range.start);

// push recursively calls to quicksort function on stack
if i > b_l + 1 {
stack.push((ctx.narrow(&Step::Left), b_l, i - 1));
// mark which ranges need to be sorted in the next pass
if i > range.start + 1 {
ranges_for_next_pass.push(range.start..(i - 1));
}
if i + 1 < b_r {
stack.push((ctx.narrow(&Step::Right), i, b_r));
if i + 1 < range.end {
ranges_for_next_pass.push(i..range.end);
}
}

quicksort_pass += 1;
ranges_to_sort = ranges_for_next_pass;
ranges_for_next_pass = Vec::with_capacity(ranges_to_sort.len() * 2);
}

// no error happened, sorted successfully
Expand All @@ -121,15 +155,20 @@ where

#[cfg(all(test, unit_test))]
pub mod tests {
use std::cmp::Ordering;

use ipa_macros::Step;
use rand::Rng;

use crate::{
error::Error,
ff::{boolean_array::BA64, Field},
protocol::{context::Context, ipa_prf::quicksort::quicksort_by_key_insecure},
ff::{
boolean_array::{BA20, BA64},
Field,
},
protocol::{context::Context, ipa_prf::quicksort::quicksort_ranges_by_key_insecure},
rand::thread_rng,
secret_sharing::replicated::semi_honest::AdditiveShare,
secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares},
test_executor::run,
test_fixture::{Reconstruct, Runner, TestWorld},
};
Expand All @@ -148,7 +187,9 @@ pub mod tests {
C: Context,
{
let mut list_mut = list.to_vec();
quicksort_by_key_insecure(ctx, &mut list_mut[..], desc, |x| x).await?;
#[allow(clippy::single_range_in_vec_init)]
quicksort_ranges_by_key_insecure(ctx, &mut list_mut[..], desc, |x| x, vec![0..list.len()])
.await?;
let mut result: Vec<AdditiveShare<BA64>> = vec![];
list_mut.iter().for_each(|x| result.push(x.clone()));
Ok(result)
Expand Down Expand Up @@ -345,4 +386,148 @@ pub mod tests {
}
});
}

#[derive(Clone, Copy, Debug)]
struct SillyStruct {
timestamp: BA20,
user_id: usize,
}

#[derive(Clone, Debug)]
struct SillyStructShare {
timestamp: AdditiveShare<BA20>,
user_id: usize,
}

impl IntoShares<SillyStructShare> for SillyStruct {
fn share_with<R: Rng>(self, rng: &mut R) -> [SillyStructShare; 3] {
let [t0, t1, t2] = self.timestamp.share_with(rng);
[
SillyStructShare {
timestamp: t0,
user_id: self.user_id,
},
SillyStructShare {
timestamp: t1,
user_id: self.user_id,
},
SillyStructShare {
timestamp: t2,
user_id: self.user_id,
},
]
}
}

impl Reconstruct<SillyStruct> for [SillyStructShare; 3] {
fn reconstruct(&self) -> SillyStruct {
SillyStruct {
user_id: self[0].user_id,
timestamp: [
self[0].timestamp.clone(),
self[1].timestamp.clone(),
self[2].timestamp.clone(),
]
.reconstruct(),
}
}
}

impl Reconstruct<Vec<SillyStruct>> for [Vec<SillyStructShare>; 3] {
fn reconstruct(&self) -> Vec<SillyStruct> {
let mut res = Vec::with_capacity(self[0].len());
for i in 0..self[0].len() {
let elem = [self[0][i].clone(), self[1][i].clone(), self[2][i].clone()];
res.push(elem.reconstruct());
}
res
}
}

const TEST_USER_IDS: [usize; 8] = [1, 2, 3, 5, 8, 13, 21, 34];

// test for sorting multiple ranges in a longer list
#[test]
fn test_multiple_ranges() {
run(|| async move {
let world = TestWorld::default();
let mut rng = thread_rng();

// test cases for both, ascending and descending
let bools = vec![false, true];

for desc in bools {
// generate vector of structs corresponding to 8 users.
// Each user will have a different number of records
// Each struct will have a random timestamps
let mut records: Vec<SillyStruct> = Vec::with_capacity(TEST_USER_IDS.iter().sum());
for user_id in TEST_USER_IDS {
for _ in 0..user_id {
records.push(SillyStruct {
timestamp: rng.gen::<BA20>(),
user_id,
});
}
}

// convert expected into more readable format
let mut expected: Vec<(usize, u128)> = records
.clone()
.into_iter()
.map(|x| (x.user_id, x.timestamp.as_u128()))
.collect();
// sort expected
expected.sort_unstable_by(|a, b| match a.0.cmp(&b.0) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => {
if desc {
b.1.cmp(&a.1)
} else {
a.1.cmp(&b.1)
}
}
});

let (_, ranges_to_sort) = TEST_USER_IDS.iter().fold(
(0, Vec::with_capacity(TEST_USER_IDS.len())),
|acc, x| {
let (start, mut ranges) = acc;
let end = start + x;
ranges.push(start..end);
(end, ranges)
},
);

// compute mpc sort
let result: Vec<_> = world
.semi_honest(records.into_iter(), |ctx, mut records| {
let ranges_clone = ranges_to_sort.clone();
async move {
quicksort_ranges_by_key_insecure(
ctx,
&mut records[..],
desc,
|x| &x.timestamp,
ranges_clone,
)
.await
.unwrap();
records
}
})
.await
.reconstruct();

assert_eq!(
// convert into more readable format
result
.into_iter()
.map(|x| (x.user_id, x.timestamp.as_u128()))
.collect::<Vec<_>>(),
expected
);
}
});
}
}

0 comments on commit 4289946

Please sign in to comment.