Skip to content

Commit

Permalink
chore: Add a more aggressive caching/precompuation strategy for first…
Browse files Browse the repository at this point in the history
… five elements in CRS (#98)
  • Loading branch information
kevaundray authored Sep 23, 2024
1 parent ec29458 commit 52c6463
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 3 deletions.
8 changes: 8 additions & 0 deletions banderwagon/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@ ark-ff = { version = "^0.4.2", default-features = false }
ark-ec = { version = "^0.4.2", default-features = false }
ark-serialize = { version = "^0.4.2", default-features = false }
rayon = "*"

[dev-dependencies]
hex = "0.4.3"
criterion = "0.5.1"
rand = "0.8.4"
sha3 = "0.10.8"

[features]
default = ["parallel"]
parallel = ["ark-ff/parallel", "ark-ff/asm", "ark-ec/parallel"]

[[bench]]
name = "benchmark"
harness = false
62 changes: 62 additions & 0 deletions banderwagon/benches/benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use banderwagon::{msm::MSMPrecompWnaf, msm_windowed_sign::MSMPrecompWindowSigned, Element, Fr};
use criterion::{criterion_group, criterion_main, Criterion};
use rand::RngCore;

pub fn msm_wnaf(c: &mut Criterion) {
const NUM_ELEMENTS: usize = 5;

let bases = random_point(120, NUM_ELEMENTS);
let scalars = random_scalars(NUM_ELEMENTS, 16);

let precomp = MSMPrecompWnaf::new(&bases, 12);

c.bench_function(&format!("msm wnaf: {}", NUM_ELEMENTS), |b| {
b.iter(|| precomp.mul(&scalars))
});

let precomp = MSMPrecompWindowSigned::new(&bases, 16);
c.bench_function(&format!("msm precomp 16: {}", NUM_ELEMENTS), |b| {
b.iter(|| precomp.mul(&scalars))
});
}

pub fn keccak_32bytes(c: &mut Criterion) {
use rand::Rng;
use sha3::{Digest, Keccak256};

c.bench_function("keccak 64 bytes", |b| {
b.iter_with_setup(
// Setup function: generates new random data for each iteration
|| {
let keccak = Keccak256::default();
let mut rand_buffer = [0u8; 64];
rand::thread_rng().fill(&mut rand_buffer);
(keccak, rand_buffer)
},
|(mut keccak, rand_buffer)| {
keccak.update(&rand_buffer);
keccak.finalize()
},
)
});
}

fn random_point(seed: u64, num_points: usize) -> Vec<Element> {
(0..num_points)
.map(|i| Element::prime_subgroup_generator() * Fr::from((seed + i as u64 + 1) as u64))
.collect()
}
fn random_scalars(num_points: usize, num_bytes: usize) -> Vec<Fr> {
use ark_ff::PrimeField;

(0..num_points)
.map(|_| {
let mut bytes = vec![0u8; num_bytes];
rand::thread_rng().fill_bytes(&mut bytes[..]);
Fr::from_le_bytes_mod_order(&bytes)
})
.collect()
}

criterion_group!(benches, msm_wnaf, keccak_32bytes);
criterion_main!(benches);
1 change: 1 addition & 0 deletions banderwagon/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod msm;
pub mod msm_windowed_sign;
pub mod trait_impls;

mod element;
Expand Down
241 changes: 241 additions & 0 deletions banderwagon/src/msm_windowed_sign.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
use crate::Element;
use ark_ec::CurveGroup;
use ark_ed_on_bls12_381_bandersnatch::{EdwardsAffine, EdwardsProjective, Fr};
use ark_ff::Zero;
use ark_ff::{BigInteger, BigInteger256};
use std::ops::Neg;

#[derive(Debug, Clone)]
pub struct MSMPrecompWindowSigned {
tables: Vec<Vec<EdwardsAffine>>,
num_windows: usize,
window_size: usize,
}

impl MSMPrecompWindowSigned {
pub fn new(bases: &[Element], window_size: usize) -> MSMPrecompWindowSigned {
use ark_ff::PrimeField;

let number_of_windows = Fr::MODULUS_BIT_SIZE as usize / window_size + 1;

let precomputed_points: Vec<_> = bases
.iter()
.map(|point| {
Self::precompute_points(
window_size,
number_of_windows,
EdwardsAffine::from(point.0),
)
})
.collect();

MSMPrecompWindowSigned {
window_size,
tables: precomputed_points,
num_windows: number_of_windows,
}
}

fn precompute_points(
window_size: usize,
number_of_windows: usize,
point: EdwardsAffine,
) -> Vec<EdwardsAffine> {
let window_size_scalar = Fr::from(1 << window_size);
use ark_ff::Field;

use rayon::prelude::*;

let all_tables: Vec<_> = (0..number_of_windows)
.into_par_iter()
.flat_map(|window_index| {
let window_scalar = window_size_scalar.pow([window_index as u64]);
let mut lookup_table = Vec::with_capacity(1 << (window_size - 1));
let point = EdwardsProjective::from(point) * window_scalar;
let mut current = point;
// Compute and store multiples
for _ in 0..(1 << (window_size - 1)) {
lookup_table.push(current);
current += point;
}
EdwardsProjective::normalize_batch(&lookup_table)
})
.collect();

all_tables
}

pub fn mul(&self, scalars: &[Fr]) -> Element {
let scalars_bytes: Vec<_> = scalars
.iter()
.map(|a| {
let bigint: BigInteger256 = (*a).into();
bigint.to_bytes_le()
})
.collect();

let mut points_to_add = Vec::new();

for window_idx in 0..self.num_windows {
for (scalar_idx, scalar_bytes) in scalars_bytes.iter().enumerate() {
let sub_table = &self.tables[scalar_idx];
let point_idx =
get_booth_index(window_idx, self.window_size, scalar_bytes.as_ref());

if point_idx == 0 {
continue;
}
let sign = point_idx.is_positive();
let point_idx = point_idx.unsigned_abs() as usize - 1;

// Scale the point index by the window index to figure out whether
// we need P, 2^wP, 2^{2w}P, etc
let scaled_point_index = window_idx * (1 << (self.window_size - 1)) + point_idx;
let mut point = sub_table[scaled_point_index];

if !sign {
point = -point;
}

points_to_add.push(point);
}
}

let mut result = EdwardsProjective::zero();
for point in points_to_add {
result += point;
}

Element(result)
}
}

// TODO: Link to halo2 file + docs + comments
pub fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
// Booth encoding:
// * step by `window` size
// * slice by size of `window + 1``
// * each window overlap by 1 bit
// * append a zero bit to the least significant end
// Indexing rule for example window size 3 where we slice by 4 bits:
// `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
// So we can reduce the bucket size without preprocessing scalars
// and remembering them as in classic signed digit encoding

let skip_bits = (window_index * window_size).saturating_sub(1);
let skip_bytes = skip_bits / 8;

// fill into a u32
let mut v: [u8; 4] = [0; 4];
for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
*dst = *src
}
let mut tmp = u32::from_le_bytes(v);

// pad with one 0 if slicing the least significant window
if window_index == 0 {
tmp <<= 1;
}

// remove further bits
tmp >>= skip_bits - (skip_bytes * 8);
// apply the booth window
tmp &= (1 << (window_size + 1)) - 1;

let sign = tmp & (1 << window_size) == 0;

// div ceil by 2
tmp = (tmp + 1) >> 1;

// find the booth action index
if sign {
tmp as i32
} else {
((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
}
}

#[test]
fn smoke_test_interop_strauss() {
use ark_ff::UniformRand;

let length = 5;
let scalars: Vec<_> = (0..length)
.map(|_| Fr::rand(&mut rand::thread_rng()))
.collect();
let points: Vec<_> = (0..length)
.map(|_| Element::prime_subgroup_generator() * Fr::rand(&mut rand::thread_rng()))
.collect();

let precomp = MSMPrecompWindowSigned::new(&points, 2);
let result = precomp.mul(&scalars);

let mut expected = Element::zero();
for (scalar, point) in scalars.into_iter().zip(points) {
expected += point * scalar
}

assert_eq!(expected, result)
}

#[cfg(test)]
mod booth_tests {
use std::ops::Neg;

use ark_ed_on_bls12_381_bandersnatch::Fr;
use ark_ff::{BigInteger, BigInteger256, Field, PrimeField};

use super::get_booth_index;
use crate::Element;

#[test]
fn smoke_scalar_mul() {
let gen = Element::prime_subgroup_generator();
let s = -Fr::ONE;

let res = gen * s;

let got = mul(&s, &gen, 4);

assert_eq!(Element::from(res), got)
}

fn mul(scalar: &Fr, point: &Element, window: usize) -> Element {
let u_bigint: BigInteger256 = (*scalar).into();
use ark_ff::Field;
let u = u_bigint.to_bytes_le();
let n = Fr::MODULUS_BIT_SIZE as usize / window + 1;

let table = (0..=1 << (window - 1))
.map(|i| point * &Fr::from(i as u64))
.collect::<Vec<_>>();

let table_scalars = (0..=1 << (window - 1))
.map(|i| Fr::from(i as u64))
.collect::<Vec<_>>();

let mut acc: Element = Element::zero();
let mut acc_scalar = Fr::ZERO;
for i in (0..n).rev() {
for _ in 0..window {
acc = acc + acc;
acc_scalar = acc_scalar + acc_scalar;
}

let idx = get_booth_index(i as usize, window, u.as_ref());

if idx.is_negative() {
acc += table[idx.unsigned_abs() as usize].neg();
acc_scalar -= table_scalars[idx.unsigned_abs() as usize];
}
if idx.is_positive() {
acc += table[idx.unsigned_abs() as usize];
acc_scalar += table_scalars[idx.unsigned_abs() as usize];
}
}

assert_eq!(acc_scalar, *scalar);

acc.into()
}
}
24 changes: 21 additions & 3 deletions ipa-multipoint/src/committer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use banderwagon::{msm::MSMPrecompWnaf, Element, Fr};
use banderwagon::{msm::MSMPrecompWnaf, msm_windowed_sign::MSMPrecompWindowSigned, Element, Fr};

// This is the functionality that commits to the branch nodes and computes the delta optimization
// For consistency with the Pcs, ensure that this component uses the same CRS as the Pcs
Expand All @@ -24,19 +24,31 @@ pub trait Committer {

#[derive(Clone, Debug)]
pub struct DefaultCommitter {
precomp_first_five: MSMPrecompWindowSigned,
precomp: MSMPrecompWnaf,
}

impl DefaultCommitter {
pub fn new(points: &[Element]) -> Self {
// Take the first five elements and use a more aggressive optimization strategy
// since they are used for computing storage keys.

let (points_five, _) = points.split_at(5);
let precomp_first_five = MSMPrecompWindowSigned::new(points_five, 16);
let precomp = MSMPrecompWnaf::new(points, 12);

Self { precomp }
Self {
precomp,
precomp_first_five,
}
}
}

impl Committer for DefaultCommitter {
fn commit_lagrange(&self, evaluations: &[Fr]) -> Element {
if evaluations.len() <= 5 {
return self.precomp_first_five.mul(evaluations);
}
// Preliminary benchmarks indicate that the parallel version is faster
// for vectors of length 64 or more
if evaluations.len() >= 64 {
Expand All @@ -47,6 +59,12 @@ impl Committer for DefaultCommitter {
}

fn scalar_mul(&self, value: Fr, lagrange_index: usize) -> Element {
self.precomp.mul_index(value, lagrange_index)
if lagrange_index < 5 {
let mut arr = [Fr::from(0u64); 5];
arr[lagrange_index] = value;
self.precomp_first_five.mul(&arr)
} else {
self.precomp.mul_index(value, lagrange_index)
}
}
}

0 comments on commit 52c6463

Please sign in to comment.