diff --git a/multilinear_extensions/src/lib.rs b/multilinear_extensions/src/lib.rs index 9f669e348..63d0f0449 100644 --- a/multilinear_extensions/src/lib.rs +++ b/multilinear_extensions/src/lib.rs @@ -1,4 +1,5 @@ #![deny(clippy::cargo)] +#![feature(sync_unsafe_cell)] pub mod mle; pub mod util; pub mod virtual_poly; diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 8a182e645..f3732297d 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -1,12 +1,22 @@ -use std::{any::TypeId, borrow::Cow, mem, sync::Arc}; +use std::{ + any::TypeId, + borrow::Cow, + cell::SyncUnsafeCell, + mem::{self, MaybeUninit}, + sync::Arc, +}; -use crate::{op_mle, util::ceil_log2}; +use crate::{ + op_mle, + util::{ceil_log2, create_uninit_vec, max_usable_threads}, +}; use ark_std::{end_timer, rand::RngCore, start_timer}; use core::hash::Hash; use ff::Field; use ff_ext::ExtensionField; use rayon::iter::{ - IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, + IntoParallelRefMutIterator, ParallelIterator, }; use serde::{Deserialize, Serialize}; use std::fmt::Debug; @@ -611,6 +621,7 @@ impl MultilinearExtension for DenseMultilinearExtension /// Reduce the number of variables of `self` by fixing the /// `partial_point.len()` variables at `partial_point`. fn fix_variables_parallel(&self, partial_point: &[E]) -> Self { + let n_threads = max_usable_threads(); // TODO: return error. assert!( partial_point.len() <= self.num_vars(), @@ -626,12 +637,25 @@ impl MultilinearExtension for DenseMultilinearExtension *poly = op_mle!(self, |evaluations| { Cow::Owned(DenseMultilinearExtension::from_evaluations_ext_vec( self.num_vars() - 1, - evaluations - .par_iter() - .chunks(2) - .with_min_len(64) - .map(|buf| *point * (*buf[1] - *buf[0]) + *buf[0]) - .collect(), + unsafe { + let data = create_uninit_vec::(evaluations.len() / 2); + let vec_sync_unsafe = SyncUnsafeCell::new(data); + (0..n_threads).into_par_iter().for_each(|thread_id| { + let ptr = (*vec_sync_unsafe.get()).as_mut_ptr(); + (0..evaluations.len()) + .skip(2 * thread_id) + .step_by(2 * n_threads) + .for_each(|i| { + *ptr.add(i / 2) = MaybeUninit::new( + *point * (evaluations[i + 1] - evaluations[i]) + + evaluations[i], + ); + }); + }); + let maybe_uninit_vec: Vec> = + vec_sync_unsafe.into_inner(); + std::mem::transmute::>, Vec>(maybe_uninit_vec) + }, )) }); } @@ -653,20 +677,34 @@ impl MultilinearExtension for DenseMultilinearExtension self.num_vars() ); let nv = self.num_vars(); + let n_threads = max_usable_threads(); // evaluate single variable of partial point from left to right for (i, point) in partial_point.iter().enumerate() { let max_log2_size = nv - i; // override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1, b2,..bt, 1] in parallel match &mut self.evaluations { - FieldType::Base(evaluations) => { - let evaluations_ext = evaluations - .par_iter() - .chunks(2) - .with_min_len(64) - .map(|buf| *point * (*buf[1] - *buf[0]) + *buf[0]) - .collect(); - let _ = mem::replace(&mut self.evaluations, FieldType::Ext(evaluations_ext)); - } + FieldType::Base(evaluations) => unsafe { + let data = create_uninit_vec::(evaluations.len() / 2); + let vec_sync_unsafe = SyncUnsafeCell::new(data); + (0..n_threads).into_par_iter().for_each(|thread_id| { + let ptr = (*vec_sync_unsafe.get()).as_mut_ptr(); + (0..evaluations.len()) + .skip(2 * thread_id) + .step_by(2 * n_threads) + .for_each(|i| { + *ptr.add(i / 2) = MaybeUninit::new( + *point * (evaluations[i + 1] - evaluations[i]) + evaluations[i], + ); + }); + }); + let maybe_uninit_vec: Vec> = vec_sync_unsafe.into_inner(); + let _ = mem::replace( + &mut self.evaluations, + FieldType::Ext(std::mem::transmute::>, Vec>( + maybe_uninit_vec, + )), + ); + }, FieldType::Ext(evaluations) => { evaluations .par_iter_mut()