Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prover functions in queue #2461

Merged
merged 31 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
01fda86
evaluate eagerly
chriseth Feb 6, 2025
144e066
oe
chriseth Feb 6, 2025
eb79b42
x
chriseth Feb 6, 2025
7f9d358
fix
chriseth Feb 6, 2025
f549ab0
x
chriseth Feb 6, 2025
f3c0c10
x
chriseth Feb 6, 2025
d1bbde6
x
chriseth Feb 6, 2025
47604d0
x
chriseth Feb 6, 2025
13ea71a
x
chriseth Feb 6, 2025
7183a49
x
chriseth Feb 6, 2025
dc08751
x
chriseth Feb 6, 2025
ebf4612
x
chriseth Feb 6, 2025
eb78a9e
clippy
chriseth Feb 6, 2025
33fc52e
move
chriseth Feb 6, 2025
4583672
fix
chriseth Feb 6, 2025
04e9a9a
Handle assignments in identity queue as well
chriseth Feb 6, 2025
df24e90
fix
chriseth Feb 6, 2025
b705223
clippy
chriseth Feb 6, 2025
5347756
Remove process assignments.
chriseth Feb 6, 2025
160cf10
fix tests.
chriseth Feb 7, 2025
e4d29ee
remove assignments.
chriseth Feb 7, 2025
b1a34d8
fix tests.
chriseth Feb 7, 2025
e82e18f
Merge remote-tracking branch 'origin/main' into handle-assignments-in…
chriseth Feb 7, 2025
4564eac
Sort assignments first.
chriseth Feb 7, 2025
7063105
Prover functions in queue.
chriseth Feb 7, 2025
8cce9c8
Refactor reference computation.
chriseth Feb 7, 2025
bca4d9d
Merge remote-tracking branch 'origin/main' into prover_functions_in_q…
chriseth Feb 7, 2025
afd3830
Merge remote-tracking branch 'origin/main' into move_assignments_to_p…
chriseth Feb 7, 2025
2bf52ad
Merge branch 'move_assignments_to_processor' into prover_functions_in…
chriseth Feb 7, 2025
c63dc65
Prevent clippy suggestion.
chriseth Feb 7, 2025
93d49e4
Merge remote-tracking branch 'origin/main' into prover_functions_in_q…
chriseth Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 159 additions & 137 deletions executor/src/witgen/jit/identity_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use std::{

use itertools::Itertools;
use powdr_ast::{
analyzed::{AlgebraicExpression as Expression, AlgebraicReferenceThin, PolynomialType},
analyzed::{
AlgebraicExpression as Expression, AlgebraicReferenceThin, PolynomialIdentity,
PolynomialType,
},
parsed::visitor::{AllChildren, Children},
};
use powdr_number::FieldElement;
Expand All @@ -15,6 +18,7 @@ use crate::witgen::{
};

use super::{
prover_function_heuristics::ProverFunction,
variable::Variable,
witgen_inference::{Assignment, VariableOrValue},
};
Expand All @@ -33,13 +37,31 @@ impl<'a, T: FieldElement> IdentityQueue<'a, T> {
fixed_data: &'a FixedData<'a, T>,
identities: &[(&'a Identity<T>, i32)],
assignments: &[Assignment<'a, T>],
prover_functions: &[(ProverFunction<'a, T>, i32)],
) -> Self {
let queue: BTreeSet<_> = identities
.iter()
.map(|(id, row)| QueueItem::Identity(id, *row))
.chain(assignments.iter().map(|a| QueueItem::Assignment(a.clone())))
.chain(
prover_functions
.iter()
.map(|(p, row)| QueueItem::ProverFunction(p.clone(), *row)),
)
.collect();
let occurrences = compute_occurrences_map(fixed_data, &queue).into();
let mut references = ReferencesComputer::new(fixed_data);
let occurrences = Rc::new(
queue
.iter()
.flat_map(|item| {
references
.references(item)
.iter()
.map(|v| (v.clone(), item.clone()))
.collect_vec()
})
.into_group_map(),
);
Self { queue, occurrences }
}

Expand Down Expand Up @@ -72,6 +94,7 @@ impl<'a, T: FieldElement> IdentityQueue<'a, T> {
pub enum QueueItem<'a, T: FieldElement> {
Identity(&'a Identity<T>, i32),
Assignment(Assignment<'a, T>),
ProverFunction(ProverFunction<'a, T>, i32),
}

/// Sorts identities by row and then by ID, preceded by assignments.
Expand All @@ -82,8 +105,13 @@ impl<T: FieldElement> Ord for QueueItem<'_, T> {
(row1, id1.id()).cmp(&(row2, id2.id()))
}
(QueueItem::Assignment(a1), QueueItem::Assignment(a2)) => a1.cmp(a2),
(QueueItem::Assignment(_), QueueItem::Identity(_, _)) => std::cmp::Ordering::Less,
(QueueItem::Identity(_, _), QueueItem::Assignment(_)) => std::cmp::Ordering::Greater,
(QueueItem::ProverFunction(p1, row1), QueueItem::ProverFunction(p2, row2)) => {
(row1, p1.index).cmp(&(row2, p2.index))
}
(QueueItem::Assignment(..), _) => std::cmp::Ordering::Less,
(QueueItem::Identity(..), QueueItem::Assignment(..)) => std::cmp::Ordering::Greater,
(QueueItem::Identity(..), QueueItem::ProverFunction(..)) => std::cmp::Ordering::Less,
(QueueItem::ProverFunction(..), _) => std::cmp::Ordering::Greater,
}
}
}
Expand All @@ -102,151 +130,145 @@ impl<T: FieldElement> PartialEq for QueueItem<'_, T> {

impl<T: FieldElement> Eq for QueueItem<'_, T> {}

/// Computes a map from each variable to the queue items it occurs in.
fn compute_occurrences_map<'b, 'a: 'b, T: FieldElement>(
/// Utility to compute the variables that occur in a queue item.
/// Follows intermediate column references and employs caches.
struct ReferencesComputer<'a, T: FieldElement> {
fixed_data: &'a FixedData<'a, T>,
items: &BTreeSet<QueueItem<'a, T>>,
) -> HashMap<Variable, Vec<QueueItem<'a, T>>> {
let mut intermediate_cache = HashMap::new();

// Compute references only once per identity.
let mut references_per_identity = HashMap::new();
for id in items
.iter()
.filter_map(|item| match item {
QueueItem::Identity(id, _) => Some(id),
_ => None,
})
.unique_by(|id| id.id())
{
references_per_identity.insert(
id.id(),
references_in_identity(id, fixed_data, &mut intermediate_cache),
);
}
intermediate_cache: HashMap<AlgebraicReferenceThin, Vec<AlgebraicReferenceThin>>,
/// A cache to store algebraic references in a polynomial identity, so that it
/// can be re-used on all rows.
references_per_identity: HashMap<u64, Vec<AlgebraicReferenceThin>>,
}

items
.iter()
.flat_map(|item| {
let variables = match item {
QueueItem::Identity(id, row) => {
let mut variables = references_per_identity[&id.id()]
.iter()
.map(|r| {
let name = fixed_data.column_name(&r.poly_id).to_string();
Variable::from_reference(&r.with_name(name), *row)
})
.collect_vec();
if let Identity::BusSend(bus_send) = id {
variables.extend((0..bus_send.selected_payload.expressions.len()).map(
|index| {
impl<'a, T: FieldElement> ReferencesComputer<'a, T> {
pub fn new(fixed_data: &'a FixedData<'a, T>) -> Self {
Self {
fixed_data,
intermediate_cache: HashMap::new(),
references_per_identity: HashMap::new(),
}
}
pub fn references(&mut self, item: &QueueItem<'a, T>) -> Vec<Variable> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this can return an iterator, usage seems to immediately call iter().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to make this return an iterator, but I got some errors wrt either references to &mut self or temporaries in lambda functions.

let vars: Box<dyn Iterator<Item = _>> = match item {
QueueItem::Identity(id, row) => match id {
Identity::Polynomial(poly_id) => Box::new(
self.references_in_polynomial_identity(poly_id)
.into_iter()
.map(|r| self.reference_to_variable(&r, *row)),
),
Identity::BusSend(bus_send) => Box::new(
self.variables_in_expression(&bus_send.selected_payload.selector, *row)
.into_iter()
.chain(
(0..bus_send.selected_payload.expressions.len()).map(|index| {
Variable::MachineCallParam(MachineCallVariable {
identity_id: id.id(),
row_offset: *row,
identity_id: bus_send.identity_id,
index,
row_offset: *row,
})
},
));
};
variables
}
QueueItem::Assignment(a) => {
variables_in_assignment(a, fixed_data, &mut intermediate_cache)
}
};
variables.into_iter().map(move |v| (v, item.clone()))
})
.into_group_map()
}

/// Returns all references to witness column in the identity.
fn references_in_identity<T: FieldElement>(
identity: &Identity<T>,
fixed_data: &FixedData<T>,
intermediate_cache: &mut HashMap<AlgebraicReferenceThin, Vec<AlgebraicReferenceThin>>,
) -> Vec<AlgebraicReferenceThin> {
let mut result = BTreeSet::new();

match identity {
Identity::BusSend(bus_send) => result.extend(references_in_expression(
&bus_send.selected_payload.selector,
fixed_data,
intermediate_cache,
)),
_ => {
for e in identity.children() {
result.extend(references_in_expression(e, fixed_data, intermediate_cache));
}),
),
),
Identity::Connect(..) => Box::new(std::iter::empty()),
},
QueueItem::Assignment(a) => {
let vars_in_rhs = match &a.rhs {
VariableOrValue::Variable(v) => Some(v.clone()),
VariableOrValue::Value(_) => None,
};
Box::new(
self.variables_in_expression(a.lhs, a.row_offset)
.into_iter()
.chain(vars_in_rhs),
)
}
}
QueueItem::ProverFunction(p, row) => Box::new(
p.condition
.iter()
.flat_map(|c| self.variables_in_expression(c, *row))
.chain(
p.input_columns
.iter()
.map(|r| Variable::from_reference(r, *row)),
),
),
};
vars.unique().collect_vec()
}

result.into_iter().collect()
}
fn variables_in_expression(&mut self, expression: &Expression<T>, row: i32) -> Vec<Variable> {
self.references_in_expression(expression)
.iter()
.map(|r| {
let name = self.fixed_data.column_name(&r.poly_id).to_string();
Variable::from_reference(&r.with_name(name), row)
})
.collect()
}

/// Recursively resolves references in intermediate column definitions.
fn references_in_intermediate<T: FieldElement>(
fixed_data: &FixedData<T>,
intermediate: &AlgebraicReferenceThin,
intermediate_cache: &mut HashMap<AlgebraicReferenceThin, Vec<AlgebraicReferenceThin>>,
) -> Vec<AlgebraicReferenceThin> {
if let Some(references) = intermediate_cache.get(intermediate) {
return references.clone();
/// Turns AlgebraicReferenceThin to Variable, by including the row offset.
fn reference_to_variable(&self, reference: &AlgebraicReferenceThin, row: i32) -> Variable {
let name = self.fixed_data.column_name(&reference.poly_id).to_string();
Variable::from_reference(&reference.with_name(name), row)
}
let references = references_in_expression(
&fixed_data.intermediate_definitions[intermediate],
fixed_data,
intermediate_cache,
)
.collect_vec();
intermediate_cache.insert(intermediate.clone(), references.clone());
references
}

/// Returns all references to witness or intermediate column in the expression.
fn references_in_expression<'a, T: FieldElement>(
expression: &'a Expression<T>,
fixed_data: &'a FixedData<T>,
intermediate_cache: &'a mut HashMap<AlgebraicReferenceThin, Vec<AlgebraicReferenceThin>>,
) -> impl Iterator<Item = AlgebraicReferenceThin> + 'a {
expression
.all_children()
.flat_map(
move |e| -> Box<dyn Iterator<Item = AlgebraicReferenceThin> + 'a> {
match e {
Expression::Reference(r) => match r.poly_id.ptype {
PolynomialType::Constant => Box::new(std::iter::empty()),
PolynomialType::Committed => Box::new(std::iter::once(r.into())),
PolynomialType::Intermediate => Box::new(
references_in_intermediate(fixed_data, &r.into(), intermediate_cache)
.into_iter(),
),
},
Expression::PublicReference(_) | Expression::Challenge(_) => {
// TODO we need to introduce a variable type for those.
Box::new(std::iter::empty())
fn references_in_polynomial_identity(
&mut self,
identity: &PolynomialIdentity<T>,
) -> Vec<AlgebraicReferenceThin> {
// Clippy suggests to use `entry()...or_insert_with()`,
// but the code does not work, since we need `&mut self` in
// self.references_in_expression.
#[allow(clippy::map_entry)]
if !self.references_per_identity.contains_key(&identity.id) {
let mut result = BTreeSet::new();
for e in identity.children() {
result.extend(self.references_in_expression(e));
}
self.references_per_identity
.insert(identity.id, result.into_iter().collect_vec());
}
self.references_per_identity[&identity.id].clone()
}

/// Returns all references to witness column in the expression, including indirect
/// references through intermediate columns.
fn references_in_expression(
&mut self,
expression: &Expression<T>,
) -> Vec<AlgebraicReferenceThin> {
let mut references = BTreeSet::new();
for e in expression.all_children() {
match e {
Expression::Reference(r) => match r.poly_id.ptype {
PolynomialType::Constant => {}
PolynomialType::Committed => {
references.insert(r.into());
}
_ => Box::new(std::iter::empty()),
PolynomialType::Intermediate => references
.extend(self.references_in_intermediate(&r.into()).iter().cloned()),
},
Expression::PublicReference(_) | Expression::Challenge(_) => {
// TODO we need to introduce a variable type for those.
}
},
)
.unique()
}
Expression::Number(_)
| Expression::BinaryOperation(..)
| Expression::UnaryOperation(..) => {}
}
}
references.into_iter().collect()
}

/// Returns a vector of all variables that occur in the assignment.
fn variables_in_assignment<'a, T: FieldElement>(
assignment: &Assignment<'a, T>,
fixed_data: &'a FixedData<'a, T>,
intermediate_cache: &mut HashMap<AlgebraicReferenceThin, Vec<AlgebraicReferenceThin>>,
) -> Vec<Variable> {
let rhs_var = match &assignment.rhs {
VariableOrValue::Variable(v) => Some(v.clone()),
VariableOrValue::Value(_) => None,
};
references_in_expression(assignment.lhs, fixed_data, intermediate_cache)
.map(|r| {
let name = fixed_data.column_name(&r.poly_id).to_string();
Variable::from_reference(&r.with_name(name), assignment.row_offset)
})
.chain(rhs_var)
.collect()
fn references_in_intermediate(
&mut self,
intermediate: &AlgebraicReferenceThin,
) -> &Vec<AlgebraicReferenceThin> {
if !self.intermediate_cache.contains_key(intermediate) {
let definition = &self.fixed_data.intermediate_definitions[intermediate];
let references = self.references_in_expression(definition);
self.intermediate_cache
.insert(intermediate.clone(), references.clone());
}
&self.intermediate_cache[intermediate]
}
}
Loading
Loading