From dec49a95fd483d8a90b7b602d13bfa8ea9485491 Mon Sep 17 00:00:00 2001 From: Sebastien Chapuis Date: Tue, 14 Jan 2025 13:12:34 +0100 Subject: [PATCH] Parallelize few stuff --- curves/src/pasta/fields/fp.rs | 2 +- curves/src/pasta/fields/fq.rs | 3 +- kimchi/src/circuits/expr.rs | 46 +++++++----- poly-commitment/src/combine.rs | 2 +- poly-commitment/src/evaluation_proof.rs | 94 ++++++++++++++----------- poly-commitment/src/msm.rs | 7 +- 6 files changed, 86 insertions(+), 68 deletions(-) diff --git a/curves/src/pasta/fields/fp.rs b/curves/src/pasta/fields/fp.rs index ed5980b1d1..5365cf232b 100644 --- a/curves/src/pasta/fields/fp.rs +++ b/curves/src/pasta/fields/fp.rs @@ -1,4 +1,4 @@ -use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256Parameters, Fp256}; +use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256, Fp256Parameters}; pub type Fp = Fp256; diff --git a/curves/src/pasta/fields/fq.rs b/curves/src/pasta/fields/fq.rs index 47731ed58c..80d027a9b7 100644 --- a/curves/src/pasta/fields/fq.rs +++ b/curves/src/pasta/fields/fq.rs @@ -1,8 +1,7 @@ -use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256Parameters, Fp256}; +use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256, Fp256Parameters}; pub type Fq = Fp256; - #[derive(Debug, Clone, Copy, Default, Eq, PartialEq, PartialOrd, Ord, Hash)] pub struct FqParameters; diff --git a/kimchi/src/circuits/expr.rs b/kimchi/src/circuits/expr.rs index f331d96b1b..cbe7fc28fc 100644 --- a/kimchi/src/circuits/expr.rs +++ b/kimchi/src/circuits/expr.rs @@ -1021,24 +1021,24 @@ fn unnormalized_lagrange_evals( impl<'a, F: FftField> EvalResult<'a, F> { fn init_ F>( - res_domain: (Domain, D), + res_domain: (Domain, &D), g: G, ) -> Evaluations> { let n = res_domain.1.size(); Evaluations::>::from_vec_and_domain( (0..n).into_par_iter().map(g).collect(), - res_domain.1, + res_domain.1.clone(), ) } - fn init F>(res_domain: (Domain, D), g: G) -> Self { + fn init F>(res_domain: (Domain, &D), g: G) -> Self { Self::Evals { domain: res_domain.0, evals: Self::init_(res_domain, g), } } - fn add<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D)) -> EvalResult<'c, F> { + fn add<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, &D)) -> EvalResult<'c, F> { use EvalResult::*; match (self, other) { (Constant(x), Constant(y)) => Constant(x + y), @@ -1074,7 +1074,7 @@ impl<'a, F: FftField> EvalResult<'a, F> { .collect(); Evals { domain: res_domain.0, - evals: Evaluations::>::from_vec_and_domain(v, res_domain.1), + evals: Evaluations::>::from_vec_and_domain(v, res_domain.1.clone()), } } ( @@ -1151,13 +1151,13 @@ impl<'a, F: FftField> EvalResult<'a, F> { Evals { domain: res_domain.0, - evals: Evaluations::>::from_vec_and_domain(v, res_domain.1), + evals: Evaluations::>::from_vec_and_domain(v, res_domain.1.clone()), } } } } - fn sub<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D)) -> EvalResult<'c, F> { + fn sub<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, &D)) -> EvalResult<'c, F> { use EvalResult::*; match (self, other) { (Constant(x), Constant(y)) => Constant(x - y), @@ -1275,7 +1275,7 @@ impl<'a, F: FftField> EvalResult<'a, F> { } } - fn pow<'b>(self, d: u64, res_domain: (Domain, D)) -> EvalResult<'b, F> { + fn pow<'b>(self, d: u64, res_domain: (Domain, &D)) -> EvalResult<'b, F> { let mut acc = EvalResult::Constant(F::one()); for i in (0..u64::BITS).rev() { acc = acc.square(res_domain); @@ -1288,7 +1288,7 @@ impl<'a, F: FftField> EvalResult<'a, F> { acc } - fn square<'b>(self, res_domain: (Domain, D)) -> EvalResult<'b, F> { + fn square<'b>(self, res_domain: (Domain, &D)) -> EvalResult<'b, F> { use EvalResult::*; match self { Constant(x) => Constant(x.square()), @@ -1312,7 +1312,7 @@ impl<'a, F: FftField> EvalResult<'a, F> { } } - fn mul<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D)) -> EvalResult<'c, F> { + fn mul<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, &D)) -> EvalResult<'c, F> { use EvalResult::*; match (self, other) { (Constant(x), Constant(y)) => Constant(x * y), @@ -1424,6 +1424,15 @@ fn get_domain(d: Domain, env: &Environment) -> D { } } +fn get_domain_ref<'a, F: FftField>(d: Domain, env: &'a Environment) -> &'a D { + match d { + Domain::D1 => &env.domain.d1, + Domain::D2 => &env.domain.d2, + Domain::D4 => &env.domain.d4, + Domain::D8 => &env.domain.d8, + } +} + impl Expr> { /// Convenience function for constructing expressions from literal /// field elements. @@ -1713,13 +1722,13 @@ impl Expr { assert_eq!(domain, d); evals } - EvalResult::Constant(x) => EvalResult::init_((d, get_domain(d, env)), |_| x), + EvalResult::Constant(x) => EvalResult::init_((d, get_domain_ref(d, env)), |_| x), EvalResult::SubEvals { evals, domain: d_sub, shift: s, } => { - let res_domain = get_domain(d, env); + let res_domain = get_domain_ref(d, env); let scale = (d_sub as usize) / (d as usize); assert!(scale != 0); EvalResult::init_((d, res_domain), |i| { @@ -1738,7 +1747,7 @@ impl Expr { where 'a: 'b, { - let dom = (d, get_domain(d, env)); + let dom = (d, get_domain_ref(d, env)); let res: EvalResult<'a, F> = match self { Expr::Square(x) => match x.evaluations_helper(cache, d, env) { @@ -1800,10 +1809,11 @@ impl Expr { Expr::Pow(x, p) => { let x = x.evaluations_helper(cache, d, env); match x { - Either::Left(x) => x.pow(*p, (d, get_domain(d, env))), - Either::Right(id) => { - id.get_from(cache).unwrap().pow(*p, (d, get_domain(d, env))) - } + Either::Left(x) => x.pow(*p, (d, get_domain_ref(d, env))), + Either::Right(id) => id + .get_from(cache) + .unwrap() + .pow(*p, (d, get_domain_ref(d, env))), } } Expr::VanishesOnZeroKnowledgeAndPreviousRows => EvalResult::SubEvals { @@ -1837,7 +1847,7 @@ impl Expr { } } Expr::BinOp(op, e1, e2) => { - let dom = (d, get_domain(d, env)); + let dom = (d, get_domain_ref(d, env)); let f = |x: EvalResult, y: EvalResult| match op { Op2::Mul => x.mul(y, dom), Op2::Add => x.add(y, dom), diff --git a/poly-commitment/src/combine.rs b/poly-commitment/src/combine.rs index a891f36ef9..d772d28939 100644 --- a/poly-commitment/src/combine.rs +++ b/poly-commitment/src/combine.rs @@ -295,7 +295,7 @@ fn affine_window_combine_one_endo_base( ) -> Vec> { fn assign(dst: &mut [A], src: &[A]) { let n = dst.len(); - dst[..n].clone_from_slice(&src[..n]); + dst[..n].copy_from_slice(&src[..n]); } fn get_bit(limbs_lsb: &[u64], i: u64) -> u64 { diff --git a/poly-commitment/src/evaluation_proof.rs b/poly-commitment/src/evaluation_proof.rs index d66122d82a..4e7306844e 100644 --- a/poly-commitment/src/evaluation_proof.rs +++ b/poly-commitment/src/evaluation_proof.rs @@ -225,25 +225,31 @@ impl SRS { let rand_l = ::rand(rng); let rand_r = ::rand(rng); - let l = call_msm( - &[&g[0..n], &[self.h, u]].concat(), - &[&a[n..], &[rand_l, inner_prod(a_hi, b_lo)]] - .concat() - .iter() - .map(|x| x.into_repr()) - .collect::>(), - ) - .into_affine(); - - let r = call_msm( - &[&g[n..], &[self.h, u]].concat(), - &[&a[0..n], &[rand_r, inner_prod(a_lo, b_hi)]] - .concat() - .iter() - .map(|x| x.into_repr()) - .collect::>(), - ) - .into_affine(); + let call_l = || { + call_msm( + &[&g[0..n], &[self.h, u]].concat(), + &[&a[n..], &[rand_l, inner_prod(a_hi, b_lo)]] + .concat() + .iter() + .map(|x| x.into_repr()) + .collect::>(), + ) + .into_affine() + }; + + let call_r = || { + call_msm( + &[&g[n..], &[self.h, u]].concat(), + &[&a[0..n], &[rand_r, inner_prod(a_lo, b_hi)]] + .concat() + .iter() + .map(|x| x.into_repr()) + .collect::>(), + ) + .into_affine() + }; + + let (l, r) = rayon::join(call_l, call_r); lr.push((l, r)); blinders.push((rand_l, rand_r)); @@ -258,29 +264,33 @@ impl SRS { chals.push(u); chal_invs.push(u_inv); - a = a_hi - .par_iter() - .zip(a_lo) - .map(|(&hi, &lo)| { - // lo + u_inv * hi - let mut res = hi; - res *= u_inv; - res += &lo; - res - }) - .collect(); - - b = b_lo - .par_iter() - .zip(b_hi) - .map(|(&lo, &hi)| { - // lo + u * hi - let mut res = hi; - res *= u; - res += &lo; - res - }) - .collect(); + let call_a = || { + a_hi.par_iter() + .zip(a_lo) + .map(|(&hi, &lo)| { + // lo + u_inv * hi + let mut res = hi; + res *= u_inv; + res += &lo; + res + }) + .collect() + }; + + let call_b = || { + b_lo.par_iter() + .zip(b_hi) + .map(|(&lo, &hi)| { + // lo + u * hi + let mut res = hi; + res *= u; + res += &lo; + res + }) + .collect() + }; + + (a, b) = rayon::join(call_a, call_b); g = G::combine_one_endo(endo_r, endo_q, &g_lo, &g_hi, u_pre); } diff --git a/poly-commitment/src/msm.rs b/poly-commitment/src/msm.rs index ae6bd759ef..8901fe4552 100644 --- a/poly-commitment/src/msm.rs +++ b/poly-commitment/src/msm.rs @@ -123,10 +123,9 @@ pub fn call_msm_impl( // Safety: We're reinterpreting generic types to their concret types // proof-systems contains too much useless generic types // It's safe because we just asserted they are the same types - let result = my_msm::( - unsafe { std::mem::transmute(points) }, - unsafe { std::mem::transmute(scalars) }, - ); + let result = my_msm::(unsafe { std::mem::transmute(points) }, unsafe { + std::mem::transmute(scalars) + }); unsafe { *(&result as *const _ as *const G::Projective) } }