From 163c329de84e8654e66837f5e86d94f3de343119 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 12 Dec 2024 15:37:15 +0800 Subject: [PATCH] optimize parallel version sumcheck --- sumcheck/src/prover_v2.rs | 91 ++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 44 deletions(-) diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index b4021b77d..7f0b6b832 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -8,12 +8,12 @@ use multilinear_extensions::{ commutative_op_mle_pair, mle::{DenseMultilinearExtension, MultilinearExtension}, op_mle, op_mle_product_3, op_mle3_range, - util::largest_even_below, + util::{largest_even_below, max_usable_threads}, virtual_poly_v2::VirtualPolynomialV2, }; use rayon::{ Scope, - iter::{IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator}, + iter::IntoParallelRefMutIterator, prelude::{IntoParallelIterator, ParallelIterator}, }; use transcript::{Challenge, Transcript, TranscriptSyncronized}; @@ -722,6 +722,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { // // eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n) let span = entered_span!("fix_variables"); + let n_threads = max_usable_threads(); if self.round == 0 { assert!(challenge.is_none(), "first round should be prover first."); } else { @@ -769,8 +770,8 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { let AdditiveVec(products_sum) = self .poly .products - .par_iter() - .fold_with( + .iter() + .fold( AdditiveVec::new(self.poly.aux_info.max_degree + 1), |mut products_sum, (coefficient, products)| { let span = entered_span!("sum"); @@ -780,17 +781,17 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { let f = &self.poly.flattened_ml_extensions[products[0]]; op_mle! { |f| { - let res = (0..largest_even_below(f.len())) - .into_par_iter() - .step_by(2) - .with_min_len(64) - .map(|b| { + let res = (0..n_threads).into_par_iter().map(|thread_id| { + (0..largest_even_below(f.len())) + .skip(2*thread_id) + .step_by(2*n_threads) + .map(|b| { AdditiveArray([ f[b], f[b + 1] ]) - }) - .sum::>(); + }).sum::>() + }).sum::>(); let res = if f.len() == 1 { AdditiveArray::<_, 2>([f[0]; 2]) } else { @@ -814,19 +815,20 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ); commutative_op_mle_pair!( |f, g| { - let res = (0..largest_even_below(f.len())) - .into_par_iter() - .step_by(2) - .with_min_len(64) - .map(|b| { - AdditiveArray([ - f[b] * g[b], - f[b + 1] * g[b + 1], - (f[b + 1] + f[b + 1] - f[b]) - * (g[b + 1] + g[b + 1] - g[b]), - ]) - }) - .sum::>(); + let res = (0..n_threads).into_par_iter().map(|thread_id| { + (0..largest_even_below(f.len())) + .skip(2*thread_id) + .step_by(2*n_threads) + .map(|b| { + AdditiveArray([ + f[b] * g[b], + f[b + 1] * g[b + 1], + (f[b + 1] + f[b + 1] - f[b]) + * (g[b + 1] + g[b + 1] - g[b]), + ]) + }).sum::>() + }).sum::>(); + let res = if f.len() == 1 { AdditiveArray::<_, 3>([f[0] * g[0]; 3]) } else { @@ -851,23 +853,26 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ); op_mle_product_3!( |f1, f2, f3| { - let res = (0..largest_even_below(f1.len())) - .step_by(2) - .map(|b| { - // f = c x + d - let c1 = f1[b + 1] - f1[b]; - let c2 = f2[b + 1] - f2[b]; - let c3 = f3[b + 1] - f3[b]; - AdditiveArray([ - f1[b] * (f2[b] * f3[b]), - f1[b + 1] * (f2[b + 1] * f3[b + 1]), - (c1 + f1[b + 1]) - * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), - (c1 + c1 + f1[b + 1]) - * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), - ]) - }) - .sum::>(); + let res = (0..n_threads).into_par_iter().map(|thread_id| { + (0..largest_even_below(f1.len())) + .skip(2*thread_id) + .step_by(2*n_threads) + .map(|b| { + // f = c x + d + let c1 = f1[b + 1] - f1[b]; + let c2 = f2[b + 1] - f2[b]; + let c3 = f3[b + 1] - f3[b]; + AdditiveArray([ + f1[b] * (f2[b] * f3[b]), + f1[b + 1] * (f2[b + 1] * f3[b + 1]), + (c1 + f1[b + 1]) + * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), + (c1 + c1 + f1[b + 1]) + * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), + ]) + }).sum::>() + }).sum::>(); + let res = if f1.len() == 1 { AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4]) } else { @@ -905,9 +910,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { exit_span!(span); products_sum }, - ) - .reduce_with(|acc, item| acc + item) - .unwrap(); + ); exit_span!(span); end_timer!(start);