Skip to content

Commit

Permalink
optimize parallel version sumcheck
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Dec 12, 2024
1 parent 0cd9258 commit 163c329
Showing 1 changed file with 47 additions and 44 deletions.
91 changes: 47 additions & 44 deletions sumcheck/src/prover_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ use multilinear_extensions::{
commutative_op_mle_pair,
mle::{DenseMultilinearExtension, MultilinearExtension},
op_mle, op_mle_product_3, op_mle3_range,
util::largest_even_below,
util::{largest_even_below, max_usable_threads},
virtual_poly_v2::VirtualPolynomialV2,
};
use rayon::{
Scope,
iter::{IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator},
iter::IntoParallelRefMutIterator,
prelude::{IntoParallelIterator, ParallelIterator},
};
use transcript::{Challenge, Transcript, TranscriptSyncronized};
Expand Down Expand Up @@ -722,6 +722,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
//
// eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n)
let span = entered_span!("fix_variables");
let n_threads = max_usable_threads();
if self.round == 0 {
assert!(challenge.is_none(), "first round should be prover first.");
} else {
Expand Down Expand Up @@ -769,8 +770,8 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
let AdditiveVec(products_sum) = self
.poly
.products
.par_iter()
.fold_with(
.iter()
.fold(
AdditiveVec::new(self.poly.aux_info.max_degree + 1),
|mut products_sum, (coefficient, products)| {
let span = entered_span!("sum");
Expand All @@ -780,17 +781,17 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
let f = &self.poly.flattened_ml_extensions[products[0]];
op_mle! {
|f| {
let res = (0..largest_even_below(f.len()))
.into_par_iter()
.step_by(2)
.with_min_len(64)
.map(|b| {
let res = (0..n_threads).into_par_iter().map(|thread_id| {
(0..largest_even_below(f.len()))
.skip(2*thread_id)
.step_by(2*n_threads)
.map(|b| {
AdditiveArray([
f[b],
f[b + 1]
])
})
.sum::<AdditiveArray<_, 2>>();
}).sum::<AdditiveArray<_, 2>>()
}).sum::<AdditiveArray<_, 2>>();
let res = if f.len() == 1 {
AdditiveArray::<_, 2>([f[0]; 2])
} else {
Expand All @@ -814,19 +815,20 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
);
commutative_op_mle_pair!(
|f, g| {
let res = (0..largest_even_below(f.len()))
.into_par_iter()
.step_by(2)
.with_min_len(64)
.map(|b| {
AdditiveArray([
f[b] * g[b],
f[b + 1] * g[b + 1],
(f[b + 1] + f[b + 1] - f[b])
* (g[b + 1] + g[b + 1] - g[b]),
])
})
.sum::<AdditiveArray<_, 3>>();
let res = (0..n_threads).into_par_iter().map(|thread_id| {
(0..largest_even_below(f.len()))
.skip(2*thread_id)
.step_by(2*n_threads)
.map(|b| {
AdditiveArray([
f[b] * g[b],
f[b + 1] * g[b + 1],
(f[b + 1] + f[b + 1] - f[b])
* (g[b + 1] + g[b + 1] - g[b]),
])
}).sum::<AdditiveArray<_, 3>>()
}).sum::<AdditiveArray<_, 3>>();

let res = if f.len() == 1 {
AdditiveArray::<_, 3>([f[0] * g[0]; 3])
} else {
Expand All @@ -851,23 +853,26 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
);
op_mle_product_3!(
|f1, f2, f3| {
let res = (0..largest_even_below(f1.len()))
.step_by(2)
.map(|b| {
// f = c x + d
let c1 = f1[b + 1] - f1[b];
let c2 = f2[b + 1] - f2[b];
let c3 = f3[b + 1] - f3[b];
AdditiveArray([
f1[b] * (f2[b] * f3[b]),
f1[b + 1] * (f2[b + 1] * f3[b + 1]),
(c1 + f1[b + 1])
* ((c2 + f2[b + 1]) * (c3 + f3[b + 1])),
(c1 + c1 + f1[b + 1])
* ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])),
])
})
.sum::<AdditiveArray<_, 4>>();
let res = (0..n_threads).into_par_iter().map(|thread_id| {
(0..largest_even_below(f1.len()))
.skip(2*thread_id)
.step_by(2*n_threads)
.map(|b| {
// f = c x + d
let c1 = f1[b + 1] - f1[b];
let c2 = f2[b + 1] - f2[b];
let c3 = f3[b + 1] - f3[b];
AdditiveArray([
f1[b] * (f2[b] * f3[b]),
f1[b + 1] * (f2[b + 1] * f3[b + 1]),
(c1 + f1[b + 1])
* ((c2 + f2[b + 1]) * (c3 + f3[b + 1])),
(c1 + c1 + f1[b + 1])
* ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])),
])
}).sum::<AdditiveArray<_, 4>>()
}).sum::<AdditiveArray<_, 4>>();

let res = if f1.len() == 1 {
AdditiveArray::<_, 4>([f1[0] * f2[0] * f3[0]; 4])
} else {
Expand Down Expand Up @@ -905,9 +910,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> {
exit_span!(span);
products_sum
},
)
.reduce_with(|acc, item| acc + item)
.unwrap();
);
exit_span!(span);

end_timer!(start);
Expand Down

0 comments on commit 163c329

Please sign in to comment.