Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse aggregation protocol #783

Merged
merged 7 commits into from
Sep 25, 2023

Conversation

taikiy
Copy link
Contributor

@taikiy taikiy commented Aug 25, 2023

clippy is updated and complaining. To be rebased on top of #782.

This diff implements the sparse aggregation protocol. The input is a vector of (contribution_value, breakdown_key) pairs where breakdown_key indicates to which bucket the contribution_value should be added. Since the breakdown keys are secret shared, we use check_everything (Sierpiński triangle thing) for equality checks. A contribution value is multiplied by a share of 1 or 0 so that the value is only contributed when the breakdown_key matches the bucket index.

}

/// Binary-share aggregation protocol.
/// Binary-share aggregation protocol for a sparse breakdown key vector input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is the interface of this protocol, Can you also add a comment about the input and output format i.e. an array of [conversion value, breakdown key] and also output

},
)
.await?,
upgrade_bit_shares(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done in parallel with the line above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to use try_join on the tuples.

Comment on lines +166 to +173
streams
.try_fold(vec![S::ZERO; num_buckets], |mut acc, bucket| async move {
for (i, b) in bucket.into_iter().enumerate() {
acc[i] += &b;
}
Ok(acc)
})
.await?;

Ok(aggregate)
.await
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misunderstanding this, but this looks like it's operating across all of the streams, but I thought each stream was able to progress independently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think try_fold applies async computations on each stream independently. cc: @martinthomson

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure that this will work sequentially.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akoshelev and @martinthomson - can you please advise on how to write this code for maximum streaming performance?

Copy link
Member

@martinthomson martinthomson Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the use of a fold() here is right. You want to operate on a per-row basis and accumulate across each of the buckets.

When we decompose the breakdown key into its one hot encoding, that operation is not inherently parallel, so each row is going to need a tree of combinations (a*b, b*c, a*c, a*b*c, etc...). We can parallelize those somewhat, and we do some of that already (in pregenerate_all_combinations). There might be some room for optimizing the circuit depth if we go deeper than 3 (abc*d isn't ((a*b)*c)*d as we currently calculate, but (a*b)*(c*d), ...), but that's small dice.

Then we have a bunch of operations that seem like they could be parallelized. The value is multiplied for each of the buckets of the decomposed breakdown key. Those could proceed in parallel proper. This code only sort of does that. It runs eq_ctx.try_join(), which is a sequential operation for large numbers of inputs. But we have a small number of inputs in this case, so it is really a parallel operation (until we have more breakdown keys than active work). I'd probably recommend a parallel join here: we don't expect the number of breakdown keys to exceed our ability to track multiplications for each.

Either way, what this structure does is make progress on row N contingent on row N-1 completing. That's OK, I think.

The alternative is to let each bucket in the histogram proceed independently, but that is probably worse overall. The only reason you might want to pursue that approach is that it might (at some point) allow us to properly parallelize the multiplication stage (as in, actually use multiple threads).

Comment on lines 129 to 162
let streams = seq_join(
ctx.active_work(),
stream_iter(breakdown_keys)
.zip(stream_iter(contribution_values))
.enumerate()
.map(|(i, (bk, v))| {
let eq_ctx = &equality_check_ctx;
let mul_ctx = &check_times_value_ctx;
async move {
let equality_checks = check_everything(eq_ctx.clone(), i, &bk).await?;
let value = &v;
eq_ctx
.try_join(
equality_checks
.into_iter()
.take(num_buckets)
.enumerate()
.map(|(check_idx, check)| {
let step = BitOpStep::from(check_idx);
let c = mul_ctx.narrow(&step);
let record_id = RecordId::from(i);
async move {
check
.multiply(
&value.to_additive_sharing_in_large_field(),
c,
record_id,
)
.await
}
}),
)
.await
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this really difficult to read / digest / understand.

Is there anything we can do to improve the code readability here? Perhaps something that leads to less indentation? Maybe factor something out into another function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted the multiplication part.

Copy link
Collaborator

@benjaminsavage benjaminsavage left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests pass, so the logic is clearly accurate, but I find it quite hard to follow. I would like to try to find a way to make it easier to read / digest / understand.

const INPUT: &[u32] = &[0, 0, 18, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 10, 0, 0, 6, 0];
const EXPECTED: &[u128] = &[28, 0, 0, 6, 1, 0, 0, 8];

const INPUT: &[(u32, u32)] = &[
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had trouble reading the ordering here. I read this as a k-v map, where it is a v-k map instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also got caught up on the same thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Swapped the order and added a comment.

Comment on lines 108 to 109
#[tracing::instrument(name = "aggregate_values_per_bucket", skip_all)]
pub async fn aggregate_values_per_bucket<F, C, S>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to use the term "sparse" here for consistency?

async move {
let equality_checks = check_everything(eq_ctx.clone(), i, &bk).await?;
equality_bits_times_value(&c, equality_checks, num_buckets, v, num_records, i)
.await
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My assumption is that we want to send work over to other streams right here (as opposed to in a second loop).

Right at this point (after this await) we should have N secret-shares. I think this is the place to just iterate over all of them and essentially just Send them over to other threads for processing.

I think creating a second loop might work, but it's confusing because it requires you to understand that the loop is being evaluate in a lazy way...

.take(num_buckets)
.enumerate()
.map(|(check_idx, check)| {
let step = BitOpStep::from(check_idx);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are kind of abusing "BitOpStep" now. This was designed for bitwise operations and cannot go above 64. There might be thousands of breakdowns - so maybe we should think more about this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a PR in review that could solve this. We can create a step like BitOpStep but with a number of our choice.

let v = &value_bits;
async move {
check
.multiply(&v.to_additive_sharing_in_large_field(), c, record_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not be doing this in a loop. We should do it only once outside of the loop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Funny. I had the same comment, but I hadn't even thought of the performance implication, just that this function had no business dealing with the conversion.

input_rows: &[SparseAggregateInputRow<CV, BK>],
) -> Vec<BitDecomposed<Replicated<Gf2>>>
num_bits: u32,
f: H,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear from the name (f) or the comment (non-existing) what this param does

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the comment.

Copy link
Member

@martinthomson martinthomson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I got thinking...

let eq_ctx = &equality_check_ctx;
let mul_ctx = &check_times_value_ctx;
async move {
let equality_checks = check_everything(eq_ctx.clone(), i, &bk).await?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'm having real trouble with check_everything as a name. The function takes an array of arithmetic shares of the bits of a number and returns a sparse array with a share of 1 at the index indicated by the value of that number (and 0 elsewhere).

binary_number_to_sparse_index() or binary_to_onehot() ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like binary_to_onehot() or bitwise_to_onehot()

// Generate N streams for each bucket specified by the breakdown key (N = |breakdown_keys|).
// A stream is pipeline of contribution values multiplied by the "check bit". A check bit is
// a bit that is a share of 1 if the breakdown key matches the bucket, and 0 otherwise.
let streams = seq_join(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a set of streams, but a single stream.

Comment on lines +166 to +173
streams
.try_fold(vec![S::ZERO; num_buckets], |mut acc, bucket| async move {
for (i, b) in bucket.into_iter().enumerate() {
acc[i] += &b;
}
Ok(acc)
})
.await?;

Ok(aggregate)
.await
Copy link
Member

@martinthomson martinthomson Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the use of a fold() here is right. You want to operate on a per-row basis and accumulate across each of the buckets.

When we decompose the breakdown key into its one hot encoding, that operation is not inherently parallel, so each row is going to need a tree of combinations (a*b, b*c, a*c, a*b*c, etc...). We can parallelize those somewhat, and we do some of that already (in pregenerate_all_combinations). There might be some room for optimizing the circuit depth if we go deeper than 3 (abc*d isn't ((a*b)*c)*d as we currently calculate, but (a*b)*(c*d), ...), but that's small dice.

Then we have a bunch of operations that seem like they could be parallelized. The value is multiplied for each of the buckets of the decomposed breakdown key. Those could proceed in parallel proper. This code only sort of does that. It runs eq_ctx.try_join(), which is a sequential operation for large numbers of inputs. But we have a small number of inputs in this case, so it is really a parallel operation (until we have more breakdown keys than active work). I'd probably recommend a parallel join here: we don't expect the number of breakdown keys to exceed our ability to track multiplications for each.

Either way, what this structure does is make progress on row N contingent on row N-1 completing. That's OK, I think.

The alternative is to let each bucket in the histogram proceed independently, but that is probably worse overall. The only reason you might want to pursue that approach is that it might (at some point) allow us to properly parallelize the multiplication stage (as in, actually use multiple threads).

.collect::<Vec<_>>()
}
.map(|row| BitDecomposed::decompose(num_bits, |i| f(row, i)))
.collect::<Vec<_>>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't collect if you only iterate (or stream) the results.

stream_iter(gf2_bits),
0..num_bits,
)
.try_collect::<Vec<_>>()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMPORTANT: This collect is worse because it covers all inputs. You should be able to pass streams into the aggregation function. This is probably the single biggest thing you can to do improve throughput.

async move {
check
.multiply(
&value.to_additive_sharing_in_large_field(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The depth of this indentation bothers me some. Factoring might help, but fundamentally you are just doing two operations:

  1. decomposition of the breakdown key into its one hot encoding
  2. multiplication of the elements of that vector by the value

Maybe the way to construct this is then approximately:

let decomposed = seq_join(
  eq_ctx.active_work(),
  breakdown_keys.enumerate().map(|(row, bk)| check_everything(eq_ctx.clone(), row, &bk))
);
let multiplied = seq_join(
  eq_ctx.active_work(),
  decomposed.zip(values).enumerate().map(|(row, (buckets, v))| {
    mul_ctx.parallel_join(
      buckets.enumerate().map(|(column, b)| {
        b.multiply(&value, mul_ctx.narrow(BitOpStep::from(column), row)
      })
    )
  })
});
let histogram = multiplied.try_fold(...); // as below

I'm not sure about the use of BitOpStep here. If we want a large histogram, we'll need something bigger.

async move {
check
.multiply(
&value.to_additive_sharing_in_large_field(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conversion to additive shares should be done outside this function.

debug_assert!(contribution_values.len() == breakdown_keys.len());
let num_records = contribution_values.len();
// for now, we assume that the bucket count is 2^BK::BITS
let num_buckets = 1 << breakdown_keys[0].len();
Copy link
Member

@martinthomson martinthomson Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hides an important future optimization.

Our code for decomposing a value into a one hot encoding assumes that the value is in the range $[0..2^n)$. That's convenient, but it leaves out optimizations that apply to ranges that are not clean powers of two.

Take the case where you have 0..=4 (that is, 4 is a possible value). In this case, you will receive a 3 bit input as there are 5 possible values. Now, with those constraints in mind, you can determine if the value is 4 with a single check: you check if the high bit is set.

Our current implementation uses the following table for a three bit input:

    // 0: + | - | - | + | - | + | + | -
    // 1: . | + | . | - | . | - | . | +
    // 2: . | . | + | - | . | . | - | +
    // 3: . | . | . | + | . | . | . | -
    // 4: . | . | . | . | + | - | - | +
    // 5: . | . | . | . | . | + | . | -
    // 6: . | . | . | . | . | . | + | -
    // 7: . | . | . | . | . | . | . | +

Which does not lead to any savings. The comparison with 0 always involves every combination that we compute.

First idea

With a smaller range you could maybe invert the checks. Comparing 7-index puts you in the thin end of the table, which means that you only need to compute as many combinations as values that you need to compare.

That's not a win though, because the computations on the right are the expensive ones. You still need to compute combination of x1*x2 if you want to find x1*x2*x3. Also, the first three values are free to compute anyway, but it suggests that there are savings to be had as the table grows in size.

Note that while it seems like computing 7-index in bitwise logic is not cheap, you aren't really doing that here. You are just changing your expansions from xi*xj and similar to (1-xi)*(1-xj).

Second idea

The savings more likely come from not needing to distinguish between 5 and 1 if 5 is out of range. A value can be considered to be 1 if the 1 bit is set and the 2 bit is clear. There is no need to look at the 4 bit. So you get is_one = x1*(1-x2).

The first cut of that produces a table like this.

    // 0: + | - | - | + | - | + | + | -
    // 1: . | + | . | -
    // 2: . | . | + | -
    // 3: . | . | . | +
    // 4: . | . | . | . | +

That's awkward, because it doesn't save any effort. That 0 row is computing (1-x1)*(1-x2)*(1-x3), which has all of the combinations.

But we can still borrow the inversion idea from above. Then, rather than computing all the combinations of x3 with the previous combinations, we can compute the single combination (1-x1)*(1-x2)*(1-x3) directly. Rather than three additional multiplications (one of the values on the right half of the table is just x3, remember), there are now just two additional multiplications.

    //                         x   x
    // 0^: . | . | . | . | . | . | . | +
    // 1v: . | + | . | -
    // 2v: . | . | + | -
    // 3v: . | . | . | +
    // 4v: . | . | . | . | +

OK, so we saved a whole multiplication, so it is clear that there is potential here. But this savings will be eliminated when we have a larger range. Take 0..=5. The value we use for the 1 is an intermediate value, so that doesn't cost more. But the additional multiplication needed to distinguish 4 and 5 eliminates the saved multiplication.

    // 0^: . | . | . | . | . | . | . | +   = x1'*x2'*x3'
    // 1^: . | . | . | . | . | . | + | -   = (1-x1')*x2'*x3' = x2'*x3'-x1'*x2'*x3' {no extra}
    // 2v: . | . | + | -
    // 3v: . | . | . | +
    // 4v: . | . | . | . | + | -   = x3 * (1-x1) {one extra multiply}
    // 5v: . | . | . | . | . | +   = x3 * x1

Once we get to 0..=6, we might as well go back to the original model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first impression is that the savings here may not justify the increased complexity. This seems like a very small potential saving compared to other things we can do.

@taikiy
Copy link
Contributor Author

taikiy commented Sep 7, 2023

Thanks for reviewing!
I updated the PR according to the comments but now I'm getting this error and don't know how to resolve...

higher-ranked lifetime error
could not prove `[async block@src/protocol/aggregation/mod.rs:271:63: 279:14]: std::marker::Send`

code

I found a post on the web that solved the similar error by wrapping the function and simply returning it with Send bound signature, but that didn't seem to work. @martinthomson, @akoshelev, any idea how I can solve this? Thanks in advance.

@akoshelev
Copy link
Collaborator

Thanks for reviewing! I updated the PR according to the comments but now I'm getting this error and don't know how to resolve...

higher-ranked lifetime error
could not prove `[async block@src/protocol/aggregation/mod.rs:271:63: 279:14]: std::marker::Send`

code

I found a post on the web that solved the similar error by wrapping the function and simply returning it with Send bound signature, but that didn't seem to work. @martinthomson, @akoshelev, any idea how I can solve this? Thanks in advance.

you could try

/// This helper function might be necessary to convince the compiler that
- seems to be the same issue

@benjaminsavage benjaminsavage merged commit df2ae41 into private-attribution:main Sep 25, 2023
@taikiy taikiy deleted the that_thing branch October 24, 2023 11:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants