Skip to content

Commit

Permalink
Fix clippy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
taikiy committed Sep 19, 2023
1 parent 905f997 commit 834616c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
17 changes: 9 additions & 8 deletions src/protocol/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub use input::SparseAggregateInputRow;
use ipa_macros::step;
use strum::AsRefStr;

use super::{context::Context, sort::check_everything, step::BitOpStep, RecordId};
use super::{context::Context, sort::bitwise_to_onehot, step::BitOpStep, RecordId};
use crate::{
error::Error,
ff::{Field, GaloisField, Gf2, PrimeField, Serializable},
Expand Down Expand Up @@ -71,9 +71,13 @@ where

// convert the input from `[Z2]^u` into `[Zp]^u`
let (converted_value_bits, converted_breakdown_key_bits) = (
upgrade_bit_shares(ctx.narrow(&Step::ConvertValueBits), contributions, CV::BITS),
upgrade_bit_shares(
ctx.narrow(&Step::ConvertBreakdownKeyBits),
&ctx.narrow(&Step::ConvertValueBits),
contributions,
CV::BITS,
),
upgrade_bit_shares(
&ctx.narrow(&Step::ConvertBreakdownKeyBits),
breakdowns,
BK::BITS,
),
Expand Down Expand Up @@ -109,9 +113,6 @@ where
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
{
// TODO: use exactsizestream trait
// debug_assert!(contribution_values.len() == breakdown_keys.len());

let equality_check_ctx = ctx.narrow(&Step::ComputeEqualityChecks);

// Generate N streams for each bucket specified by the `num_buckets`.
Expand All @@ -126,7 +127,7 @@ where
let eq_ctx = &equality_check_ctx;
let c = ctx.clone();
async move {
let equality_checks = check_everything(eq_ctx.clone(), i, &bk?).await?;
let equality_checks = bitwise_to_onehot(eq_ctx.clone(), i, &bk?).await?;
equality_bits_times_value(&c, equality_checks, num_buckets, v?, i).await
}
}),
Expand Down Expand Up @@ -173,7 +174,7 @@ where
}

fn upgrade_bit_shares<'a, F, C, S, I, G>(
ctx: C,
ctx: &C,
input_rows: I,
num_bits: u32,
) -> impl Stream<Item = Result<BitDecomposed<S>, Error>> + 'a
Expand Down
4 changes: 2 additions & 2 deletions src/protocol/attribution/aggregate_credit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
protocol::{
context::{UpgradableContext, UpgradedContext, Validator},
modulus_conversion::convert_bits,
sort::{check_everything, generate_permutation::ShuffledPermutationWrapper},
sort::{bitwise_to_onehot, generate_permutation::ShuffledPermutationWrapper},
step::BitOpStep,
BasicProtocols, RecordId,
},
Expand Down Expand Up @@ -114,7 +114,7 @@ where
let ceq = &equality_check_context;
let cmul = &check_times_credit_context;
async move {
let equality_checks = check_everything(ceq.clone(), i, &bk?).await?;
let equality_checks = bitwise_to_onehot(ceq.clone(), i, &bk?).await?;
ceq.try_join(equality_checks.into_iter().take(to_take).enumerate().map(
|(check_idx, check)| {
let step = BitOpStep::from(check_idx);
Expand Down
2 changes: 1 addition & 1 deletion src/protocol/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub(crate) enum ReshareStep {
///
/// # Errors
/// If any multiplication fails, or if the record is too long (e.g. more than 64 multiplications required)
pub async fn check_everything<F, C, S>(
pub async fn bitwise_to_onehot<F, C, S>(
ctx: C,
record_idx: usize,
record: &[S],
Expand Down
8 changes: 4 additions & 4 deletions src/protocol/sort/multi_bit_permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
error::Error,
ff::PrimeField,
protocol::{
basics::SumOfProducts, context::UpgradedContext, sort::check_everything, BasicProtocols,
basics::SumOfProducts, context::UpgradedContext, sort::bitwise_to_onehot, BasicProtocols,
RecordId,
},
secret_sharing::{
Expand Down Expand Up @@ -66,7 +66,7 @@ where
.iter()
.zip(repeat(ctx.set_total_records(num_records)))
.enumerate()
.map(|(idx, (record, ctx))| check_everything(ctx, idx, record)),
.map(|(idx, (record, ctx))| bitwise_to_onehot(ctx, idx, record)),
)
.await?;

Expand Down Expand Up @@ -117,7 +117,7 @@ mod tests {
ff::{Field, Fp31},
protocol::{
context::{Context, UpgradableContext, Validator},
sort::check_everything,
sort::bitwise_to_onehot,
},
secret_sharing::{BitDecomposed, SharedValue},
seq_join::SeqJoin,
Expand Down Expand Up @@ -170,7 +170,7 @@ mod tests {
let ctx = ctx.set_total_records(num_records);
let mut equality_check_futures = Vec::with_capacity(num_records);
for (i, record) in m_shares.iter().enumerate() {
equality_check_futures.push(check_everything(ctx.clone(), i, record));
equality_check_futures.push(bitwise_to_onehot(ctx.clone(), i, record));
}
ctx.try_join(equality_check_futures).await.unwrap()
})
Expand Down

0 comments on commit 834616c

Please sign in to comment.