Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add a more aggressive caching/precompuation strategy for first five elements #98

Merged
merged 5 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions banderwagon/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@ ark-ff = { version = "^0.4.2", default-features = false }
ark-ec = { version = "^0.4.2", default-features = false }
ark-serialize = { version = "^0.4.2", default-features = false }
rayon = "*"

[dev-dependencies]
hex = "0.4.3"
criterion = "0.5.1"
rand = "0.8.4"
sha3 = "0.10.8"

[features]
default = ["parallel"]
parallel = ["ark-ff/parallel", "ark-ff/asm", "ark-ec/parallel"]

[[bench]]
name = "benchmark"
harness = false
62 changes: 62 additions & 0 deletions banderwagon/benches/benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use banderwagon::{msm::MSMPrecompWnaf, msm_windowed_sign::MSMPrecompWindowSigned, Element, Fr};
use criterion::{criterion_group, criterion_main, Criterion};
use rand::RngCore;

pub fn msm_wnaf(c: &mut Criterion) {
const NUM_ELEMENTS: usize = 5;

let bases = random_point(120, NUM_ELEMENTS);
let scalars = random_scalars(NUM_ELEMENTS, 16);

let precomp = MSMPrecompWnaf::new(&bases, 12);

c.bench_function(&format!("msm wnaf: {}", NUM_ELEMENTS), |b| {
b.iter(|| precomp.mul(&scalars))
});

let precomp = MSMPrecompWindowSigned::new(&bases, 16);
c.bench_function(&format!("msm precomp 16: {}", NUM_ELEMENTS), |b| {
b.iter(|| precomp.mul(&scalars))
});
}

pub fn keccak_32bytes(c: &mut Criterion) {
use rand::Rng;
use sha3::{Digest, Keccak256};

c.bench_function("keccak 64 bytes", |b| {
b.iter_with_setup(
// Setup function: generates new random data for each iteration
|| {
let keccak = Keccak256::default();
let mut rand_buffer = [0u8; 64];
rand::thread_rng().fill(&mut rand_buffer);
(keccak, rand_buffer)
},
|(mut keccak, rand_buffer)| {
keccak.update(&rand_buffer);
keccak.finalize()
},
)
});
}

fn random_point(seed: u64, num_points: usize) -> Vec<Element> {
(0..num_points)
.map(|i| Element::prime_subgroup_generator() * Fr::from((seed + i as u64 + 1) as u64))
.collect()
}
fn random_scalars(num_points: usize, num_bytes: usize) -> Vec<Fr> {
use ark_ff::PrimeField;

(0..num_points)
.map(|_| {
let mut bytes = vec![0u8; num_bytes];
rand::thread_rng().fill_bytes(&mut bytes[..]);
Fr::from_le_bytes_mod_order(&bytes)
})
.collect()
}

criterion_group!(benches, msm_wnaf, keccak_32bytes);
criterion_main!(benches);
1 change: 1 addition & 0 deletions banderwagon/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod msm;
pub mod msm_windowed_sign;
pub mod trait_impls;

mod element;
Expand Down
241 changes: 241 additions & 0 deletions banderwagon/src/msm_windowed_sign.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
use crate::Element;
use ark_ec::CurveGroup;
use ark_ed_on_bls12_381_bandersnatch::{EdwardsAffine, EdwardsProjective, Fr};
use ark_ff::Zero;
use ark_ff::{BigInteger, BigInteger256};
use std::ops::Neg;

#[derive(Debug, Clone)]
pub struct MSMPrecompWindowSigned {
tables: Vec<Vec<EdwardsAffine>>,
num_windows: usize,
window_size: usize,
}

impl MSMPrecompWindowSigned {
pub fn new(bases: &[Element], window_size: usize) -> MSMPrecompWindowSigned {
use ark_ff::PrimeField;

let number_of_windows = Fr::MODULUS_BIT_SIZE as usize / window_size + 1;

let precomputed_points: Vec<_> = bases
.iter()
.map(|point| {
Self::precompute_points(
window_size,
number_of_windows,
EdwardsAffine::from(point.0),
)
})
.collect();

MSMPrecompWindowSigned {
window_size,
tables: precomputed_points,
num_windows: number_of_windows,
}
}

fn precompute_points(
window_size: usize,
number_of_windows: usize,
point: EdwardsAffine,
) -> Vec<EdwardsAffine> {
let window_size_scalar = Fr::from(1 << window_size);
use ark_ff::Field;

use rayon::prelude::*;

let all_tables: Vec<_> = (0..number_of_windows)
.into_par_iter()
.flat_map(|window_index| {
let window_scalar = window_size_scalar.pow([window_index as u64]);
let mut lookup_table = Vec::with_capacity(1 << (window_size - 1));
let point = EdwardsProjective::from(point) * window_scalar;
let mut current = point;
// Compute and store multiples
for _ in 0..(1 << (window_size - 1)) {
lookup_table.push(current);
current += point;
}
EdwardsProjective::normalize_batch(&lookup_table)
})
.collect();

all_tables
}

pub fn mul(&self, scalars: &[Fr]) -> Element {
let scalars_bytes: Vec<_> = scalars
.iter()
.map(|a| {
let bigint: BigInteger256 = (*a).into();
bigint.to_bytes_le()
})
.collect();

let mut points_to_add = Vec::new();

for window_idx in 0..self.num_windows {
for (scalar_idx, scalar_bytes) in scalars_bytes.iter().enumerate() {
let sub_table = &self.tables[scalar_idx];
let point_idx =
get_booth_index(window_idx, self.window_size, scalar_bytes.as_ref());

if point_idx == 0 {
continue;
}
let sign = point_idx.is_positive();
let point_idx = point_idx.unsigned_abs() as usize - 1;

// Scale the point index by the window index to figure out whether
// we need P, 2^wP, 2^{2w}P, etc
let scaled_point_index = window_idx * (1 << (self.window_size - 1)) + point_idx;
let mut point = sub_table[scaled_point_index];

if !sign {
point = -point;
}

points_to_add.push(point);
}
}

let mut result = EdwardsProjective::zero();
for point in points_to_add {
result += point;
}

Element(result)
}
}

// TODO: Link to halo2 file + docs + comments
pub fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
// Booth encoding:
// * step by `window` size
// * slice by size of `window + 1``
// * each window overlap by 1 bit
// * append a zero bit to the least significant end
// Indexing rule for example window size 3 where we slice by 4 bits:
// `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
// So we can reduce the bucket size without preprocessing scalars
// and remembering them as in classic signed digit encoding

let skip_bits = (window_index * window_size).saturating_sub(1);
let skip_bytes = skip_bits / 8;

// fill into a u32
let mut v: [u8; 4] = [0; 4];
for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
*dst = *src
}
let mut tmp = u32::from_le_bytes(v);

// pad with one 0 if slicing the least significant window
if window_index == 0 {
tmp <<= 1;
}

// remove further bits
tmp >>= skip_bits - (skip_bytes * 8);
// apply the booth window
tmp &= (1 << (window_size + 1)) - 1;

let sign = tmp & (1 << window_size) == 0;

// div ceil by 2
tmp = (tmp + 1) >> 1;

// find the booth action index
if sign {
tmp as i32
} else {
((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
}
}

#[test]
fn smoke_test_interop_strauss() {
use ark_ff::UniformRand;

let length = 5;
let scalars: Vec<_> = (0..length)
.map(|_| Fr::rand(&mut rand::thread_rng()))
.collect();
let points: Vec<_> = (0..length)
.map(|_| Element::prime_subgroup_generator() * Fr::rand(&mut rand::thread_rng()))
.collect();

let precomp = MSMPrecompWindowSigned::new(&points, 2);
let result = precomp.mul(&scalars);

let mut expected = Element::zero();
for (scalar, point) in scalars.into_iter().zip(points) {
expected += point * scalar
}

assert_eq!(expected, result)
}

#[cfg(test)]
mod booth_tests {
use std::ops::Neg;

use ark_ed_on_bls12_381_bandersnatch::Fr;
use ark_ff::{BigInteger, BigInteger256, Field, PrimeField};

use super::get_booth_index;
use crate::Element;

#[test]
fn smoke_scalar_mul() {
let gen = Element::prime_subgroup_generator();
let s = -Fr::ONE;

let res = gen * s;

let got = mul(&s, &gen, 4);

assert_eq!(Element::from(res), got)
}

fn mul(scalar: &Fr, point: &Element, window: usize) -> Element {
let u_bigint: BigInteger256 = (*scalar).into();
use ark_ff::Field;
let u = u_bigint.to_bytes_le();
let n = Fr::MODULUS_BIT_SIZE as usize / window + 1;

let table = (0..=1 << (window - 1))
.map(|i| point * &Fr::from(i as u64))
.collect::<Vec<_>>();

let table_scalars = (0..=1 << (window - 1))
.map(|i| Fr::from(i as u64))
.collect::<Vec<_>>();

let mut acc: Element = Element::zero();
let mut acc_scalar = Fr::ZERO;
for i in (0..n).rev() {
for _ in 0..window {
acc = acc + acc;
acc_scalar = acc_scalar + acc_scalar;
}

let idx = get_booth_index(i as usize, window, u.as_ref());

if idx.is_negative() {
acc += table[idx.unsigned_abs() as usize].neg();
acc_scalar -= table_scalars[idx.unsigned_abs() as usize];
}
if idx.is_positive() {
acc += table[idx.unsigned_abs() as usize];
acc_scalar += table_scalars[idx.unsigned_abs() as usize];
}
}

assert_eq!(acc_scalar, *scalar);

acc.into()
}
}
24 changes: 21 additions & 3 deletions ipa-multipoint/src/committer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use banderwagon::{msm::MSMPrecompWnaf, Element, Fr};
use banderwagon::{msm::MSMPrecompWnaf, msm_windowed_sign::MSMPrecompWindowSigned, Element, Fr};

// This is the functionality that commits to the branch nodes and computes the delta optimization
// For consistency with the Pcs, ensure that this component uses the same CRS as the Pcs
Expand All @@ -24,19 +24,31 @@ pub trait Committer {

#[derive(Clone, Debug)]
pub struct DefaultCommitter {
precomp_first_five: MSMPrecompWindowSigned,
precomp: MSMPrecompWnaf,
}

impl DefaultCommitter {
pub fn new(points: &[Element]) -> Self {
// Take the first five elements and use a more aggressive optimization strategy
// since they are used for computing storage keys.

let (points_five, _) = points.split_at(5);
let precomp_first_five = MSMPrecompWindowSigned::new(points_five, 16);
let precomp = MSMPrecompWnaf::new(points, 12);

Self { precomp }
Self {
precomp,
precomp_first_five,
}
}
}

impl Committer for DefaultCommitter {
fn commit_lagrange(&self, evaluations: &[Fr]) -> Element {
if evaluations.len() <= 5 {
return self.precomp_first_five.mul(evaluations);
}
// Preliminary benchmarks indicate that the parallel version is faster
// for vectors of length 64 or more
if evaluations.len() >= 64 {
Expand All @@ -47,6 +59,12 @@ impl Committer for DefaultCommitter {
}

fn scalar_mul(&self, value: Fr, lagrange_index: usize) -> Element {
self.precomp.mul_index(value, lagrange_index)
if lagrange_index < 5 {
let mut arr = [Fr::from(0u64); 5];
arr[lagrange_index] = value;
self.precomp_first_five.mul(&arr)
} else {
self.precomp.mul_index(value, lagrange_index)
}
}
}
Loading