Skip to content

Commit

Permalink
copied over test_shape_cs from nova
Browse files Browse the repository at this point in the history
  • Loading branch information
arasuarun committed Sep 29, 2023
1 parent bdda386 commit aca9967
Show file tree
Hide file tree
Showing 3 changed files with 329 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/bellpepper/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
pub mod r1cs;
pub mod shape_cs;
pub mod solver;
pub mod test_shape_cs;

#[cfg(test)]
mod tests {
Expand Down
3 changes: 2 additions & 1 deletion src/bellpepper/r1cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#![allow(non_snake_case)]

use super::{shape_cs::ShapeCS, solver::SatisfyingAssignment};
use super::{shape_cs::ShapeCS, solver::SatisfyingAssignment, test_shape_cs::TestShapeCS};
use crate::{
errors::SpartanError,
r1cs::{R1CSInstance, R1CSShape, R1CSWitness, R1CS},
Expand Down Expand Up @@ -93,6 +93,7 @@ macro_rules! impl_spartan_shape {
}

impl_spartan_shape!(ShapeCS);
impl_spartan_shape!(TestShapeCS);

fn add_constraint<S: PrimeField>(
X: &mut (
Expand Down
326 changes: 326 additions & 0 deletions src/bellpepper/test_shape_cs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
//! Support for generating R1CS shape using bellpepper.
//! `TestShapeCS` implements a superset of `ShapeCS`, adding non-trivial namespace support for use in testing.
use std::{
cmp::Ordering,
collections::{BTreeMap, HashMap},
};

use crate::traits::Group;
use bellpepper_core::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable};
use core::fmt::Write;
use ff::{Field, PrimeField};

#[derive(Clone, Copy)]
struct OrderedVariable(Variable);

#[derive(Debug)]
enum NamedObject {
Constraint(usize),
Var(Variable),
Namespace,
}

impl Eq for OrderedVariable {}
impl PartialEq for OrderedVariable {
fn eq(&self, other: &OrderedVariable) -> bool {
match (self.0.get_unchecked(), other.0.get_unchecked()) {
(Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => a == b,
_ => false,
}
}
}
impl PartialOrd for OrderedVariable {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderedVariable {
fn cmp(&self, other: &Self) -> Ordering {
match (self.0.get_unchecked(), other.0.get_unchecked()) {
(Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => {
a.cmp(b)
}
(Index::Input(_), Index::Aux(_)) => Ordering::Less,
(Index::Aux(_), Index::Input(_)) => Ordering::Greater,
}
}
}

#[allow(clippy::upper_case_acronyms)]
/// `TestShapeCS` is a `ConstraintSystem` for creating `R1CSShape`s for a circuit.
pub struct TestShapeCS<G: Group>
where
G::Scalar: PrimeField + Field,
{
named_objects: HashMap<String, NamedObject>,
current_namespace: Vec<String>,
#[allow(clippy::type_complexity)]
/// All constraints added to the `TestShapeCS`.
pub constraints: Vec<(
LinearCombination<G::Scalar>,
LinearCombination<G::Scalar>,
LinearCombination<G::Scalar>,
String,
)>,
inputs: Vec<String>,
aux: Vec<String>,
}

fn proc_lc<Scalar: PrimeField>(
terms: &LinearCombination<Scalar>,
) -> BTreeMap<OrderedVariable, Scalar> {
let mut map = BTreeMap::new();
for (var, &coeff) in terms.iter() {
map
.entry(OrderedVariable(var))
.or_insert_with(|| Scalar::ZERO)
.add_assign(&coeff);
}

// Remove terms that have a zero coefficient to normalize
let mut to_remove = vec![];
for (var, coeff) in map.iter() {
if coeff.is_zero().into() {
to_remove.push(*var)
}
}

for var in to_remove {
map.remove(&var);
}

map
}

impl<G: Group> TestShapeCS<G>
where
G::Scalar: PrimeField,
{
#[allow(unused)]
/// Create a new, default `TestShapeCS`,
pub fn new() -> Self {
TestShapeCS::default()
}

/// Returns the number of constraints defined for this `TestShapeCS`.
pub fn num_constraints(&self) -> usize {
self.constraints.len()
}

/// Returns the number of inputs defined for this `TestShapeCS`.
pub fn num_inputs(&self) -> usize {
self.inputs.len()
}

/// Returns the number of aux inputs defined for this `TestShapeCS`.
pub fn num_aux(&self) -> usize {
self.aux.len()
}

/// Print all public inputs, aux inputs, and constraint names.
#[allow(dead_code)]
pub fn pretty_print_list(&self) -> Vec<String> {
let mut result = Vec::new();

for input in &self.inputs {
result.push(format!("INPUT {input}"));
}
for aux in &self.aux {
result.push(format!("AUX {aux}"));
}

for (_a, _b, _c, name) in &self.constraints {
result.push(name.to_string());
}

result
}

/// Print all iputs and a detailed representation of each constraint.
#[allow(dead_code)]
pub fn pretty_print(&self) -> String {
let mut s = String::new();

for input in &self.inputs {
writeln!(s, "INPUT {}", &input).unwrap()
}

let negone = -<G::Scalar>::ONE;

let powers_of_two = (0..G::Scalar::NUM_BITS)
.map(|i| G::Scalar::from(2u64).pow_vartime([u64::from(i)]))
.collect::<Vec<_>>();

let pp = |s: &mut String, lc: &LinearCombination<G::Scalar>| {
s.push('(');
let mut is_first = true;
for (var, coeff) in proc_lc::<G::Scalar>(lc) {
if coeff == negone {
s.push_str(" - ")
} else if !is_first {
s.push_str(" + ")
}
is_first = false;

if coeff != <G::Scalar>::ONE && coeff != negone {
for (i, x) in powers_of_two.iter().enumerate() {
if x == &coeff {
write!(s, "2^{i} . ").unwrap();
break;
}
}

write!(s, "{coeff:?} . ").unwrap()
}

match var.0.get_unchecked() {
Index::Input(i) => {
write!(s, "`I{}`", &self.inputs[i]).unwrap();
}
Index::Aux(i) => {
write!(s, "`A{}`", &self.aux[i]).unwrap();
}
}
}
if is_first {
// Nothing was visited, print 0.
s.push('0');
}
s.push(')');
};

for (a, b, c, name) in &self.constraints {
s.push('\n');

write!(s, "{name}: ").unwrap();
pp(&mut s, a);
write!(s, " * ").unwrap();
pp(&mut s, b);
s.push_str(" = ");
pp(&mut s, c);
}

s.push('\n');

s
}

/// Associate `NamedObject` with `path`.
/// `path` must not already have an associated object.
fn set_named_obj(&mut self, path: String, to: NamedObject) {
assert!(
!self.named_objects.contains_key(&path),
"tried to create object at existing path: {path}"
);

self.named_objects.insert(path, to);
}
}

impl<G: Group> Default for TestShapeCS<G>
where
G::Scalar: PrimeField,
{
fn default() -> Self {
let mut map = HashMap::new();
map.insert("ONE".into(), NamedObject::Var(TestShapeCS::<G>::one()));
TestShapeCS {
named_objects: map,
current_namespace: vec![],
constraints: vec![],
inputs: vec![String::from("ONE")],
aux: vec![],
}
}
}

impl<G: Group> ConstraintSystem<G::Scalar> for TestShapeCS<G>
where
G::Scalar: PrimeField,
{
type Root = Self;

fn alloc<F, A, AR>(&mut self, annotation: A, _f: F) -> Result<Variable, SynthesisError>
where
F: FnOnce() -> Result<G::Scalar, SynthesisError>,
A: FnOnce() -> AR,
AR: Into<String>,
{
let path = compute_path(&self.current_namespace, &annotation().into());
self.aux.push(path);

Ok(Variable::new_unchecked(Index::Aux(self.aux.len() - 1)))
}

fn alloc_input<F, A, AR>(&mut self, annotation: A, _f: F) -> Result<Variable, SynthesisError>
where
F: FnOnce() -> Result<G::Scalar, SynthesisError>,
A: FnOnce() -> AR,
AR: Into<String>,
{
let path = compute_path(&self.current_namespace, &annotation().into());
self.inputs.push(path);

Ok(Variable::new_unchecked(Index::Input(self.inputs.len() - 1)))
}

fn enforce<A, AR, LA, LB, LC>(&mut self, annotation: A, a: LA, b: LB, c: LC)
where
A: FnOnce() -> AR,
AR: Into<String>,
LA: FnOnce(LinearCombination<G::Scalar>) -> LinearCombination<G::Scalar>,
LB: FnOnce(LinearCombination<G::Scalar>) -> LinearCombination<G::Scalar>,
LC: FnOnce(LinearCombination<G::Scalar>) -> LinearCombination<G::Scalar>,
{
let path = compute_path(&self.current_namespace, &annotation().into());
let index = self.constraints.len();
self.set_named_obj(path.clone(), NamedObject::Constraint(index));

let a = a(LinearCombination::zero());
let b = b(LinearCombination::zero());
let c = c(LinearCombination::zero());

self.constraints.push((a, b, c, path));
}

fn push_namespace<NR, N>(&mut self, name_fn: N)
where
NR: Into<String>,
N: FnOnce() -> NR,
{
let name = name_fn().into();
let path = compute_path(&self.current_namespace, &name);
self.set_named_obj(path, NamedObject::Namespace);
self.current_namespace.push(name);
}

fn pop_namespace(&mut self) {
assert!(self.current_namespace.pop().is_some());
}

fn get_root(&mut self) -> &mut Self::Root {
self
}
}

fn compute_path(ns: &[String], this: &str) -> String {
assert!(
!this.chars().any(|a| a == '/'),
"'/' is not allowed in names"
);

let mut name = String::new();

let mut needs_separation = false;
for ns in ns.iter().chain(Some(this.to_string()).iter()) {
if needs_separation {
name += "/";
}

name += ns;
needs_separation = true;
}

name
}

0 comments on commit aca9967

Please sign in to comment.