Skip to content

Commit

Permalink
Parallelize calls to prover functions (#2176)
Browse files Browse the repository at this point in the history
Cherry-picked from #2174

With this PR, we run all prover functions in parallel when solving for
the witness in `VmProcessor`. Interestingly, this didn't require any
changes to the order in which things are done: We already ran the
functions independently and applied the combined updates. So, this is a
classic map-reduce.

I think this change always makes sense, but is especially useful for the
prover functions we have to set bus accumulator values. For example, in
our RISC-V machine, the main machine has ~30 bus interactions, with a
fairly expensive prover function for each.

When used on top of #2173 and #2175, this accelerates second-stage
witness generation for the main machine from ~10s to ~6s for the example
mentioned in #2173.
  • Loading branch information
georgwiese authored and leonardoalt committed Nov 29, 2024
1 parent 89fc866 commit 90f29fe
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
31 changes: 25 additions & 6 deletions executor/src/witgen/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use powdr_ast::analyzed::PolynomialType;
use powdr_ast::analyzed::{AlgebraicExpression as Expression, AlgebraicReference, PolyID};

use powdr_number::{DegreeType, FieldElement};
use rayon::iter::{ParallelBridge, ParallelIterator};

use crate::witgen::affine_expression::AlgebraicVariable;
use crate::witgen::data_structures::mutable_state::MutableState;
Expand Down Expand Up @@ -220,7 +221,7 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> Processor<'a, 'c, T, Q> {
}

pub fn process_queries(&mut self, row_index: usize) -> Result<bool, EvalError<T>> {
let mut query_processor = QueryProcessor::new(
let query_processor = QueryProcessor::new(
self.fixed_data,
self.mutable_state.query_callback(),
self.size,
Expand All @@ -238,15 +239,33 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> Processor<'a, 'c, T, Q> {
);
let mut updates = EvalValue::complete(vec![]);

for (i, fun) in self.parts.prover_functions.iter().enumerate() {
if !self.processed_prover_functions.has_run(row_index, i) {
let r = query_processor.process_prover_function(&row_pair, fun)?;
self.parts
.prover_functions
.iter()
.enumerate()
// Run all prover functions in parallel
.par_bridge()
.filter_map(|(i, fun)| {
if !self.processed_prover_functions.has_run(row_index, i) {
query_processor
.process_prover_function(&row_pair, fun)
.map(|result| Some((result, i)))
.transpose()
} else {
// Skip already processed functions
None
}
})
// Fail if any of the prover functions failed
.collect::<Result<Vec<_>, EvalError<T>>>()?
// Combine results
.into_iter()
.for_each(|(r, i)| {
if r.is_complete() {
updates.combine(r);
self.processed_prover_functions.mark_as_run(row_index, i);
}
}
}
});

for poly_id in &self.prover_query_witnesses {
if let Some(r) = query_processor.process_query(&row_pair, poly_id) {
Expand Down
12 changes: 6 additions & 6 deletions executor/src/witgen/query_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback<T>>
}
}

pub fn process_prover_function<'c>(
&'c mut self,
rows: &'c RowPair<'c, 'a, T>,
pub fn process_prover_function(
&self,
rows: &RowPair<'_, 'a, T>,
fun: &'a Expression,
) -> EvalResult<'a, T> {
let arguments = vec![Arc::new(Value::Integer(BigInt::from(u64::from(
Expand Down Expand Up @@ -77,7 +77,7 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback<T>>
/// Panics if the column does not have a query attached.
/// @returns None if the value for that column is already known.
pub fn process_query(
&mut self,
&self,
rows: &RowPair<'_, 'a, T>,
poly_id: &PolyID,
) -> Option<EvalResult<'a, T>> {
Expand All @@ -91,7 +91,7 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback<T>>
}

fn process_witness_query(
&mut self,
&self,
query: &'a Expression,
poly: &'a AlgebraicReference,
rows: &RowPair<'_, 'a, T>,
Expand Down Expand Up @@ -129,7 +129,7 @@ impl<'a, 'b, T: FieldElement, QueryCallback: super::QueryCallback<T>>
}

fn interpolate_query(
&mut self,
&self,
query: &'a Expression,
rows: &RowPair<'_, 'a, T>,
) -> Result<String, EvalError> {
Expand Down

0 comments on commit 90f29fe

Please sign in to comment.