-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sparse aggregation protocol #783
Conversation
} | ||
|
||
/// Binary-share aggregation protocol. | ||
/// Binary-share aggregation protocol for a sparse breakdown key vector input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is the interface of this protocol, Can you also add a comment about the input and output format i.e. an array of [conversion value, breakdown key] and also output
}, | ||
) | ||
.await?, | ||
upgrade_bit_shares( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be done in parallel with the line above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to use try_join
on the tuples.
streams | ||
.try_fold(vec![S::ZERO; num_buckets], |mut acc, bucket| async move { | ||
for (i, b) in bucket.into_iter().enumerate() { | ||
acc[i] += &b; | ||
} | ||
Ok(acc) | ||
}) | ||
.await?; | ||
|
||
Ok(aggregate) | ||
.await |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be misunderstanding this, but this looks like it's operating across all of the streams, but I thought each stream was able to progress independently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think try_fold
applies async computations on each stream independently. cc: @martinthomson
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure that this will work sequentially.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@akoshelev and @martinthomson - can you please advise on how to write this code for maximum streaming performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the use of a fold()
here is right. You want to operate on a per-row basis and accumulate across each of the buckets.
When we decompose the breakdown key into its one hot encoding, that operation is not inherently parallel, so each row is going to need a tree of combinations (a*b
, b*c
, a*c
, a*b*c
, etc...). We can parallelize those somewhat, and we do some of that already (in pregenerate_all_combinations
). There might be some room for optimizing the circuit depth if we go deeper than 3 (abc*d isn't ((a*b)*c)*d
as we currently calculate, but (a*b)*(c*d)
, ...), but that's small dice.
Then we have a bunch of operations that seem like they could be parallelized. The value is multiplied for each of the buckets of the decomposed breakdown key. Those could proceed in parallel proper. This code only sort of does that. It runs eq_ctx.try_join()
, which is a sequential operation for large numbers of inputs. But we have a small number of inputs in this case, so it is really a parallel operation (until we have more breakdown keys than active work). I'd probably recommend a parallel join here: we don't expect the number of breakdown keys to exceed our ability to track multiplications for each.
Either way, what this structure does is make progress on row N contingent on row N-1 completing. That's OK, I think.
The alternative is to let each bucket in the histogram proceed independently, but that is probably worse overall. The only reason you might want to pursue that approach is that it might (at some point) allow us to properly parallelize the multiplication stage (as in, actually use multiple threads).
src/protocol/aggregation/mod.rs
Outdated
let streams = seq_join( | ||
ctx.active_work(), | ||
stream_iter(breakdown_keys) | ||
.zip(stream_iter(contribution_values)) | ||
.enumerate() | ||
.map(|(i, (bk, v))| { | ||
let eq_ctx = &equality_check_ctx; | ||
let mul_ctx = &check_times_value_ctx; | ||
async move { | ||
let equality_checks = check_everything(eq_ctx.clone(), i, &bk).await?; | ||
let value = &v; | ||
eq_ctx | ||
.try_join( | ||
equality_checks | ||
.into_iter() | ||
.take(num_buckets) | ||
.enumerate() | ||
.map(|(check_idx, check)| { | ||
let step = BitOpStep::from(check_idx); | ||
let c = mul_ctx.narrow(&step); | ||
let record_id = RecordId::from(i); | ||
async move { | ||
check | ||
.multiply( | ||
&value.to_additive_sharing_in_large_field(), | ||
c, | ||
record_id, | ||
) | ||
.await | ||
} | ||
}), | ||
) | ||
.await | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find this really difficult to read / digest / understand.
Is there anything we can do to improve the code readability here? Perhaps something that leads to less indentation? Maybe factor something out into another function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extracted the multiplication part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests pass, so the logic is clearly accurate, but I find it quite hard to follow. I would like to try to find a way to make it easier to read / digest / understand.
const INPUT: &[u32] = &[0, 0, 18, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 10, 0, 0, 6, 0]; | ||
const EXPECTED: &[u128] = &[28, 0, 0, 6, 1, 0, 0, 8]; | ||
|
||
const INPUT: &[(u32, u32)] = &[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had trouble reading the ordering here. I read this as a k-v map, where it is a v-k map instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also got caught up on the same thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Swapped the order and added a comment.
src/protocol/aggregation/mod.rs
Outdated
#[tracing::instrument(name = "aggregate_values_per_bucket", skip_all)] | ||
pub async fn aggregate_values_per_bucket<F, C, S>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to use the term "sparse" here for consistency?
src/protocol/aggregation/mod.rs
Outdated
async move { | ||
let equality_checks = check_everything(eq_ctx.clone(), i, &bk).await?; | ||
equality_bits_times_value(&c, equality_checks, num_buckets, v, num_records, i) | ||
.await |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My assumption is that we want to send work over to other streams right here (as opposed to in a second loop).
Right at this point (after this await) we should have N secret-shares. I think this is the place to just iterate over all of them and essentially just Send them over to other threads for processing.
I think creating a second loop might work, but it's confusing because it requires you to understand that the loop is being evaluate in a lazy way...
.take(num_buckets) | ||
.enumerate() | ||
.map(|(check_idx, check)| { | ||
let step = BitOpStep::from(check_idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are kind of abusing "BitOpStep" now. This was designed for bitwise operations and cannot go above 64. There might be thousands of breakdowns - so maybe we should think more about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a PR in review that could solve this. We can create a step like BitOpStep but with a number of our choice.
src/protocol/aggregation/mod.rs
Outdated
let v = &value_bits; | ||
async move { | ||
check | ||
.multiply(&v.to_additive_sharing_in_large_field(), c, record_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not be doing this in a loop. We should do it only once outside of the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Funny. I had the same comment, but I hadn't even thought of the performance implication, just that this function had no business dealing with the conversion.
src/protocol/aggregation/mod.rs
Outdated
input_rows: &[SparseAggregateInputRow<CV, BK>], | ||
) -> Vec<BitDecomposed<Replicated<Gf2>>> | ||
num_bits: u32, | ||
f: H, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not clear from the name (f) or the comment (non-existing) what this param does
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I got thinking...
src/protocol/aggregation/mod.rs
Outdated
let eq_ctx = &equality_check_ctx; | ||
let mul_ctx = &check_times_value_ctx; | ||
async move { | ||
let equality_checks = check_everything(eq_ctx.clone(), i, &bk).await?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I'm having real trouble with check_everything
as a name. The function takes an array of arithmetic shares of the bits of a number and returns a sparse array with a share of 1 at the index indicated by the value of that number (and 0 elsewhere).
binary_number_to_sparse_index()
or binary_to_onehot()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like binary_to_onehot()
or bitwise_to_onehot()
// Generate N streams for each bucket specified by the breakdown key (N = |breakdown_keys|). | ||
// A stream is pipeline of contribution values multiplied by the "check bit". A check bit is | ||
// a bit that is a share of 1 if the breakdown key matches the bucket, and 0 otherwise. | ||
let streams = seq_join( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a set of streams, but a single stream.
streams | ||
.try_fold(vec![S::ZERO; num_buckets], |mut acc, bucket| async move { | ||
for (i, b) in bucket.into_iter().enumerate() { | ||
acc[i] += &b; | ||
} | ||
Ok(acc) | ||
}) | ||
.await?; | ||
|
||
Ok(aggregate) | ||
.await |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the use of a fold()
here is right. You want to operate on a per-row basis and accumulate across each of the buckets.
When we decompose the breakdown key into its one hot encoding, that operation is not inherently parallel, so each row is going to need a tree of combinations (a*b
, b*c
, a*c
, a*b*c
, etc...). We can parallelize those somewhat, and we do some of that already (in pregenerate_all_combinations
). There might be some room for optimizing the circuit depth if we go deeper than 3 (abc*d isn't ((a*b)*c)*d
as we currently calculate, but (a*b)*(c*d)
, ...), but that's small dice.
Then we have a bunch of operations that seem like they could be parallelized. The value is multiplied for each of the buckets of the decomposed breakdown key. Those could proceed in parallel proper. This code only sort of does that. It runs eq_ctx.try_join()
, which is a sequential operation for large numbers of inputs. But we have a small number of inputs in this case, so it is really a parallel operation (until we have more breakdown keys than active work). I'd probably recommend a parallel join here: we don't expect the number of breakdown keys to exceed our ability to track multiplications for each.
Either way, what this structure does is make progress on row N contingent on row N-1 completing. That's OK, I think.
The alternative is to let each bucket in the histogram proceed independently, but that is probably worse overall. The only reason you might want to pursue that approach is that it might (at some point) allow us to properly parallelize the multiplication stage (as in, actually use multiple threads).
src/protocol/aggregation/mod.rs
Outdated
.collect::<Vec<_>>() | ||
} | ||
.map(|row| BitDecomposed::decompose(num_bits, |i| f(row, i))) | ||
.collect::<Vec<_>>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't collect if you only iterate (or stream) the results.
src/protocol/aggregation/mod.rs
Outdated
stream_iter(gf2_bits), | ||
0..num_bits, | ||
) | ||
.try_collect::<Vec<_>>() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMPORTANT: This collect is worse because it covers all inputs. You should be able to pass streams into the aggregation function. This is probably the single biggest thing you can to do improve throughput.
src/protocol/aggregation/mod.rs
Outdated
async move { | ||
check | ||
.multiply( | ||
&value.to_additive_sharing_in_large_field(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The depth of this indentation bothers me some. Factoring might help, but fundamentally you are just doing two operations:
- decomposition of the breakdown key into its one hot encoding
- multiplication of the elements of that vector by the value
Maybe the way to construct this is then approximately:
let decomposed = seq_join(
eq_ctx.active_work(),
breakdown_keys.enumerate().map(|(row, bk)| check_everything(eq_ctx.clone(), row, &bk))
);
let multiplied = seq_join(
eq_ctx.active_work(),
decomposed.zip(values).enumerate().map(|(row, (buckets, v))| {
mul_ctx.parallel_join(
buckets.enumerate().map(|(column, b)| {
b.multiply(&value, mul_ctx.narrow(BitOpStep::from(column), row)
})
)
})
});
let histogram = multiplied.try_fold(...); // as below
I'm not sure about the use of BitOpStep
here. If we want a large histogram, we'll need something bigger.
src/protocol/aggregation/mod.rs
Outdated
async move { | ||
check | ||
.multiply( | ||
&value.to_additive_sharing_in_large_field(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This conversion to additive shares should be done outside this function.
src/protocol/aggregation/mod.rs
Outdated
debug_assert!(contribution_values.len() == breakdown_keys.len()); | ||
let num_records = contribution_values.len(); | ||
// for now, we assume that the bucket count is 2^BK::BITS | ||
let num_buckets = 1 << breakdown_keys[0].len(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This hides an important future optimization.
Our code for decomposing a value into a one hot encoding assumes that the value is in the range
Take the case where you have 0..=4 (that is, 4 is a possible value). In this case, you will receive a 3 bit input as there are 5 possible values. Now, with those constraints in mind, you can determine if the value is 4 with a single check: you check if the high bit is set.
Our current implementation uses the following table for a three bit input:
// 0: + | - | - | + | - | + | + | -
// 1: . | + | . | - | . | - | . | +
// 2: . | . | + | - | . | . | - | +
// 3: . | . | . | + | . | . | . | -
// 4: . | . | . | . | + | - | - | +
// 5: . | . | . | . | . | + | . | -
// 6: . | . | . | . | . | . | + | -
// 7: . | . | . | . | . | . | . | +
Which does not lead to any savings. The comparison with 0 always involves every combination that we compute.
First idea
With a smaller range you could maybe invert the checks. Comparing 7-index
puts you in the thin end of the table, which means that you only need to compute as many combinations as values that you need to compare.
That's not a win though, because the computations on the right are the expensive ones. You still need to compute combination of x1*x2
if you want to find x1*x2*x3
. Also, the first three values are free to compute anyway, but it suggests that there are savings to be had as the table grows in size.
Note that while it seems like computing 7-index
in bitwise logic is not cheap, you aren't really doing that here. You are just changing your expansions from xi*xj
and similar to (1-xi)*(1-xj)
.
Second idea
The savings more likely come from not needing to distinguish between 5 and 1 if 5 is out of range. A value can be considered to be 1 if the 1 bit is set and the 2 bit is clear. There is no need to look at the 4 bit. So you get is_one = x1*(1-x2)
.
The first cut of that produces a table like this.
// 0: + | - | - | + | - | + | + | -
// 1: . | + | . | -
// 2: . | . | + | -
// 3: . | . | . | +
// 4: . | . | . | . | +
That's awkward, because it doesn't save any effort. That 0 row is computing (1-x1)*(1-x2)*(1-x3)
, which has all of the combinations.
But we can still borrow the inversion idea from above. Then, rather than computing all the combinations of x3 with the previous combinations, we can compute the single combination (1-x1)*(1-x2)*(1-x3)
directly. Rather than three additional multiplications (one of the values on the right half of the table is just x3, remember), there are now just two additional multiplications.
// x x
// 0^: . | . | . | . | . | . | . | +
// 1v: . | + | . | -
// 2v: . | . | + | -
// 3v: . | . | . | +
// 4v: . | . | . | . | +
OK, so we saved a whole multiplication, so it is clear that there is potential here. But this savings will be eliminated when we have a larger range. Take 0..=5. The value we use for the 1 is an intermediate value, so that doesn't cost more. But the additional multiplication needed to distinguish 4 and 5 eliminates the saved multiplication.
// 0^: . | . | . | . | . | . | . | + = x1'*x2'*x3'
// 1^: . | . | . | . | . | . | + | - = (1-x1')*x2'*x3' = x2'*x3'-x1'*x2'*x3' {no extra}
// 2v: . | . | + | -
// 3v: . | . | . | +
// 4v: . | . | . | . | + | - = x3 * (1-x1) {one extra multiply}
// 5v: . | . | . | . | . | + = x3 * x1
Once we get to 0..=6, we might as well go back to the original model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My first impression is that the savings here may not justify the increased complexity. This seems like a very small potential saving compared to other things we can do.
Thanks for reviewing!
I found a post on the web that solved the similar error by wrapping the function and simply returning it with |
you could try Line 17 in e3d9748
|
12da6d5
to
905f997
Compare
clippy is updated and complaining. To be rebased on top of #782.
This diff implements the sparse aggregation protocol. The input is a vector of
(contribution_value, breakdown_key)
pairs wherebreakdown_key
indicates to which bucket thecontribution_value
should be added. Since the breakdown keys are secret shared, we usecheck_everything
(Sierpiński triangle thing) for equality checks. A contribution value is multiplied by a share of 1 or 0 so that the value is only contributed when the breakdown_key matches the bucket index.