Skip to content

Commit

Permalink
Parallelize few stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiencs committed Jan 14, 2025
1 parent fd30bd6 commit dec49a9
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 68 deletions.
2 changes: 1 addition & 1 deletion curves/src/pasta/fields/fp.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256Parameters, Fp256};
use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256, Fp256Parameters};

pub type Fp = Fp256<FpParameters>;

Expand Down
3 changes: 1 addition & 2 deletions curves/src/pasta/fields/fq.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256Parameters, Fp256};
use ark_ff::{biginteger::BigInteger256 as BigInteger, FftParameters, Fp256, Fp256Parameters};

pub type Fq = Fp256<FqParameters>;


#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct FqParameters;

Expand Down
46 changes: 28 additions & 18 deletions kimchi/src/circuits/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1021,24 +1021,24 @@ fn unnormalized_lagrange_evals<F: FftField>(

impl<'a, F: FftField> EvalResult<'a, F> {
fn init_<G: Sync + Send + Fn(usize) -> F>(
res_domain: (Domain, D<F>),
res_domain: (Domain, &D<F>),
g: G,
) -> Evaluations<F, D<F>> {
let n = res_domain.1.size();
Evaluations::<F, D<F>>::from_vec_and_domain(
(0..n).into_par_iter().map(g).collect(),
res_domain.1,
res_domain.1.clone(),
)
}

fn init<G: Sync + Send + Fn(usize) -> F>(res_domain: (Domain, D<F>), g: G) -> Self {
fn init<G: Sync + Send + Fn(usize) -> F>(res_domain: (Domain, &D<F>), g: G) -> Self {
Self::Evals {
domain: res_domain.0,
evals: Self::init_(res_domain, g),
}
}

fn add<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D<F>)) -> EvalResult<'c, F> {
fn add<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, &D<F>)) -> EvalResult<'c, F> {
use EvalResult::*;
match (self, other) {
(Constant(x), Constant(y)) => Constant(x + y),
Expand Down Expand Up @@ -1074,7 +1074,7 @@ impl<'a, F: FftField> EvalResult<'a, F> {
.collect();
Evals {
domain: res_domain.0,
evals: Evaluations::<F, D<F>>::from_vec_and_domain(v, res_domain.1),
evals: Evaluations::<F, D<F>>::from_vec_and_domain(v, res_domain.1.clone()),
}
}
(
Expand Down Expand Up @@ -1151,13 +1151,13 @@ impl<'a, F: FftField> EvalResult<'a, F> {

Evals {
domain: res_domain.0,
evals: Evaluations::<F, D<F>>::from_vec_and_domain(v, res_domain.1),
evals: Evaluations::<F, D<F>>::from_vec_and_domain(v, res_domain.1.clone()),
}
}
}
}

fn sub<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D<F>)) -> EvalResult<'c, F> {
fn sub<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, &D<F>)) -> EvalResult<'c, F> {
use EvalResult::*;
match (self, other) {
(Constant(x), Constant(y)) => Constant(x - y),
Expand Down Expand Up @@ -1275,7 +1275,7 @@ impl<'a, F: FftField> EvalResult<'a, F> {
}
}

fn pow<'b>(self, d: u64, res_domain: (Domain, D<F>)) -> EvalResult<'b, F> {
fn pow<'b>(self, d: u64, res_domain: (Domain, &D<F>)) -> EvalResult<'b, F> {
let mut acc = EvalResult::Constant(F::one());
for i in (0..u64::BITS).rev() {
acc = acc.square(res_domain);
Expand All @@ -1288,7 +1288,7 @@ impl<'a, F: FftField> EvalResult<'a, F> {
acc
}

fn square<'b>(self, res_domain: (Domain, D<F>)) -> EvalResult<'b, F> {
fn square<'b>(self, res_domain: (Domain, &D<F>)) -> EvalResult<'b, F> {
use EvalResult::*;
match self {
Constant(x) => Constant(x.square()),
Expand All @@ -1312,7 +1312,7 @@ impl<'a, F: FftField> EvalResult<'a, F> {
}
}

fn mul<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D<F>)) -> EvalResult<'c, F> {
fn mul<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, &D<F>)) -> EvalResult<'c, F> {
use EvalResult::*;
match (self, other) {
(Constant(x), Constant(y)) => Constant(x * y),
Expand Down Expand Up @@ -1424,6 +1424,15 @@ fn get_domain<F: FftField>(d: Domain, env: &Environment<F>) -> D<F> {
}
}

fn get_domain_ref<'a, F: FftField>(d: Domain, env: &'a Environment<F>) -> &'a D<F> {
match d {
Domain::D1 => &env.domain.d1,
Domain::D2 => &env.domain.d2,
Domain::D4 => &env.domain.d4,
Domain::D8 => &env.domain.d8,
}
}

impl<F: Field> Expr<ConstantExpr<F>> {
/// Convenience function for constructing expressions from literal
/// field elements.
Expand Down Expand Up @@ -1713,13 +1722,13 @@ impl<F: FftField> Expr<F> {
assert_eq!(domain, d);
evals
}
EvalResult::Constant(x) => EvalResult::init_((d, get_domain(d, env)), |_| x),
EvalResult::Constant(x) => EvalResult::init_((d, get_domain_ref(d, env)), |_| x),
EvalResult::SubEvals {
evals,
domain: d_sub,
shift: s,
} => {
let res_domain = get_domain(d, env);
let res_domain = get_domain_ref(d, env);
let scale = (d_sub as usize) / (d as usize);
assert!(scale != 0);
EvalResult::init_((d, res_domain), |i| {
Expand All @@ -1738,7 +1747,7 @@ impl<F: FftField> Expr<F> {
where
'a: 'b,
{
let dom = (d, get_domain(d, env));
let dom = (d, get_domain_ref(d, env));

let res: EvalResult<'a, F> = match self {
Expr::Square(x) => match x.evaluations_helper(cache, d, env) {
Expand Down Expand Up @@ -1800,10 +1809,11 @@ impl<F: FftField> Expr<F> {
Expr::Pow(x, p) => {
let x = x.evaluations_helper(cache, d, env);
match x {
Either::Left(x) => x.pow(*p, (d, get_domain(d, env))),
Either::Right(id) => {
id.get_from(cache).unwrap().pow(*p, (d, get_domain(d, env)))
}
Either::Left(x) => x.pow(*p, (d, get_domain_ref(d, env))),
Either::Right(id) => id
.get_from(cache)
.unwrap()
.pow(*p, (d, get_domain_ref(d, env))),
}
}
Expr::VanishesOnZeroKnowledgeAndPreviousRows => EvalResult::SubEvals {
Expand Down Expand Up @@ -1837,7 +1847,7 @@ impl<F: FftField> Expr<F> {
}
}
Expr::BinOp(op, e1, e2) => {
let dom = (d, get_domain(d, env));
let dom = (d, get_domain_ref(d, env));
let f = |x: EvalResult<F>, y: EvalResult<F>| match op {
Op2::Mul => x.mul(y, dom),
Op2::Add => x.add(y, dom),
Expand Down
2 changes: 1 addition & 1 deletion poly-commitment/src/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ fn affine_window_combine_one_endo_base<P: SWModelParameters>(
) -> Vec<SWJAffine<P>> {
fn assign<A: Copy>(dst: &mut [A], src: &[A]) {
let n = dst.len();
dst[..n].clone_from_slice(&src[..n]);
dst[..n].copy_from_slice(&src[..n]);
}

fn get_bit(limbs_lsb: &[u64], i: u64) -> u64 {
Expand Down
94 changes: 52 additions & 42 deletions poly-commitment/src/evaluation_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,25 +225,31 @@ impl<G: CommitmentCurve> SRS<G> {
let rand_l = <G::ScalarField as UniformRand>::rand(rng);
let rand_r = <G::ScalarField as UniformRand>::rand(rng);

let l = call_msm(
&[&g[0..n], &[self.h, u]].concat(),
&[&a[n..], &[rand_l, inner_prod(a_hi, b_lo)]]
.concat()
.iter()
.map(|x| x.into_repr())
.collect::<Vec<_>>(),
)
.into_affine();

let r = call_msm(
&[&g[n..], &[self.h, u]].concat(),
&[&a[0..n], &[rand_r, inner_prod(a_lo, b_hi)]]
.concat()
.iter()
.map(|x| x.into_repr())
.collect::<Vec<_>>(),
)
.into_affine();
let call_l = || {
call_msm(
&[&g[0..n], &[self.h, u]].concat(),
&[&a[n..], &[rand_l, inner_prod(a_hi, b_lo)]]
.concat()
.iter()
.map(|x| x.into_repr())
.collect::<Vec<_>>(),
)
.into_affine()
};

let call_r = || {
call_msm(
&[&g[n..], &[self.h, u]].concat(),
&[&a[0..n], &[rand_r, inner_prod(a_lo, b_hi)]]
.concat()
.iter()
.map(|x| x.into_repr())
.collect::<Vec<_>>(),
)
.into_affine()
};

let (l, r) = rayon::join(call_l, call_r);

lr.push((l, r));
blinders.push((rand_l, rand_r));
Expand All @@ -258,29 +264,33 @@ impl<G: CommitmentCurve> SRS<G> {
chals.push(u);
chal_invs.push(u_inv);

a = a_hi
.par_iter()
.zip(a_lo)
.map(|(&hi, &lo)| {
// lo + u_inv * hi
let mut res = hi;
res *= u_inv;
res += &lo;
res
})
.collect();

b = b_lo
.par_iter()
.zip(b_hi)
.map(|(&lo, &hi)| {
// lo + u * hi
let mut res = hi;
res *= u;
res += &lo;
res
})
.collect();
let call_a = || {
a_hi.par_iter()
.zip(a_lo)
.map(|(&hi, &lo)| {
// lo + u_inv * hi
let mut res = hi;
res *= u_inv;
res += &lo;
res
})
.collect()
};

let call_b = || {
b_lo.par_iter()
.zip(b_hi)
.map(|(&lo, &hi)| {
// lo + u * hi
let mut res = hi;
res *= u;
res += &lo;
res
})
.collect()
};

(a, b) = rayon::join(call_a, call_b);

g = G::combine_one_endo(endo_r, endo_q, &g_lo, &g_hi, u_pre);
}
Expand Down
7 changes: 3 additions & 4 deletions poly-commitment/src/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,9 @@ pub fn call_msm_impl<G: CommitmentCurve>(
// Safety: We're reinterpreting generic types to their concret types
// proof-systems contains too much useless generic types
// It's safe because we just asserted they are the same types
let result = my_msm::<G::Params>(
unsafe { std::mem::transmute(points) },
unsafe { std::mem::transmute(scalars) },
);
let result = my_msm::<G::Params>(unsafe { std::mem::transmute(points) }, unsafe {
std::mem::transmute(scalars)
});
unsafe { *(&result as *const _ as *const G::Projective) }
}

Expand Down

0 comments on commit dec49a9

Please sign in to comment.