Skip to content

Commit

Permalink
fix variable with unsafe api
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Dec 12, 2024
1 parent 163c329 commit 31907d4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 18 deletions.
1 change: 1 addition & 0 deletions multilinear_extensions/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![deny(clippy::cargo)]
#![feature(sync_unsafe_cell)]
pub mod mle;
pub mod util;
pub mod virtual_poly;
Expand Down
74 changes: 56 additions & 18 deletions multilinear_extensions/src/mle.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
use std::{any::TypeId, borrow::Cow, mem, sync::Arc};
use std::{
any::TypeId,
borrow::Cow,
cell::SyncUnsafeCell,
mem::{self, MaybeUninit},
sync::Arc,
};

use crate::{op_mle, util::ceil_log2};
use crate::{
op_mle,
util::{ceil_log2, create_uninit_vec, max_usable_threads},
};
use ark_std::{end_timer, rand::RngCore, start_timer};
use core::hash::Hash;
use ff::Field;
use ff_ext::ExtensionField;
use rayon::iter::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
Expand Down Expand Up @@ -611,6 +621,7 @@ impl<E: ExtensionField> MultilinearExtension<E> for DenseMultilinearExtension<E>
/// Reduce the number of variables of `self` by fixing the
/// `partial_point.len()` variables at `partial_point`.
fn fix_variables_parallel(&self, partial_point: &[E]) -> Self {
let n_threads = max_usable_threads();
// TODO: return error.
assert!(
partial_point.len() <= self.num_vars(),
Expand All @@ -626,12 +637,25 @@ impl<E: ExtensionField> MultilinearExtension<E> for DenseMultilinearExtension<E>
*poly = op_mle!(self, |evaluations| {
Cow::Owned(DenseMultilinearExtension::from_evaluations_ext_vec(
self.num_vars() - 1,
evaluations
.par_iter()
.chunks(2)
.with_min_len(64)
.map(|buf| *point * (*buf[1] - *buf[0]) + *buf[0])
.collect(),
unsafe {
let data = create_uninit_vec::<E>(evaluations.len() / 2);
let vec_sync_unsafe = SyncUnsafeCell::new(data);
(0..n_threads).into_par_iter().for_each(|thread_id| {
let ptr = (*vec_sync_unsafe.get()).as_mut_ptr();
(0..evaluations.len())
.skip(2 * thread_id)
.step_by(2 * n_threads)
.for_each(|i| {
*ptr.add(i / 2) = MaybeUninit::new(
*point * (evaluations[i + 1] - evaluations[i])
+ evaluations[i],
);
});
});
let maybe_uninit_vec: Vec<MaybeUninit<E>> =
vec_sync_unsafe.into_inner();
std::mem::transmute::<Vec<MaybeUninit<E>>, Vec<E>>(maybe_uninit_vec)
},
))
});
}
Expand All @@ -653,20 +677,34 @@ impl<E: ExtensionField> MultilinearExtension<E> for DenseMultilinearExtension<E>
self.num_vars()
);
let nv = self.num_vars();
let n_threads = max_usable_threads();
// evaluate single variable of partial point from left to right
for (i, point) in partial_point.iter().enumerate() {
let max_log2_size = nv - i;
// override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1, b2,..bt, 1] in parallel
match &mut self.evaluations {
FieldType::Base(evaluations) => {
let evaluations_ext = evaluations
.par_iter()
.chunks(2)
.with_min_len(64)
.map(|buf| *point * (*buf[1] - *buf[0]) + *buf[0])
.collect();
let _ = mem::replace(&mut self.evaluations, FieldType::Ext(evaluations_ext));
}
FieldType::Base(evaluations) => unsafe {
let data = create_uninit_vec::<E>(evaluations.len() / 2);
let vec_sync_unsafe = SyncUnsafeCell::new(data);
(0..n_threads).into_par_iter().for_each(|thread_id| {
let ptr = (*vec_sync_unsafe.get()).as_mut_ptr();
(0..evaluations.len())
.skip(2 * thread_id)
.step_by(2 * n_threads)
.for_each(|i| {
*ptr.add(i / 2) = MaybeUninit::new(
*point * (evaluations[i + 1] - evaluations[i]) + evaluations[i],
);
});
});
let maybe_uninit_vec: Vec<MaybeUninit<E>> = vec_sync_unsafe.into_inner();
let _ = mem::replace(
&mut self.evaluations,
FieldType::Ext(std::mem::transmute::<Vec<MaybeUninit<E>>, Vec<E>>(
maybe_uninit_vec,
)),
);
},
FieldType::Ext(evaluations) => {
evaluations
.par_iter_mut()
Expand Down

0 comments on commit 31907d4

Please sign in to comment.