From 402fa0b66ee9ae99292407510bbebcc2115d10fd Mon Sep 17 00:00:00 2001 From: hmuro andrej Date: Thu, 9 Feb 2023 00:20:19 +0300 Subject: [PATCH 1/4] feat: add GCE codegen --- codegen/gce/Cargo.toml | 17 ++ codegen/gce/README.md | 2 + codegen/gce/src/error.rs | 51 +++++ codegen/gce/src/expressions.rs | 359 +++++++++++++++++++++++++++++++++ codegen/gce/src/lib.rs | 248 +++++++++++++++++++++++ codegen/gce/src/utils.rs | 86 ++++++++ ir/src/constraints/graph.rs | 10 + ir/src/lib.rs | 8 +- ir/src/symbol_table.rs | 5 +- 9 files changed, 784 insertions(+), 2 deletions(-) create mode 100644 codegen/gce/Cargo.toml create mode 100644 codegen/gce/README.md create mode 100644 codegen/gce/src/error.rs create mode 100644 codegen/gce/src/expressions.rs create mode 100644 codegen/gce/src/lib.rs create mode 100644 codegen/gce/src/utils.rs diff --git a/codegen/gce/Cargo.toml b/codegen/gce/Cargo.toml new file mode 100644 index 00000000..8ba246e3 --- /dev/null +++ b/codegen/gce/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "air-codegen-gce" +version = "0.1.0" +description="Code generation for the generic constraint evaluation format." +authors = ["miden contributors"] +readme="README.md" +license = "MIT" +repository = "https://github.com/0xPolygonMiden/air-script" +categories = ["compilers", "cryptography"] +keywords = ["air", "stark", "winterfell", "zero-knowledge", "zkp"] +edition = "2021" +rust-version = "1.65" + +[dependencies] +air-script-core = {package = "air-script-core", path="../../air-script-core", version="0.1.0" } +ir = {package = "air-ir", path="../../ir", version="0.1.0" } +codegen = "0.2.0" \ No newline at end of file diff --git a/codegen/gce/README.md b/codegen/gce/README.md new file mode 100644 index 00000000..26d5766f --- /dev/null +++ b/codegen/gce/README.md @@ -0,0 +1,2 @@ +# Generic Constraint Evaluation Generator + diff --git a/codegen/gce/src/error.rs b/codegen/gce/src/error.rs new file mode 100644 index 00000000..b14da5c3 --- /dev/null +++ b/codegen/gce/src/error.rs @@ -0,0 +1,51 @@ +#[derive(Debug)] +pub enum ConstraintEvaluationError { + InvalidTraceSegment(String), + InvalidOperation(String), + IdentifierNotFound(String), + ConstantNotFound(String), + PublicInputNotFound(String), + OperationNotFound(String), + InvalidConstantType(String), +} + +impl ConstraintEvaluationError { + pub fn invalid_trace_segment(segment: u8) -> Self { + ConstraintEvaluationError::InvalidTraceSegment(format!( + "Trace segment {} is invalid", + segment + )) + } + + pub fn identifier_not_found(name: &str) -> Self { + ConstraintEvaluationError::IdentifierNotFound(format!( + "Identifier {} not found in JSON arrays", + name + )) + } + + pub fn constant_not_found(name: &str) -> Self { + ConstraintEvaluationError::ConstantNotFound(format!("Constant \"{}\" not found", name)) + } + + pub fn public_input_not_found(name: &str) -> Self { + ConstraintEvaluationError::PublicInputNotFound(format!( + "Public Input \"{}\" not found", + name + )) + } + + pub fn invalid_constant_type(name: &str, constant_type: &str) -> Self { + ConstraintEvaluationError::InvalidConstantType(format!( + "Invalid type of constant \"{}\". {} exprected.", + name, constant_type + )) + } + + pub fn operation_not_found(index: usize) -> Self { + ConstraintEvaluationError::OperationNotFound(format!( + "Operation with index {} does not match the expression in the expressions JSON array", + index + )) + } +} diff --git a/codegen/gce/src/expressions.rs b/codegen/gce/src/expressions.rs new file mode 100644 index 00000000..7d569731 --- /dev/null +++ b/codegen/gce/src/expressions.rs @@ -0,0 +1,359 @@ +use super::error::ConstraintEvaluationError; +use super::{ + utils::{ + get_constant_index_by_matrix_access, get_constant_index_by_name, + get_constant_index_by_value, get_constant_index_by_vector_access, get_random_value_index, + }, + ExpressionJson, ExpressionOperation, NodeReference, NodeType, +}; +use ir::{ + constraints::{ConstantValue, Operation}, + AirIR, NodeIndex, +}; +use std::collections::BTreeMap; + +const MAIN_TRACE_SEGMENT_INDEX: u8 = 0; + +pub struct ExpressionsHandler<'a> { + ir: &'a AirIR, + constants: &'a [u64], + // maps indexes in Node vector in AlgebraicGraph and in `expressions` JSON array + expressions_map: &'a mut BTreeMap, +} + +impl<'a> ExpressionsHandler<'a> { + pub fn new( + ir: &'a AirIR, + constants: &'a [u64], + expressions_map: &'a mut BTreeMap, + ) -> Self { + ExpressionsHandler { + ir, + constants, + expressions_map, + } + } + + /// Parses expressions in transition graph's Node vector, creates [Expression] instances and pushes + /// them to the `expressions` vector. + pub fn get_expressions(&mut self) -> Result, ConstraintEvaluationError> { + // TODO: currently we can't create a node reference to the last row (which is required for + // main.last and aux.last boundary constraints). Working in assumption that first reference to + // the column is .first constraint and second is .last constraint (in the boundary section, not + // entire array) + let mut expressions = Vec::new(); + + for (index, node) in self.ir.constraint_graph().nodes().iter().enumerate() { + match node.op() { + Operation::Add(l, r) => { + expressions.push(self.handle_transition_expression( + ExpressionOperation::Add, + *l, + *r, + )?); + // create mapping (index in node graph: index in expressions vector) + self.expressions_map.insert(index, expressions.len() - 1); + } + Operation::Sub(l, r) => { + expressions.push(self.handle_transition_expression( + ExpressionOperation::Sub, + *l, + *r, + )?); + self.expressions_map.insert(index, expressions.len() - 1); + } + Operation::Mul(l, r) => { + expressions.push(self.handle_transition_expression( + ExpressionOperation::Mul, + *l, + *r, + )?); + self.expressions_map.insert(index, expressions.len() - 1); + } + Operation::Exp(i, degree) => { + match degree { + 0 => { + // I decided that node^0 could be emulated using the product of 1*1, but perhaps there are better ways + let index_of_1 = get_constant_index_by_value(1, self.constants)?; + let const_1_node = NodeReference { + node_type: NodeType::Const, + index: index_of_1, + }; + expressions.push(ExpressionJson { + op: ExpressionOperation::Mul, + lhs: const_1_node.clone(), + rhs: const_1_node, + }); + } + 1 => { + let lhs = self.handle_node_reference(*i)?; + let degree_index = get_constant_index_by_value(1, self.constants)?; + let rhs = NodeReference { + node_type: NodeType::Const, + index: degree_index, + }; + expressions.push(ExpressionJson { + op: ExpressionOperation::Mul, + lhs, + rhs, + }); + } + _ => self.handle_exponentiation(&mut expressions, *i, *degree)?, + } + self.expressions_map.insert(index, expressions.len() - 1); + } + _ => {} + } + } + Ok(expressions) + } + + /// Fills the `outputs` vector with indexes from `expressions` vector according to the `expressions_map`. + pub fn get_outputs( + &self, + expressions: &mut Vec, + ) -> Result, ConstraintEvaluationError> { + let mut outputs = Vec::new(); + + for i in 0..self.ir.segment_widths().len() { + for root in self.ir.boundary_constraints(i as u8) { + let index = self + .expressions_map + .get(&root.node_index().index()) + .ok_or_else(|| { + ConstraintEvaluationError::operation_not_found(root.node_index().index()) + })?; + if outputs.contains(index) { + expressions.push(expressions[*index].clone()); + outputs.push(expressions.len() - 1); + } else { + outputs.push(*index); + } + } + + for root in self.ir.validity_constraints(i as u8) { + let index = self + .expressions_map + .get(&root.node_index().index()) + .ok_or_else(|| { + ConstraintEvaluationError::operation_not_found(root.node_index().index()) + })?; + outputs.push(*index); + } + + for root in self.ir.transition_constraints(i as u8) { + let index = self + .expressions_map + .get(&root.node_index().index()) + .ok_or_else(|| { + ConstraintEvaluationError::operation_not_found(root.node_index().index()) + })?; + outputs.push(*index); + } + } + Ok(outputs) + } + + // --- HELPERS -------------------------------------------------------------------------------- + + /// Parses expression in transition graph Node vector and returns related [Expression] instance. + fn handle_transition_expression( + &self, + op: ExpressionOperation, + l: NodeIndex, + r: NodeIndex, + ) -> Result { + let lhs = self.handle_node_reference(l)?; + let rhs = self.handle_node_reference(r)?; + Ok(ExpressionJson { op, lhs, rhs }) + } + + /// Parses expression in transition graph Node vector by [NodeIndex] and returns related + /// [NodeReference] instance. + fn handle_node_reference( + &self, + i: NodeIndex, + ) -> Result { + use Operation::*; + match self.ir.constraint_graph().node(&i).op() { + Add(_, _) | Sub(_, _) | Mul(_, _) | Exp(_, _) => { + let index = self + .expressions_map + .get(&i.index()) + .ok_or_else(|| ConstraintEvaluationError::operation_not_found(i.index()))?; + Ok(NodeReference { + node_type: NodeType::Expr, + index: *index, + }) + } + Constant(constant_value) => { + match constant_value { + ConstantValue::Inline(v) => { + let index = get_constant_index_by_value(*v, self.constants)?; + Ok(NodeReference { + node_type: NodeType::Const, + index, + }) + } + ConstantValue::Scalar(name) => { + let index = get_constant_index_by_name(self.ir, name, self.constants)?; + Ok(NodeReference { + node_type: NodeType::Const, + index, + }) + } + ConstantValue::Vector(vector_access) => { + // why Constant.name() returns Identifier and VectorAccess.name() works like + // VectorAccess.name.name() and returns &str? (same with MatrixAccess) + let index = get_constant_index_by_vector_access( + self.ir, + vector_access, + self.constants, + )?; + Ok(NodeReference { + node_type: NodeType::Const, + index, + }) + } + ConstantValue::Matrix(matrix_access) => { + let index = get_constant_index_by_matrix_access( + self.ir, + matrix_access, + self.constants, + )?; + Ok(NodeReference { + node_type: NodeType::Const, + index, + }) + } + } + } + TraceElement(trace_access) => { + // Working in assumption that segment 0 is main columns, and others are aux columns + match trace_access.trace_segment() { + MAIN_TRACE_SEGMENT_INDEX => { + // TODO: handle other offsets (not only 1) + if trace_access.row_offset() == 0 { + Ok(NodeReference { + node_type: NodeType::Pol, + index: trace_access.col_idx(), + }) + } else { + Ok(NodeReference { + node_type: NodeType::PolNext, + index: trace_access.col_idx(), + }) + } + } + i if i < self.ir.segment_widths().len() as u8 => { + let col_index = self.ir.segment_widths()[0..i as usize].iter().sum::() + as usize + + trace_access.col_idx(); + if trace_access.row_offset() == 0 { + Ok(NodeReference { + node_type: NodeType::Pol, + index: col_index, + }) + } else { + Ok(NodeReference { + node_type: NodeType::PolNext, + index: col_index, + }) + } + } + _ => Err(ConstraintEvaluationError::invalid_trace_segment( + trace_access.trace_segment(), + )), + } + } + RandomValue(rand_index) => { + let index = get_random_value_index(self.ir, *rand_index); + Ok(NodeReference { + node_type: NodeType::Var, + index, + }) + } + + PeriodicColumn(_column, _length) => todo!(), + + // Currently it can only be `Neg` + _ => Err(ConstraintEvaluationError::InvalidOperation( + "Invalid transition constraint operation".to_string(), + )), + } + } + + /// Replaces the exponentiation operation with multiplication operations, adding them to the + /// expressions vector. + fn handle_exponentiation( + &self, + expressions: &mut Vec, + i: NodeIndex, + degree: usize, + ) -> Result<(), ConstraintEvaluationError> { + // base node that we want to raise to a degree + let base_node = self.handle_node_reference(i)?; + // push node^2 expression + expressions.push(ExpressionJson { + op: ExpressionOperation::Mul, + lhs: base_node.clone(), + rhs: base_node.clone(), + }); + let square_node_index = expressions.len() - 1; + + // square the previous expression while there is such an opportunity + let mut cur_degree_of_2 = 1; // currently we have node^(2^cur_degree_of_2) = node^(2^1) = node^2 + while 2_usize.pow(cur_degree_of_2) <= degree / 2 { + // the last node that we want to square + let last_node = NodeReference { + node_type: NodeType::Expr, + index: expressions.len() - 1, + }; + expressions.push(ExpressionJson { + op: ExpressionOperation::Mul, + lhs: last_node.clone(), + rhs: last_node, + }); + cur_degree_of_2 += 1; + } + + // add the largest available powers of two to the current degree + let mut cur_max_degree = 2_usize.pow(cur_degree_of_2); // currently we have node^(2^cur_max_degree) + while cur_max_degree != degree { + let diff = degree - cur_max_degree; + if diff == 1 { + // if we need to add first degree (base node) + let last_node = NodeReference { + node_type: NodeType::Expr, + index: expressions.len() - 1, + }; + expressions.push(ExpressionJson { + op: ExpressionOperation::Mul, + lhs: last_node, + rhs: base_node, + }); + break; + } + if 2_usize.pow(cur_degree_of_2 - 1) <= diff { + let last_node = NodeReference { + node_type: NodeType::Expr, + index: expressions.len() - 1, + }; + let fitting_degree_of_2_node = NodeReference { + node_type: NodeType::Expr, + // cur_degree_of_2 shows how many indexes we need to add to reach the largest fitting degree of 2 + index: square_node_index + cur_degree_of_2 as usize - 2, + }; + expressions.push(ExpressionJson { + op: ExpressionOperation::Mul, + lhs: last_node, + rhs: fitting_degree_of_2_node, + }); + cur_max_degree += 2_usize.pow(cur_degree_of_2 - 1); + } + cur_degree_of_2 -= 1; + } + + Ok(()) + } +} diff --git a/codegen/gce/src/lib.rs b/codegen/gce/src/lib.rs new file mode 100644 index 00000000..45959b78 --- /dev/null +++ b/codegen/gce/src/lib.rs @@ -0,0 +1,248 @@ +use ir::{ + constraints::{ConstantValue, Operation}, + AirIR, +}; + +pub use air_script_core::{ + Constant, ConstantType, Expression, Identifier, IndexedTraceAccess, MatrixAccess, + NamedTraceAccess, TraceSegment, Variable, VariableType, VectorAccess, +}; +use std::fmt::Display; + +mod error; +use error::ConstraintEvaluationError; + +mod utils; + +mod expressions; +use expressions::ExpressionsHandler; + +use std::collections::BTreeMap; +use std::fs::File; +use std::io::Write; + +/// Holds data for JSON generation +#[derive(Default, Debug)] +pub struct CodeGenerator { + num_polys: u16, + num_variables: usize, + constants: Vec, + expressions: Vec, + outputs: Vec, +} + +impl CodeGenerator { + pub fn new(ir: &AirIR, extension_degree: u8) -> Result { + // maps indexes in Node vector in AlgebraicGraph and in `expressions` JSON array + let mut expressions_map = BTreeMap::new(); + + let num_polys = set_num_polys(ir, extension_degree); + let num_variables = set_num_variables(ir); + let constants = set_constants(ir); + + let mut expressions_handler = ExpressionsHandler::new(ir, &constants, &mut expressions_map); + + let mut expressions = expressions_handler.get_expressions()?; + // vector of `expressions` indexes + let outputs = expressions_handler.get_outputs(&mut expressions)?; + + Ok(CodeGenerator { + num_polys, + num_variables, + constants, + expressions, + outputs, + }) + } + + /// Generates constraint evaluation JSON file + pub fn generate(&self, path: &str) -> std::io::Result<()> { + let mut file = File::create(path)?; + file.write_all("{\n".as_bytes())?; + file.write_all(format!("\t\"num_polys\": {},\n", self.num_polys).as_bytes())?; + file.write_all(format!("\t\"num_variables\": {},\n", self.num_variables).as_bytes())?; + file.write_all(format!("\t\"constants\": {:?},\n", self.constants).as_bytes())?; + file.write_all(format!("\t\"expressions\": [\n\t\t{}", self.expressions[0]).as_bytes())?; + for expr in self.expressions.iter().skip(1) { + file.write_all(format!(",\n\t\t{}", expr).as_bytes())?; + } + file.write_all("\n\t],\n".as_bytes())?; + file.write_all(format!("\t\"outputs\": {:?}\n", self.outputs).as_bytes())?; + + file.write_all("}\n".as_bytes())?; + Ok(()) + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Returns total number of trace columns according to provided extension degree. +/// The result is calculated as `number of main columns + (number of aux columns) * extension +/// degree`. +fn set_num_polys(ir: &AirIR, extension_degree: u8) -> u16 { + // TODO: Should all aux columns be extended to be quadratic or cubic? + let num_polys_vec = ir.segment_widths(); + num_polys_vec + .iter() + .skip(1) + .fold(num_polys_vec[0], |acc, &x| { + acc + x * extension_degree as u16 + }) +} + +/// Returns total number of public inputs and random values. +fn set_num_variables(ir: &AirIR) -> usize { + let mut num_variables = 0; + + // public inputs + for input in ir.public_inputs() { + num_variables += input.1; + } + + num_variables + ir.num_random_values() as usize +} + +/// Returns a vector of all unique constants: named ones defined in `constants` section and inline +/// ones used in constraints calculation. Every value in vector or matrix considered as new +/// constant. +/// +/// # Examples +/// +/// Fragment of AIR script: +/// +/// ```airscript +/// const A = 1 +/// const B = [0, 1] +/// const C = [[1, 2], [2, 0]] +/// +/// boundary_constraints: +/// enf a.first = 1 +/// enf a.last = 5 +/// ``` +/// +/// Result vector: `[1, 0, 2, 5]` +fn set_constants(ir: &AirIR) -> Vec { + //named constants + let mut constants = Vec::new(); + for constant in ir.constants() { + match constant.value() { + ConstantType::Scalar(value) => { + if !constants.contains(value) { + constants.push(*value); + } + } + ConstantType::Vector(values) => { + for elem in values { + if !constants.contains(elem) { + constants.push(*elem); + } + } + } + ConstantType::Matrix(values) => { + for elem in values.iter().flatten() { + if !constants.contains(elem) { + constants.push(*elem); + } + } + } + } + } + + // inline constants + for node in ir.constraint_graph().nodes() { + match node.op() { + Operation::Constant(ConstantValue::Inline(value)) => { + if !constants.contains(value) { + constants.push(*value); + } + } + Operation::Exp(_, degree) => { + if *degree == 0 { + if !constants.contains(&1) { + constants.push(1); // constant needed for optimization, since node^0 is Const(1) + } + } else if !constants.contains(&(*degree as u64)) { + constants.push(*degree as u64) + } + } + _ => {} + } + } + + constants +} + +/// Stroes node type required in [NodeReference] struct +#[derive(Debug, Clone)] +pub enum NodeType { + Pol, + PolNext, + Var, + Const, + Expr, +} + +impl Display for NodeType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Pol => write!(f, "POL"), + Self::PolNext => write!(f, "POL_NEXT"), + Self::Var => write!(f, "VAR"), + Self::Const => write!(f, "CONST"), + Self::Expr => write!(f, "EXPR"), + } + } +} + +#[derive(Clone, Debug)] +pub enum ExpressionOperation { + Add, + Sub, + Mul, +} + +impl Display for ExpressionOperation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Add => write!(f, "ADD"), + Self::Sub => write!(f, "SUB"), + Self::Mul => write!(f, "MUL"), + } + } +} + +/// Stores data used in JSON generation +#[derive(Debug, Clone)] +pub struct NodeReference { + pub node_type: NodeType, + pub index: usize, +} + +impl Display for NodeReference { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{\"type\": \"{}\", \"index\": {}}}", + self.node_type, self.index + ) + } +} + +/// Stores data used in JSON generation +#[derive(Clone, Debug)] +pub struct ExpressionJson { + pub op: ExpressionOperation, + pub lhs: NodeReference, + pub rhs: NodeReference, +} + +impl Display for ExpressionJson { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{\"op\": \"{}\", \"lhs\": {}, \"rhs\": {}}}", + self.op, self.lhs, self.rhs + ) + } +} diff --git a/codegen/gce/src/utils.rs b/codegen/gce/src/utils.rs new file mode 100644 index 00000000..403dde1b --- /dev/null +++ b/codegen/gce/src/utils.rs @@ -0,0 +1,86 @@ +use super::error::ConstraintEvaluationError; +pub use air_script_core::{ + Constant, ConstantType, Expression, Identifier, IndexedTraceAccess, MatrixAccess, + NamedTraceAccess, TraceSegment, Variable, VariableType, VectorAccess, +}; +use ir::AirIR; + +/// Returns index of the constant found in the `constants` array by its value +pub fn get_constant_index_by_value( + v: u64, + constants: &[u64], +) -> Result { + constants + .iter() + .position(|&x| x == v) + .ok_or_else(|| ConstraintEvaluationError::constant_not_found(&v.to_string())) +} + +/// Returns index of the constant found in the `constants` array by its `name` +pub fn get_constant_index_by_name( + ir: &AirIR, + name: &String, + constants: &[u64], +) -> Result { + let constant = ir + .constants() + .iter() + .find(|v| v.name().name() == name) + .ok_or_else(|| ConstraintEvaluationError::constant_not_found(name))?; + let value = match constant.value() { + ConstantType::Scalar(s) => Ok(*s), + _ => Err(ConstraintEvaluationError::invalid_constant_type( + name, "Scalar", + )), + }?; + get_constant_index_by_value(value, constants) +} + +/// Returns index of the constant found in the `constants` array by its vector access (name and +/// index) +pub fn get_constant_index_by_vector_access( + ir: &AirIR, + vector_access: &VectorAccess, + constants: &[u64], +) -> Result { + let constant = ir + .constants() + .iter() + .find(|v| v.name().name() == vector_access.name()) + .ok_or_else(|| ConstraintEvaluationError::constant_not_found(vector_access.name()))?; + let value = match constant.value() { + ConstantType::Vector(v) => Ok(v[vector_access.idx()]), + _ => Err(ConstraintEvaluationError::invalid_constant_type( + vector_access.name(), + "Vector", + )), + }?; + get_constant_index_by_value(value, constants) +} + +/// Returns index of the constant found in the `constants` array by its matrix access (name and +/// indexes) +pub fn get_constant_index_by_matrix_access( + ir: &AirIR, + matrix_access: &MatrixAccess, + constants: &[u64], +) -> Result { + let constant = ir + .constants() + .iter() + .find(|v| v.name().name() == matrix_access.name()) + .ok_or_else(|| ConstraintEvaluationError::constant_not_found(matrix_access.name()))?; + + let value = match constant.value() { + ConstantType::Matrix(m) => Ok(m[matrix_access.row_idx()][matrix_access.col_idx()]), + _ => Err(ConstraintEvaluationError::invalid_constant_type( + matrix_access.name(), + "Matrix", + )), + }?; + get_constant_index_by_value(value, constants) +} + +pub fn get_random_value_index(ir: &AirIR, rand_index: usize) -> usize { + ir.public_inputs().iter().map(|v| v.1).sum::() + rand_index +} diff --git a/ir/src/constraints/graph.rs b/ir/src/constraints/graph.rs index 6cd69e83..70f1df84 100644 --- a/ir/src/constraints/graph.rs +++ b/ir/src/constraints/graph.rs @@ -44,6 +44,10 @@ impl AlgebraicGraph { &self.nodes[index.0] } + pub fn nodes(&self) -> &Vec { + &self.nodes + } + /// Returns the degree of the subgraph which has the specified node as its tip. pub fn degree(&self, index: &NodeIndex) -> IntegrityConstraintDegree { let mut cycles: BTreeMap = BTreeMap::new(); @@ -568,6 +572,12 @@ impl AlgebraicGraph { #[derive(Debug, Default, Clone, Copy, Eq, PartialEq)] pub struct NodeIndex(usize); +impl NodeIndex { + pub fn index(&self) -> usize { + self.0 + } +} + #[derive(Debug)] pub struct Node { /// The operation represented by this node diff --git a/ir/src/lib.rs b/ir/src/lib.rs index f40952d6..92f4bc54 100644 --- a/ir/src/lib.rs +++ b/ir/src/lib.rs @@ -40,6 +40,7 @@ pub type BoundaryConstraintsMap = BTreeMap; pub struct AirIR { air_name: String, segment_widths: Vec, + num_random_values: u16, constants: Constants, public_inputs: PublicInputs, periodic_columns: PeriodicColumns, @@ -120,7 +121,7 @@ impl AirIR { } } - let (segment_widths, constants, public_inputs, periodic_columns) = + let (segment_widths, num_random_values, constants, public_inputs, periodic_columns) = symbol_table.into_declarations(); // validate sections @@ -129,6 +130,7 @@ impl AirIR { Ok(Self { air_name: air_name.to_string(), segment_widths, + num_random_values, constants, public_inputs, periodic_columns, @@ -146,6 +148,10 @@ impl AirIR { &self.constants } + pub fn num_random_values(&self) -> u16 { + self.num_random_values + } + pub fn segment_widths(&self) -> &Vec { &self.segment_widths } diff --git a/ir/src/symbol_table.rs b/ir/src/symbol_table.rs index a0487f2e..b1ca5fa2 100644 --- a/ir/src/symbol_table.rs +++ b/ir/src/symbol_table.rs @@ -204,9 +204,12 @@ impl SymbolTable { /// Consumes this symbol table and returns the information required for declaring constants, /// public inputs, periodic columns and columns amount for the AIR. - pub(super) fn into_declarations(self) -> (Vec, Constants, PublicInputs, PeriodicColumns) { + pub(super) fn into_declarations( + self, + ) -> (Vec, u16, Constants, PublicInputs, PeriodicColumns) { ( self.segment_widths, + self.num_random_values, self.constants, self.public_inputs, self.periodic_columns, From 2422925895e07db5dd0f7de2f086a4c361eb5343 Mon Sep 17 00:00:00 2001 From: hmuro andrej Date: Thu, 9 Feb 2023 00:23:03 +0300 Subject: [PATCH 2/4] test: add GCE testing --- Cargo.toml | 3 +- air-script/Cargo.toml | 1 + air-script/src/lib.rs | 5 +- air-script/tests/aux_trace/aux_trace.json | 30 ++ air-script/tests/binary/binary.json | 15 + air-script/tests/constants/constants.json | 39 +++ air-script/tests/helpers.rs | 29 +- .../indexed_trace_access.json | 13 + air-script/tests/main.rs | 302 +++++++++++++++--- air-script/tests/pub_inputs/pub_inputs.json | 18 ++ .../tests/random_values/random_values.json | 17 + air-script/tests/system/system.json | 11 + .../trace_col_groups/trace_col_groups.json | 13 + codegen/gce/Cargo.toml | 4 +- codegen/gce/src/expressions.rs | 21 +- codegen/gce/src/utils.rs | 11 +- 16 files changed, 461 insertions(+), 71 deletions(-) create mode 100644 air-script/tests/aux_trace/aux_trace.json create mode 100644 air-script/tests/binary/binary.json create mode 100644 air-script/tests/constants/constants.json create mode 100644 air-script/tests/indexed_trace_access/indexed_trace_access.json create mode 100644 air-script/tests/pub_inputs/pub_inputs.json create mode 100644 air-script/tests/random_values/random_values.json create mode 100644 air-script/tests/system/system.json create mode 100644 air-script/tests/trace_col_groups/trace_col_groups.json diff --git a/Cargo.toml b/Cargo.toml index 86d40ddb..d57ea966 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,5 +3,6 @@ members = [ "air-script", "parser", "ir", - "codegen/winterfell" + "codegen/winterfell", + "codegen/gce" ] \ No newline at end of file diff --git a/air-script/Cargo.toml b/air-script/Cargo.toml index 5b9234b5..42b4bb80 100644 --- a/air-script/Cargo.toml +++ b/air-script/Cargo.toml @@ -18,6 +18,7 @@ path = "src/main.rs" [dependencies] codegen-winter = { package = "air-codegen-winter", path = "../codegen/winterfell", version = "0.2.0" } +codegen-gce = { package = "air-codegen-gce", path = "../codegen/gce", version = "0.1.0" } env_logger = "0.10.0" ir = { package = "air-ir", path = "../ir", version = "0.2.0" } log = { version = "0.4", default-features = false } diff --git a/air-script/src/lib.rs b/air-script/src/lib.rs index e04db514..0042685d 100644 --- a/air-script/src/lib.rs +++ b/air-script/src/lib.rs @@ -7,5 +7,8 @@ pub use parser::parse; /// AirScript intermediate representation pub use ir::AirIR; +/// JSON file generation in generic constraint evaluation format +pub use codegen_gce::CodeGenerator as GceCodeGenerator; + /// Code generation targeting Rust for the Winterfell prover -pub use codegen_winter::CodeGenerator; +pub use codegen_winter::CodeGenerator as WinterfellCodeGenerator; diff --git a/air-script/tests/aux_trace/aux_trace.json b/air-script/tests/aux_trace/aux_trace.json new file mode 100644 index 00000000..956baeb3 --- /dev/null +++ b/air-script/tests/aux_trace/aux_trace.json @@ -0,0 +1,30 @@ +{ + "num_polys": 7, + "num_variables": 18, + "constants": [1], + "expressions": [ + {"op": "SUB", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 1}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 3}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 4}, "rhs": {"type": "VAR", "index": 16}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 4}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "MUL", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "POL", "index": 1}}, + {"op": "MUL", "lhs": {"type": "EXPR", "index": 5}, "rhs": {"type": "POL", "index": 2}}, + {"op": "ADD", "lhs": {"type": "POL", "index": 1}, "rhs": {"type": "EXPR", "index": 6}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 0}, "rhs": {"type": "EXPR", "index": 7}}, + {"op": "ADD", "lhs": {"type": "POL", "index": 2}, "rhs": {"type": "POL_NEXT", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 1}, "rhs": {"type": "EXPR", "index": 9}}, + {"op": "ADD", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "POL", "index": 1}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 2}, "rhs": {"type": "EXPR", "index": 11}}, + {"op": "ADD", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "VAR", "index": 16}}, + {"op": "ADD", "lhs": {"type": "EXPR", "index": 13}, "rhs": {"type": "POL", "index": 1}}, + {"op": "ADD", "lhs": {"type": "EXPR", "index": 14}, "rhs": {"type": "VAR", "index": 17}}, + {"op": "MUL", "lhs": {"type": "POL", "index": 3}, "rhs": {"type": "EXPR", "index": 15}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 3}, "rhs": {"type": "EXPR", "index": 16}}, + {"op": "ADD", "lhs": {"type": "POL", "index": 2}, "rhs": {"type": "VAR", "index": 16}}, + {"op": "MUL", "lhs": {"type": "POL_NEXT", "index": 4}, "rhs": {"type": "EXPR", "index": 18}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 4}, "rhs": {"type": "EXPR", "index": 19}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 3}, "rhs": {"type": "CONST", "index": 0}} + ], + "outputs": [0, 1, 8, 10, 12, 2, 21, 3, 4, 17, 20] +} diff --git a/air-script/tests/binary/binary.json b/air-script/tests/binary/binary.json new file mode 100644 index 00000000..e67a993d --- /dev/null +++ b/air-script/tests/binary/binary.json @@ -0,0 +1,15 @@ +{ + "num_polys": 2, + "num_variables": 16, + "constants": [0, 2], + "expressions": [ + {"op": "SUB", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "MUL", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "POL", "index": 0}}, + {"op": "SUB", "lhs": {"type": "EXPR", "index": 1}, "rhs": {"type": "POL", "index": 0}}, + {"op": "SUB", "lhs": {"type": "EXPR", "index": 2}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "MUL", "lhs": {"type": "POL", "index": 1}, "rhs": {"type": "POL", "index": 1}}, + {"op": "SUB", "lhs": {"type": "EXPR", "index": 4}, "rhs": {"type": "POL", "index": 1}}, + {"op": "SUB", "lhs": {"type": "EXPR", "index": 5}, "rhs": {"type": "CONST", "index": 0}} + ], + "outputs": [0, 3, 6] +} diff --git a/air-script/tests/constants/constants.json b/air-script/tests/constants/constants.json new file mode 100644 index 00000000..71a074c9 --- /dev/null +++ b/air-script/tests/constants/constants.json @@ -0,0 +1,39 @@ +{ + "num_polys": 10, + "num_variables": 32, + "constants": [1, 0, 2], + "expressions": [ + {"op": "SUB", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "MUL", "lhs": {"type": "CONST", "index": 1}, "rhs": {"type": "CONST", "index": 2}}, + {"op": "ADD", "lhs": {"type": "CONST", "index": 0}, "rhs": {"type": "EXPR", "index": 1}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 1}, "rhs": {"type": "EXPR", "index": 2}}, + {"op": "SUB", "lhs": {"type": "CONST", "index": 1}, "rhs": {"type": "CONST", "index": 1}}, + {"op": "MUL", "lhs": {"type": "EXPR", "index": 4}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 2}, "rhs": {"type": "EXPR", "index": 5}}, + {"op": "ADD", "lhs": {"type": "CONST", "index": 0}, "rhs": {"type": "CONST", "index": 1}}, + {"op": "SUB", "lhs": {"type": "EXPR", "index": 7}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "ADD", "lhs": {"type": "EXPR", "index": 8}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "EXPR", "index": 9}, "rhs": {"type": "CONST", "index": 2}}, + {"op": "ADD", "lhs": {"type": "EXPR", "index": 10}, "rhs": {"type": "CONST", "index": 2}}, + {"op": "SUB", "lhs": {"type": "EXPR", "index": 11}, "rhs": {"type": "CONST", "index": 1}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 3}, "rhs": {"type": "EXPR", "index": 12}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 4}, "rhs": {"type": "EXPR", "index": 2}}, + {"op": "MUL", "lhs": {"type": "CONST", "index": 0}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "CONST", "index": 0}, "rhs": {"type": "EXPR", "index": 15}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 4}, "rhs": {"type": "EXPR", "index": 16}}, + {"op": "ADD", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 0}, "rhs": {"type": "EXPR", "index": 18}}, + {"op": "MUL", "lhs": {"type": "CONST", "index": 1}, "rhs": {"type": "POL", "index": 1}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 1}, "rhs": {"type": "EXPR", "index": 20}}, + {"op": "ADD", "lhs": {"type": "CONST", "index": 0}, "rhs": {"type": "CONST", "index": 1}}, + {"op": "MUL", "lhs": {"type": "EXPR", "index": 22}, "rhs": {"type": "POL", "index": 2}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 2}, "rhs": {"type": "EXPR", "index": 23}}, + {"op": "ADD", "lhs": {"type": "POL", "index": 4}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "ADD", "lhs": {"type": "EXPR", "index": 25}, "rhs": {"type": "EXPR", "index": 1}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 4}, "rhs": {"type": "EXPR", "index": 26}}, + {"op": "MUL", "lhs": {"type": "CONST", "index": 0}, "rhs": {"type": "CONST", "index": 1}}, + {"op": "ADD", "lhs": {"type": "CONST", "index": 0}, "rhs": {"type": "EXPR", "index": 28}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 4}, "rhs": {"type": "EXPR", "index": 29}} + ], + "outputs": [0, 3, 6, 13, 19, 21, 24, 14, 17, 27, 30] +} diff --git a/air-script/tests/helpers.rs b/air-script/tests/helpers.rs index 1481d339..7e6e7a87 100644 --- a/air-script/tests/helpers.rs +++ b/air-script/tests/helpers.rs @@ -1,6 +1,4 @@ -use codegen_winter::CodeGenerator; -use ir::AirIR; -use parser::parse; +use air_script::{parse, AirIR, GceCodeGenerator, WinterfellCodeGenerator}; use std::fs; #[derive(Debug)] @@ -8,6 +6,7 @@ pub enum TestError { IO(String), Parse(String), IR(String), + Codegen(String), } pub struct Test { @@ -19,12 +18,13 @@ impl Test { Test { input_path } } - pub fn transpile(&self) -> Result { + /// Parse data in file at `input_path` and return [AirIR] with this data + fn generate_ir(&self) -> Result { // load source input from file let source = fs::read_to_string(&self.input_path).map_err(|err| { TestError::IO(format!( "Failed to open input file `{:?}` - {}", - self.input_path, err + &self.input_path, err )) })?; @@ -43,8 +43,25 @@ impl Test { )) })?; + Ok(ir) + } + + /// Generate Rust code containing a Winterfell Air implementation for the AirIR + pub fn generate_winterfell(&self) -> Result { + let ir = Self::generate_ir(self)?; // generate Rust code targeting Winterfell - let codegen = CodeGenerator::new(&ir); + let codegen = WinterfellCodeGenerator::new(&ir); Ok(codegen.generate()) } + + /// Generate JSON file in generic constraint evaluation format + pub fn generate_gce(&self, extension_degree: u8, path: &str) -> Result<(), TestError> { + let ir = Self::generate_ir(self)?; + let codegen = GceCodeGenerator::new(&ir, extension_degree).map_err(|err| { + TestError::Codegen(format!("Failed to create GCECodeGenerator: {:?}", err)) + })?; + codegen + .generate(path) + .map_err(|err| TestError::Codegen(format!("Failed to generate JSON file: {:?}", err))) + } } diff --git a/air-script/tests/indexed_trace_access/indexed_trace_access.json b/air-script/tests/indexed_trace_access/indexed_trace_access.json new file mode 100644 index 00000000..06fdd25e --- /dev/null +++ b/air-script/tests/indexed_trace_access/indexed_trace_access.json @@ -0,0 +1,13 @@ +{ + "num_polys": 6, + "num_variables": 16, + "constants": [1, 0], + "expressions": [ + {"op": "ADD", "lhs": {"type": "POL", "index": 1}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 0}, "rhs": {"type": "EXPR", "index": 0}}, + {"op": "ADD", "lhs": {"type": "POL", "index": 3}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 2}, "rhs": {"type": "EXPR", "index": 2}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "CONST", "index": 1}} + ], + "outputs": [4, 1, 3] +} diff --git a/air-script/tests/main.rs b/air-script/tests/main.rs index c6175242..8b7e7f92 100644 --- a/air-script/tests/main.rs +++ b/air-script/tests/main.rs @@ -1,5 +1,6 @@ use expect_test::expect_file; - +use std::fs::{self, File}; +use std::io::prelude::*; mod helpers; use helpers::Test; @@ -7,129 +8,326 @@ use helpers::Test; // ================================================================================================ #[test] -fn aux_trace() { - let generated_air = Test::new("tests/aux_trace/aux_trace.air".to_string()) - .transpile() - .unwrap(); +fn winterfell_aux_trace() { + let test = Test::new("tests/aux_trace/aux_trace.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["aux_trace/aux_trace.rs"]; expected.assert_eq(&generated_air); } #[test] -fn binary() { - let generated_air = Test::new("tests/binary/binary.air".to_string()) - .transpile() - .unwrap(); +fn gce_aux_trace() { + let test_path = "tests/aux_trace/aux_trace"; + let result_file = "tests/aux_trace/generated_aux_trace.json"; + + let test = Test::new([test_path, "air"].join(".")); + test.generate_gce(2, result_file) + .expect("GCE generation failed"); + + let expected = expect_file![[test_path, "json"].join(".").trim_start_matches("tests/")]; + + let mut file = File::open(result_file).expect("Failed to open file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Failed to read form file"); + + expected.assert_eq(&contents); + + fs::remove_file(result_file).expect("Failed to remove file"); +} + +#[test] +fn winterfell_binary() { + let test = Test::new("tests/binary/binary.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["binary/binary.rs"]; expected.assert_eq(&generated_air); } #[test] -fn periodic_columns() { - let generated_air = Test::new("tests/periodic_columns/periodic_columns.air".to_string()) - .transpile() - .unwrap(); +fn gce_binary() { + let test_path = "tests/binary/binary"; + let result_file = "tests/binary/generated_binary.json"; + + let test = Test::new([test_path, "air"].join(".")); + test.generate_gce(2, result_file) + .expect("GCE generation failed"); + + let expected = expect_file![[test_path, "json"].join(".").trim_start_matches("tests/")]; + + let mut file = File::open(result_file).expect("Failed to open file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Failed to read form file"); + + expected.assert_eq(&contents); + + fs::remove_file(result_file).expect("Failed to remove file"); +} + +#[test] +fn winterfell_periodic_columns() { + let test = Test::new("tests/periodic_columns/periodic_columns.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["periodic_columns/periodic_columns.rs"]; expected.assert_eq(&generated_air); } #[test] -fn pub_inputs() { - let generated_air = Test::new("tests/pub_inputs/pub_inputs.air".to_string()) - .transpile() - .unwrap(); +fn winterfell_pub_inputs() { + let test = Test::new("tests/pub_inputs/pub_inputs.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["pub_inputs/pub_inputs.rs"]; expected.assert_eq(&generated_air); } #[test] -fn system() { - let generated_air = Test::new("tests/system/system.air".to_string()) - .transpile() - .unwrap(); +fn gce_pub_inputs() { + let test_path = "tests/pub_inputs/pub_inputs"; + let result_file = "tests/pub_inputs/generated_pub_inputs.json"; + + let test = Test::new([test_path, "air"].join(".")); + test.generate_gce(2, result_file) + .expect("GCE generation failed"); + + let expected = expect_file![[test_path, "json"].join(".").trim_start_matches("tests/")]; + + let mut file = File::open(result_file).expect("Failed to open file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Failed to read form file"); + + expected.assert_eq(&contents); + + fs::remove_file(result_file).expect("Failed to remove file"); +} + +#[test] +fn winterfell_system() { + let test = Test::new("tests/system/system.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["system/system.rs"]; expected.assert_eq(&generated_air); } #[test] -fn bitwise() { - let generated_air = Test::new("tests/bitwise/bitwise.air".to_string()) - .transpile() - .unwrap(); +fn gce_system() { + let test_path = "tests/system/system"; + let result_file = "tests/system/generated_system.json"; + + let test = Test::new([test_path, "air"].join(".")); + test.generate_gce(2, result_file) + .expect("GCE generation failed"); + + let expected = expect_file![[test_path, "json"].join(".").trim_start_matches("tests/")]; + + let mut file = File::open(result_file).expect("Failed to open file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Failed to read form file"); + + expected.assert_eq(&contents); + + fs::remove_file(result_file).expect("Failed to remove file"); +} + +#[test] +fn winterfell_bitwise() { + let test = Test::new("tests/bitwise/bitwise.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["bitwise/bitwise.rs"]; expected.assert_eq(&generated_air); } #[test] -fn constants() { - let generated_air = Test::new("tests/constants/constants.air".to_string()) - .transpile() - .unwrap(); +fn winterfell_constants() { + let test = Test::new("tests/constants/constants.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["constants/constants.rs"]; expected.assert_eq(&generated_air); } #[test] -fn variables() { - let generated_air = Test::new("tests/variables/variables.air".to_string()) - .transpile() - .unwrap(); +fn gce_constants() { + let test_path = "tests/constants/constants"; + let result_file = "tests/constants/generated_constants.json"; + + let test = Test::new([test_path, "air"].join(".")); + test.generate_gce(2, result_file) + .expect("GCE generation failed"); + + let expected = expect_file![[test_path, "json"].join(".").trim_start_matches("tests/")]; + + let mut file = File::open(result_file).expect("Failed to open file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Failed to read form file"); + + expected.assert_eq(&contents); + + fs::remove_file(result_file).expect("Failed to remove file"); +} + +#[test] +fn winterfell_variables() { + let test = Test::new("tests/variables/variables.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["variables/variables.rs"]; expected.assert_eq(&generated_air); } #[test] -fn trace_col_groups() { - let generated_air = Test::new("tests/trace_col_groups/trace_col_groups.air".to_string()) - .transpile() - .unwrap(); +fn winterfell_trace_col_groups() { + let test = Test::new("tests/trace_col_groups/trace_col_groups.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["trace_col_groups/trace_col_groups.rs"]; expected.assert_eq(&generated_air); } #[test] -fn indexed_trace_access() { - let generated_air = - Test::new("tests/indexed_trace_access/indexed_trace_access.air".to_string()) - .transpile() - .unwrap(); +fn gce_trace_col_groups() { + let test_path = "tests/trace_col_groups/trace_col_groups"; + let result_file = "tests/trace_col_groups/generated_trace_col_groups.json"; + + let test = Test::new([test_path, "air"].join(".")); + test.generate_gce(2, result_file) + .expect("GCE generation failed"); + + let expected = expect_file![[test_path, "json"].join(".").trim_start_matches("tests/")]; + + let mut file = File::open(result_file).expect("Failed to open file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Failed to read form file"); + + expected.assert_eq(&contents); + + fs::remove_file(result_file).expect("Failed to remove file"); +} + +#[test] +fn winterfell_indexed_trace_access() { + let test = Test::new("tests/indexed_trace_access/indexed_trace_access.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["indexed_trace_access/indexed_trace_access.rs"]; expected.assert_eq(&generated_air); } #[test] -fn random_values() { - let generated_air = Test::new("tests/random_values/random_values_simple.air".to_string()) - .transpile() - .unwrap(); +fn gce_indexed_trace_access() { + let test_path = "tests/indexed_trace_access/indexed_trace_access"; + let result_file = "tests/indexed_trace_access/generated_indexed_trace_access.json"; + + let test = Test::new([test_path, "air"].join(".")); + test.generate_gce(2, result_file) + .expect("GCE generation failed"); + + let expected = expect_file![[test_path, "json"].join(".").trim_start_matches("tests/")]; + + let mut file = File::open(result_file).expect("Failed to open file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Failed to read form file"); + + expected.assert_eq(&contents); + + fs::remove_file(result_file).expect("Failed to remove file"); +} + +#[test] +fn winterfell_random_values() { + let test = Test::new("tests/random_values/random_values_simple.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["random_values/random_values.rs"]; expected.assert_eq(&generated_air); - let generated_air = Test::new("tests/random_values/random_values_bindings.air".to_string()) - .transpile() - .unwrap(); + let test = Test::new("tests/random_values/random_values_bindings.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["random_values/random_values.rs"]; expected.assert_eq(&generated_air); } +#[test] +fn gce_random_values() { + let test_path = "tests/random_values/random_values"; + let result_file = "tests/random_values/generated_random_values.json"; + + let test_path_simple = &[test_path, "simple"].join("_"); + let test = Test::new([test_path_simple, "air"].join(".")); + test.generate_gce(2, result_file) + .expect("GCE generation failed"); + + let expected = expect_file![[test_path, "json"].join(".").trim_start_matches("tests/")]; + + let mut file = File::open(result_file).expect("Failed to open file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Failed to read form file"); + + expected.assert_eq(&contents); + + fs::remove_file(result_file).expect("Failed to remove file"); + + let test_path_bindings = &[test_path, "simple"].join("_"); + let test = Test::new([test_path_bindings, "air"].join(".")); + test.generate_gce(2, result_file) + .expect("GCE generation failed"); + + let expected = expect_file![[test_path, "json"].join(".").trim_start_matches("tests/")]; + + let mut file = File::open(result_file).expect("Failed to open file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Failed to read form file"); + + expected.assert_eq(&contents); + + fs::remove_file(result_file).expect("Failed to remove file"); +} + #[test] fn list_comprehension() { // TODO: Improve this test to include more complicated expressions - let generated_air = Test::new("tests/list_comprehension/list_comprehension.air".to_string()) - .transpile() - .unwrap(); + let test = Test::new("tests/list_comprehension/list_comprehension.air".to_string()); + let generated_air = test + .generate_winterfell() + .expect("Failed to generate a Winterfell Air implementation"); let expected = expect_file!["list_comprehension/list_comprehension.rs"]; expected.assert_eq(&generated_air); diff --git a/air-script/tests/pub_inputs/pub_inputs.json b/air-script/tests/pub_inputs/pub_inputs.json new file mode 100644 index 00000000..95d86df2 --- /dev/null +++ b/air-script/tests/pub_inputs/pub_inputs.json @@ -0,0 +1,18 @@ +{ + "num_polys": 4, + "num_variables": 32, + "constants": [], + "expressions": [ + {"op": "SUB", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "VAR", "index": 4}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 1}, "rhs": {"type": "VAR", "index": 5}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 2}, "rhs": {"type": "VAR", "index": 6}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 3}, "rhs": {"type": "VAR", "index": 7}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "VAR", "index": 8}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 1}, "rhs": {"type": "VAR", "index": 9}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 2}, "rhs": {"type": "VAR", "index": 10}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 3}, "rhs": {"type": "VAR", "index": 11}}, + {"op": "ADD", "lhs": {"type": "POL", "index": 1}, "rhs": {"type": "POL", "index": 2}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 0}, "rhs": {"type": "EXPR", "index": 8}} + ], + "outputs": [0, 1, 2, 3, 4, 5, 6, 7, 9] +} diff --git a/air-script/tests/random_values/random_values.json b/air-script/tests/random_values/random_values.json new file mode 100644 index 00000000..d3074978 --- /dev/null +++ b/air-script/tests/random_values/random_values.json @@ -0,0 +1,17 @@ +{ + "num_polys": 6, + "num_variables": 32, + "constants": [], + "expressions": [ + {"op": "ADD", "lhs": {"type": "VAR", "index": 21}, "rhs": {"type": "VAR", "index": 19}}, + {"op": "ADD", "lhs": {"type": "EXPR", "index": 0}, "rhs": {"type": "VAR", "index": 31}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 2}, "rhs": {"type": "EXPR", "index": 1}}, + {"op": "ADD", "lhs": {"type": "VAR", "index": 16}, "rhs": {"type": "VAR", "index": 31}}, + {"op": "ADD", "lhs": {"type": "EXPR", "index": 3}, "rhs": {"type": "VAR", "index": 27}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 2}, "rhs": {"type": "EXPR", "index": 4}}, + {"op": "SUB", "lhs": {"type": "VAR", "index": 31}, "rhs": {"type": "VAR", "index": 16}}, + {"op": "ADD", "lhs": {"type": "EXPR", "index": 6}, "rhs": {"type": "VAR", "index": 19}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 2}, "rhs": {"type": "EXPR", "index": 7}} + ], + "outputs": [2, 5, 8] +} diff --git a/air-script/tests/system/system.json b/air-script/tests/system/system.json new file mode 100644 index 00000000..ad213f2b --- /dev/null +++ b/air-script/tests/system/system.json @@ -0,0 +1,11 @@ +{ + "num_polys": 3, + "num_variables": 16, + "constants": [1, 0], + "expressions": [ + {"op": "ADD", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 0}, "rhs": {"type": "EXPR", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 0}, "rhs": {"type": "CONST", "index": 1}} + ], + "outputs": [2, 1] +} diff --git a/air-script/tests/trace_col_groups/trace_col_groups.json b/air-script/tests/trace_col_groups/trace_col_groups.json new file mode 100644 index 00000000..26b1cc0c --- /dev/null +++ b/air-script/tests/trace_col_groups/trace_col_groups.json @@ -0,0 +1,13 @@ +{ + "num_polys": 14, + "num_variables": 16, + "constants": [1, 0], + "expressions": [ + {"op": "ADD", "lhs": {"type": "POL", "index": 2}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 2}, "rhs": {"type": "EXPR", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 1}, "rhs": {"type": "CONST", "index": 0}}, + {"op": "SUB", "lhs": {"type": "POL_NEXT", "index": 1}, "rhs": {"type": "EXPR", "index": 2}}, + {"op": "SUB", "lhs": {"type": "POL", "index": 8}, "rhs": {"type": "CONST", "index": 1}} + ], + "outputs": [1, 3, 4] +} diff --git a/codegen/gce/Cargo.toml b/codegen/gce/Cargo.toml index 8ba246e3..ec8562ff 100644 --- a/codegen/gce/Cargo.toml +++ b/codegen/gce/Cargo.toml @@ -12,6 +12,6 @@ edition = "2021" rust-version = "1.65" [dependencies] -air-script-core = {package = "air-script-core", path="../../air-script-core", version="0.1.0" } -ir = {package = "air-ir", path="../../ir", version="0.1.0" } +air-script-core = { package = "air-script-core", path="../../air-script-core", version="0.2.0" } +ir = { package = "air-ir", path="../../ir", version="0.2.0" } codegen = "0.2.0" \ No newline at end of file diff --git a/codegen/gce/src/expressions.rs b/codegen/gce/src/expressions.rs index 7d569731..ce4606d2 100644 --- a/codegen/gce/src/expressions.rs +++ b/codegen/gce/src/expressions.rs @@ -2,7 +2,8 @@ use super::error::ConstraintEvaluationError; use super::{ utils::{ get_constant_index_by_matrix_access, get_constant_index_by_name, - get_constant_index_by_value, get_constant_index_by_vector_access, get_random_value_index, + get_constant_index_by_value, get_constant_index_by_vector_access, get_public_input_index, + get_random_value_index, }, ExpressionJson, ExpressionOperation, NodeReference, NodeType, }; @@ -42,7 +43,6 @@ impl<'a> ExpressionsHandler<'a> { // the column is .first constraint and second is .last constraint (in the boundary section, not // entire array) let mut expressions = Vec::new(); - for (index, node) in self.ir.constraint_graph().nodes().iter().enumerate() { match node.op() { Operation::Add(l, r) => { @@ -123,6 +123,9 @@ impl<'a> ExpressionsHandler<'a> { .ok_or_else(|| { ConstraintEvaluationError::operation_not_found(root.node_index().index()) })?; + // if a same constraint is found twice, this means that it is used for both first + // and last row of the column, so we should add this expression to the expressions + // array again. if outputs.contains(index) { expressions.push(expressions[*index].clone()); outputs.push(expressions.len() - 1); @@ -267,7 +270,14 @@ impl<'a> ExpressionsHandler<'a> { } } RandomValue(rand_index) => { - let index = get_random_value_index(self.ir, *rand_index); + let index = get_random_value_index(self.ir, rand_index); + Ok(NodeReference { + node_type: NodeType::Var, + index, + }) + } + PublicInput(name, public_index) => { + let index = get_public_input_index(self.ir, name, public_index); Ok(NodeReference { node_type: NodeType::Var, index, @@ -275,11 +285,6 @@ impl<'a> ExpressionsHandler<'a> { } PeriodicColumn(_column, _length) => todo!(), - - // Currently it can only be `Neg` - _ => Err(ConstraintEvaluationError::InvalidOperation( - "Invalid transition constraint operation".to_string(), - )), } } diff --git a/codegen/gce/src/utils.rs b/codegen/gce/src/utils.rs index 403dde1b..d895987c 100644 --- a/codegen/gce/src/utils.rs +++ b/codegen/gce/src/utils.rs @@ -81,6 +81,15 @@ pub fn get_constant_index_by_matrix_access( get_constant_index_by_value(value, constants) } -pub fn get_random_value_index(ir: &AirIR, rand_index: usize) -> usize { +pub fn get_random_value_index(ir: &AirIR, rand_index: &usize) -> usize { ir.public_inputs().iter().map(|v| v.1).sum::() + rand_index } + +pub fn get_public_input_index(ir: &AirIR, name: &String, public_index: &usize) -> usize { + ir.public_inputs() + .iter() + .take_while(|v| v.0 != *name) + .map(|v| v.1) + .sum::() + + public_index +} From 6bc503842089025850c85a326e32ca6a054c7cc3 Mon Sep 17 00:00:00 2001 From: hmuro andrej Date: Thu, 9 Feb 2023 00:37:39 +0300 Subject: [PATCH 3/4] chore: remove unused errors, update error fromat --- codegen/gce/src/error.rs | 33 +++++++-------------------------- codegen/gce/src/expressions.rs | 6 +++--- codegen/gce/src/lib.rs | 2 +- 3 files changed, 11 insertions(+), 30 deletions(-) diff --git a/codegen/gce/src/error.rs b/codegen/gce/src/error.rs index b14da5c3..f96aae79 100644 --- a/codegen/gce/src/error.rs +++ b/codegen/gce/src/error.rs @@ -1,51 +1,32 @@ #[derive(Debug)] pub enum ConstraintEvaluationError { - InvalidTraceSegment(String), - InvalidOperation(String), - IdentifierNotFound(String), ConstantNotFound(String), - PublicInputNotFound(String), - OperationNotFound(String), InvalidConstantType(String), + InvalidOperation(String), + InvalidTraceSegment(String), + OperationNotFound(String), } impl ConstraintEvaluationError { pub fn invalid_trace_segment(segment: u8) -> Self { ConstraintEvaluationError::InvalidTraceSegment(format!( - "Trace segment {} is invalid", - segment - )) - } - - pub fn identifier_not_found(name: &str) -> Self { - ConstraintEvaluationError::IdentifierNotFound(format!( - "Identifier {} not found in JSON arrays", - name + "Trace segment {segment} is invalid" )) } pub fn constant_not_found(name: &str) -> Self { - ConstraintEvaluationError::ConstantNotFound(format!("Constant \"{}\" not found", name)) - } - - pub fn public_input_not_found(name: &str) -> Self { - ConstraintEvaluationError::PublicInputNotFound(format!( - "Public Input \"{}\" not found", - name - )) + ConstraintEvaluationError::ConstantNotFound(format!("Constant \"{name}\" not found")) } pub fn invalid_constant_type(name: &str, constant_type: &str) -> Self { ConstraintEvaluationError::InvalidConstantType(format!( - "Invalid type of constant \"{}\". {} exprected.", - name, constant_type + "Invalid type of constant \"{name}\". {constant_type} exprected." )) } pub fn operation_not_found(index: usize) -> Self { ConstraintEvaluationError::OperationNotFound(format!( - "Operation with index {} does not match the expression in the expressions JSON array", - index + "Operation with index {index} does not match the expression in the expressions JSON array" )) } } diff --git a/codegen/gce/src/expressions.rs b/codegen/gce/src/expressions.rs index ce4606d2..b5027236 100644 --- a/codegen/gce/src/expressions.rs +++ b/codegen/gce/src/expressions.rs @@ -123,9 +123,9 @@ impl<'a> ExpressionsHandler<'a> { .ok_or_else(|| { ConstraintEvaluationError::operation_not_found(root.node_index().index()) })?; - // if a same constraint is found twice, this means that it is used for both first - // and last row of the column, so we should add this expression to the expressions - // array again. + // if we found index twice, put the corresponding expression in the expressions + // array again. It means that we have equal boundary constraints for both first + // and last domains (e.g. a.first = 1 and a.last = 1) if outputs.contains(index) { expressions.push(expressions[*index].clone()); outputs.push(expressions.len() - 1); diff --git a/codegen/gce/src/lib.rs b/codegen/gce/src/lib.rs index 45959b78..bc97f306 100644 --- a/codegen/gce/src/lib.rs +++ b/codegen/gce/src/lib.rs @@ -64,7 +64,7 @@ impl CodeGenerator { file.write_all(format!("\t\"constants\": {:?},\n", self.constants).as_bytes())?; file.write_all(format!("\t\"expressions\": [\n\t\t{}", self.expressions[0]).as_bytes())?; for expr in self.expressions.iter().skip(1) { - file.write_all(format!(",\n\t\t{}", expr).as_bytes())?; + file.write_all(format!(",\n\t\t{expr}").as_bytes())?; } file.write_all("\n\t],\n".as_bytes())?; file.write_all(format!("\t\"outputs\": {:?}\n", self.outputs).as_bytes())?; From e839cfa9def95beb84fab5b2021c5348191d5496 Mon Sep 17 00:00:00 2001 From: hmuro andrej Date: Sat, 25 Feb 2023 19:06:22 +0300 Subject: [PATCH 4/4] refactor: update ExpressionHandler to match builder pattern --- air-script/Cargo.toml | 2 +- air-script/tests/main.rs | 1 + codegen/gce/Cargo.toml | 2 +- codegen/gce/src/error.rs | 8 +- codegen/gce/src/expressions.rs | 173 ++++++++++++++++++--------------- codegen/gce/src/lib.rs | 44 ++++----- 6 files changed, 124 insertions(+), 106 deletions(-) diff --git a/air-script/Cargo.toml b/air-script/Cargo.toml index 42b4bb80..c6b82f7a 100644 --- a/air-script/Cargo.toml +++ b/air-script/Cargo.toml @@ -18,7 +18,7 @@ path = "src/main.rs" [dependencies] codegen-winter = { package = "air-codegen-winter", path = "../codegen/winterfell", version = "0.2.0" } -codegen-gce = { package = "air-codegen-gce", path = "../codegen/gce", version = "0.1.0" } +codegen-gce = { package = "air-codegen-gce", path = "../codegen/gce", version = "0.2.0" } env_logger = "0.10.0" ir = { package = "air-ir", path = "../ir", version = "0.2.0" } log = { version = "0.4", default-features = false } diff --git a/air-script/tests/main.rs b/air-script/tests/main.rs index 8b7e7f92..8e0db34a 100644 --- a/air-script/tests/main.rs +++ b/air-script/tests/main.rs @@ -1,6 +1,7 @@ use expect_test::expect_file; use std::fs::{self, File}; use std::io::prelude::*; + mod helpers; use helpers::Test; diff --git a/codegen/gce/Cargo.toml b/codegen/gce/Cargo.toml index ec8562ff..08b493aa 100644 --- a/codegen/gce/Cargo.toml +++ b/codegen/gce/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "air-codegen-gce" -version = "0.1.0" +version = "0.2.0" description="Code generation for the generic constraint evaluation format." authors = ["miden contributors"] readme="README.md" diff --git a/codegen/gce/src/error.rs b/codegen/gce/src/error.rs index f96aae79..f841767f 100644 --- a/codegen/gce/src/error.rs +++ b/codegen/gce/src/error.rs @@ -10,23 +10,23 @@ pub enum ConstraintEvaluationError { impl ConstraintEvaluationError { pub fn invalid_trace_segment(segment: u8) -> Self { ConstraintEvaluationError::InvalidTraceSegment(format!( - "Trace segment {segment} is invalid" + "Trace segment {segment} is invalid." )) } pub fn constant_not_found(name: &str) -> Self { - ConstraintEvaluationError::ConstantNotFound(format!("Constant \"{name}\" not found")) + ConstraintEvaluationError::ConstantNotFound(format!("Constant \"{name}\" not found.")) } pub fn invalid_constant_type(name: &str, constant_type: &str) -> Self { ConstraintEvaluationError::InvalidConstantType(format!( - "Invalid type of constant \"{name}\". {constant_type} exprected." + "Invalid type of constant \"{name}\". Expected \"{constant_type}\"." )) } pub fn operation_not_found(index: usize) -> Self { ConstraintEvaluationError::OperationNotFound(format!( - "Operation with index {index} does not match the expression in the expressions JSON array" + "Operation with index {index} does not match the expression in the JSON expressions array." )) } } diff --git a/codegen/gce/src/expressions.rs b/codegen/gce/src/expressions.rs index b5027236..6d9ebf4f 100644 --- a/codegen/gce/src/expressions.rs +++ b/codegen/gce/src/expressions.rs @@ -15,108 +15,126 @@ use std::collections::BTreeMap; const MAIN_TRACE_SEGMENT_INDEX: u8 = 0; -pub struct ExpressionsHandler<'a> { - ir: &'a AirIR, - constants: &'a [u64], +pub struct GceBuilder { // maps indexes in Node vector in AlgebraicGraph and in `expressions` JSON array - expressions_map: &'a mut BTreeMap, + expressions_map: BTreeMap, + expressions: Vec, + outputs: Vec, } -impl<'a> ExpressionsHandler<'a> { - pub fn new( - ir: &'a AirIR, - constants: &'a [u64], - expressions_map: &'a mut BTreeMap, - ) -> Self { - ExpressionsHandler { - ir, - constants, - expressions_map, +impl GceBuilder { + pub fn new() -> Self { + GceBuilder { + expressions_map: BTreeMap::new(), + expressions: Vec::new(), + outputs: Vec::new(), } } + pub fn build( + &mut self, + ir: &AirIR, + constants: &[u64], + ) -> Result<(), ConstraintEvaluationError> { + self.build_expressions(ir, constants)?; + self.build_outputs(ir)?; + Ok(()) + } + + pub fn into_gce(self) -> Result<(Vec, Vec), ConstraintEvaluationError> { + Ok((self.expressions, self.outputs)) + } + /// Parses expressions in transition graph's Node vector, creates [Expression] instances and pushes /// them to the `expressions` vector. - pub fn get_expressions(&mut self) -> Result, ConstraintEvaluationError> { + fn build_expressions( + &mut self, + ir: &AirIR, + constants: &[u64], + ) -> Result<(), ConstraintEvaluationError> { // TODO: currently we can't create a node reference to the last row (which is required for // main.last and aux.last boundary constraints). Working in assumption that first reference to // the column is .first constraint and second is .last constraint (in the boundary section, not // entire array) - let mut expressions = Vec::new(); - for (index, node) in self.ir.constraint_graph().nodes().iter().enumerate() { + for (index, node) in ir.constraint_graph().nodes().iter().enumerate() { match node.op() { Operation::Add(l, r) => { - expressions.push(self.handle_transition_expression( + self.expressions.push(self.handle_transition_expression( + ir, + constants, ExpressionOperation::Add, *l, *r, )?); // create mapping (index in node graph: index in expressions vector) - self.expressions_map.insert(index, expressions.len() - 1); + self.expressions_map + .insert(index, self.expressions.len() - 1); } Operation::Sub(l, r) => { - expressions.push(self.handle_transition_expression( + self.expressions.push(self.handle_transition_expression( + ir, + constants, ExpressionOperation::Sub, *l, *r, )?); - self.expressions_map.insert(index, expressions.len() - 1); + self.expressions_map + .insert(index, self.expressions.len() - 1); } Operation::Mul(l, r) => { - expressions.push(self.handle_transition_expression( + self.expressions.push(self.handle_transition_expression( + ir, + constants, ExpressionOperation::Mul, *l, *r, )?); - self.expressions_map.insert(index, expressions.len() - 1); + self.expressions_map + .insert(index, self.expressions.len() - 1); } Operation::Exp(i, degree) => { match degree { 0 => { // I decided that node^0 could be emulated using the product of 1*1, but perhaps there are better ways - let index_of_1 = get_constant_index_by_value(1, self.constants)?; + let index_of_1 = get_constant_index_by_value(1, constants)?; let const_1_node = NodeReference { node_type: NodeType::Const, index: index_of_1, }; - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs: const_1_node.clone(), rhs: const_1_node, }); } 1 => { - let lhs = self.handle_node_reference(*i)?; - let degree_index = get_constant_index_by_value(1, self.constants)?; + let lhs = self.handle_node_reference(ir, constants, *i)?; + let degree_index = get_constant_index_by_value(1, constants)?; let rhs = NodeReference { node_type: NodeType::Const, index: degree_index, }; - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs, rhs, }); } - _ => self.handle_exponentiation(&mut expressions, *i, *degree)?, + _ => self.handle_exponentiation(ir, constants, *i, *degree)?, } - self.expressions_map.insert(index, expressions.len() - 1); + self.expressions_map + .insert(index, self.expressions.len() - 1); } _ => {} } } - Ok(expressions) + Ok(()) } /// Fills the `outputs` vector with indexes from `expressions` vector according to the `expressions_map`. - pub fn get_outputs( - &self, - expressions: &mut Vec, - ) -> Result, ConstraintEvaluationError> { - let mut outputs = Vec::new(); - - for i in 0..self.ir.segment_widths().len() { - for root in self.ir.boundary_constraints(i as u8) { + fn build_outputs(&mut self, ir: &AirIR) -> Result<(), ConstraintEvaluationError> { + for i in 0..ir.segment_widths().len() { + for root in ir.boundary_constraints(i as u8) { let index = self .expressions_map .get(&root.node_index().index()) @@ -126,35 +144,35 @@ impl<'a> ExpressionsHandler<'a> { // if we found index twice, put the corresponding expression in the expressions // array again. It means that we have equal boundary constraints for both first // and last domains (e.g. a.first = 1 and a.last = 1) - if outputs.contains(index) { - expressions.push(expressions[*index].clone()); - outputs.push(expressions.len() - 1); + if self.outputs.contains(index) { + self.expressions.push(self.expressions[*index].clone()); + self.outputs.push(self.expressions.len() - 1); } else { - outputs.push(*index); + self.outputs.push(*index); } } - for root in self.ir.validity_constraints(i as u8) { + for root in ir.validity_constraints(i as u8) { let index = self .expressions_map .get(&root.node_index().index()) .ok_or_else(|| { ConstraintEvaluationError::operation_not_found(root.node_index().index()) })?; - outputs.push(*index); + self.outputs.push(*index); } - for root in self.ir.transition_constraints(i as u8) { + for root in ir.transition_constraints(i as u8) { let index = self .expressions_map .get(&root.node_index().index()) .ok_or_else(|| { ConstraintEvaluationError::operation_not_found(root.node_index().index()) })?; - outputs.push(*index); + self.outputs.push(*index); } } - Ok(outputs) + Ok(()) } // --- HELPERS -------------------------------------------------------------------------------- @@ -162,12 +180,14 @@ impl<'a> ExpressionsHandler<'a> { /// Parses expression in transition graph Node vector and returns related [Expression] instance. fn handle_transition_expression( &self, + ir: &AirIR, + constants: &[u64], op: ExpressionOperation, l: NodeIndex, r: NodeIndex, ) -> Result { - let lhs = self.handle_node_reference(l)?; - let rhs = self.handle_node_reference(r)?; + let lhs = self.handle_node_reference(ir, constants, l)?; + let rhs = self.handle_node_reference(ir, constants, r)?; Ok(ExpressionJson { op, lhs, rhs }) } @@ -175,10 +195,12 @@ impl<'a> ExpressionsHandler<'a> { /// [NodeReference] instance. fn handle_node_reference( &self, + ir: &AirIR, + constants: &[u64], i: NodeIndex, ) -> Result { use Operation::*; - match self.ir.constraint_graph().node(&i).op() { + match ir.constraint_graph().node(&i).op() { Add(_, _) | Sub(_, _) | Mul(_, _) | Exp(_, _) => { let index = self .expressions_map @@ -192,14 +214,14 @@ impl<'a> ExpressionsHandler<'a> { Constant(constant_value) => { match constant_value { ConstantValue::Inline(v) => { - let index = get_constant_index_by_value(*v, self.constants)?; + let index = get_constant_index_by_value(*v, constants)?; Ok(NodeReference { node_type: NodeType::Const, index, }) } ConstantValue::Scalar(name) => { - let index = get_constant_index_by_name(self.ir, name, self.constants)?; + let index = get_constant_index_by_name(ir, name, constants)?; Ok(NodeReference { node_type: NodeType::Const, index, @@ -208,22 +230,16 @@ impl<'a> ExpressionsHandler<'a> { ConstantValue::Vector(vector_access) => { // why Constant.name() returns Identifier and VectorAccess.name() works like // VectorAccess.name.name() and returns &str? (same with MatrixAccess) - let index = get_constant_index_by_vector_access( - self.ir, - vector_access, - self.constants, - )?; + let index = + get_constant_index_by_vector_access(ir, vector_access, constants)?; Ok(NodeReference { node_type: NodeType::Const, index, }) } ConstantValue::Matrix(matrix_access) => { - let index = get_constant_index_by_matrix_access( - self.ir, - matrix_access, - self.constants, - )?; + let index = + get_constant_index_by_matrix_access(ir, matrix_access, constants)?; Ok(NodeReference { node_type: NodeType::Const, index, @@ -248,8 +264,8 @@ impl<'a> ExpressionsHandler<'a> { }) } } - i if i < self.ir.segment_widths().len() as u8 => { - let col_index = self.ir.segment_widths()[0..i as usize].iter().sum::() + i if i < ir.segment_widths().len() as u8 => { + let col_index = ir.segment_widths()[0..i as usize].iter().sum::() as usize + trace_access.col_idx(); if trace_access.row_offset() == 0 { @@ -270,14 +286,14 @@ impl<'a> ExpressionsHandler<'a> { } } RandomValue(rand_index) => { - let index = get_random_value_index(self.ir, rand_index); + let index = get_random_value_index(ir, rand_index); Ok(NodeReference { node_type: NodeType::Var, index, }) } PublicInput(name, public_index) => { - let index = get_public_input_index(self.ir, name, public_index); + let index = get_public_input_index(ir, name, public_index); Ok(NodeReference { node_type: NodeType::Var, index, @@ -291,20 +307,21 @@ impl<'a> ExpressionsHandler<'a> { /// Replaces the exponentiation operation with multiplication operations, adding them to the /// expressions vector. fn handle_exponentiation( - &self, - expressions: &mut Vec, + &mut self, + ir: &AirIR, + constants: &[u64], i: NodeIndex, degree: usize, ) -> Result<(), ConstraintEvaluationError> { // base node that we want to raise to a degree - let base_node = self.handle_node_reference(i)?; + let base_node = self.handle_node_reference(ir, constants, i)?; // push node^2 expression - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs: base_node.clone(), rhs: base_node.clone(), }); - let square_node_index = expressions.len() - 1; + let square_node_index = self.expressions.len() - 1; // square the previous expression while there is such an opportunity let mut cur_degree_of_2 = 1; // currently we have node^(2^cur_degree_of_2) = node^(2^1) = node^2 @@ -312,9 +329,9 @@ impl<'a> ExpressionsHandler<'a> { // the last node that we want to square let last_node = NodeReference { node_type: NodeType::Expr, - index: expressions.len() - 1, + index: self.expressions.len() - 1, }; - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs: last_node.clone(), rhs: last_node, @@ -330,9 +347,9 @@ impl<'a> ExpressionsHandler<'a> { // if we need to add first degree (base node) let last_node = NodeReference { node_type: NodeType::Expr, - index: expressions.len() - 1, + index: self.expressions.len() - 1, }; - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs: last_node, rhs: base_node, @@ -342,14 +359,14 @@ impl<'a> ExpressionsHandler<'a> { if 2_usize.pow(cur_degree_of_2 - 1) <= diff { let last_node = NodeReference { node_type: NodeType::Expr, - index: expressions.len() - 1, + index: self.expressions.len() - 1, }; let fitting_degree_of_2_node = NodeReference { node_type: NodeType::Expr, // cur_degree_of_2 shows how many indexes we need to add to reach the largest fitting degree of 2 index: square_node_index + cur_degree_of_2 as usize - 2, }; - expressions.push(ExpressionJson { + self.expressions.push(ExpressionJson { op: ExpressionOperation::Mul, lhs: last_node, rhs: fitting_degree_of_2_node, diff --git a/codegen/gce/src/lib.rs b/codegen/gce/src/lib.rs index bc97f306..e3a6b826 100644 --- a/codegen/gce/src/lib.rs +++ b/codegen/gce/src/lib.rs @@ -1,13 +1,14 @@ -use ir::{ - constraints::{ConstantValue, Operation}, - AirIR, -}; - pub use air_script_core::{ Constant, ConstantType, Expression, Identifier, IndexedTraceAccess, MatrixAccess, NamedTraceAccess, TraceSegment, Variable, VariableType, VectorAccess, }; +use ir::{ + constraints::{ConstantValue, Operation}, + AirIR, +}; use std::fmt::Display; +use std::fs::File; +use std::io::Write; mod error; use error::ConstraintEvaluationError; @@ -15,13 +16,10 @@ use error::ConstraintEvaluationError; mod utils; mod expressions; -use expressions::ExpressionsHandler; - -use std::collections::BTreeMap; -use std::fs::File; -use std::io::Write; +use expressions::GceBuilder; -/// Holds data for JSON generation +/// CodeGenerator is used to generate a JSON file with generic constraint evaluation. The generated +/// file contains the data used for GPU acceleration. #[derive(Default, Debug)] pub struct CodeGenerator { num_polys: u16, @@ -33,18 +31,13 @@ pub struct CodeGenerator { impl CodeGenerator { pub fn new(ir: &AirIR, extension_degree: u8) -> Result { - // maps indexes in Node vector in AlgebraicGraph and in `expressions` JSON array - let mut expressions_map = BTreeMap::new(); - let num_polys = set_num_polys(ir, extension_degree); let num_variables = set_num_variables(ir); let constants = set_constants(ir); - let mut expressions_handler = ExpressionsHandler::new(ir, &constants, &mut expressions_map); - - let mut expressions = expressions_handler.get_expressions()?; - // vector of `expressions` indexes - let outputs = expressions_handler.get_outputs(&mut expressions)?; + let mut gce_builder = GceBuilder::new(); + gce_builder.build(ir, &constants)?; + let (expressions, outputs) = gce_builder.into_gce()?; Ok(CodeGenerator { num_polys, @@ -173,13 +166,18 @@ fn set_constants(ir: &AirIR) -> Vec { constants } -/// Stroes node type required in [NodeReference] struct +/// Stores the node type required by the [NodeReference] struct. #[derive(Debug, Clone)] pub enum NodeType { + // Refers to the value in the trace column at the specified `index` in the current row. Pol, + // Refers to the value in the trace column at the specified `index` in the next row. PolNext, + // Refers to a public input or a random value at the specified `index`. Var, + // Refers to a constant at the specified `index`. Const, + // Refers to a previously defined expression at the specified index. Expr, } @@ -212,7 +210,8 @@ impl Display for ExpressionOperation { } } -/// Stores data used in JSON generation +/// Stores the reference to the node using the type of the node and index in related array of +/// nodes. #[derive(Debug, Clone)] pub struct NodeReference { pub node_type: NodeType, @@ -229,7 +228,8 @@ impl Display for NodeReference { } } -/// Stores data used in JSON generation +/// Stores the expression node using the expression operation and references to the left and rigth +/// nodes. #[derive(Clone, Debug)] pub struct ExpressionJson { pub op: ExpressionOperation,