Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/dsl #60

Merged
merged 12 commits into from
Jul 1, 2024
Prev Previous commit
Next Next commit
add docs and tests
  • Loading branch information
lonerapier committed May 15, 2024
commit c6c7bbdd0c3596e2f09785af965218ed24930a22
28 changes: 28 additions & 0 deletions src/compiler/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,30 @@
//! Contains a simple DSL for writing circuits that can compile to plonk's polynomials that are used
//! in PLONK.
//!
//! ## DSL
//!
//! let's take an example to compute `x^3 + x + 5`:
//! ### Example
//! ```ignore
//! x public
//! x2 <== x * x
//! out <== x2 * x + 5
//! ```
//!
//! Each line of DSL is a separate constraint. It's parsed and converted to corresponding
//! `WireValues`, i.e. variables and coefficients. Vanilla PLONK supports fan-in 2 arithmetic (add
//! and mul) gates, so each constraint can only support a maximum of 1 output and 2 input variables.
//!
//! Note: Read [`parser`] for DSL rules.
//!
//! ## Program
//!
//! Converts `WireValues` to required polynomials in PLONK, i.e.
//! - Preprocessed polynomials:
//! - selector polynomials: `[QM,QR,QM,QO,QC]`
//! - permutation helpers: `[S1,S2,S3]`
//! - public inputs
//! - witness: `[a,b,c]`
pub mod parser;
pub mod program;
mod utils;
181 changes: 87 additions & 94 deletions src/compiler/parser.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
//! Parses a simple DSL used to define circuits.

//! ## Rules:
//! - supports `<==` for assignment and `===` for arithmetic equality checks
//! - mark variables as public in the beginning.
//! - only supports quadratic constraints, i.e. don't support `y <== x * x * x`.
//! - each token should be separated by space.
//!
//! Outputs parsed output in form of [`WireCoeffs`] values and coefficients.
//!
//! ## Example
//! - `a public` => `(['a', None, None], {'$public': 1, 'a': -1,
//! '$output_coeffs': 0}`
//! - `b <== a * c` => `(['a', 'c', 'b'], {'a*c': 1})`
//! - `d <== a * c - 45 * a + 987` => `(['a', 'c', 'd'], {'a*c': 1, 'a': -45, '': 987})`
// TODOs:
// - incorrect use of &str and String
// - use iterators more
Expand All @@ -11,100 +23,91 @@ use std::{
iter,
};

use super::utils::{get_product_key, is_valid_var_name};
use crate::field::{gf_101::GF101, FiniteField};

/// Gate represents each new constraint in the computation
/// Fan-in 2 Gate representing a constraint in the computation.
/// Each constraint satisfies PLONK's arithmetic equation: `a(X)QL(X) + b(X)QR(X) + a(X)b(X)QM(X) +
/// o(X)QO(X) + QC(X) = 0`.
pub struct Gate {
/// left wire value
pub l: GF101,
/// right wire value
pub r: GF101,
/// output wire
/// output wire, represented as `$output_coeffs` in wire coefficients
pub o: GF101,
/// multiplication wire
pub m: GF101,
/// constant wire
/// constant wire, represented as `$constant` in coefficients
pub c: GF101,
}

/// Values of wires with coefficients of each wire name
#[derive(Debug, PartialEq)]
pub struct WireValues<'a> {
pub struct WireCoeffs<'a> {
/// variable used in each wire
pub wires: Vec<Option<&'a str>>,
/// coefficients of variables in wires
/// coefficients of variables in wires and [`Gate`]
pub coeffs: HashMap<String, i32>,
}

impl<'a> WireValues<'a> {
impl<'a> WireCoeffs<'a> {
fn l(&self) -> GF101 {
match self.coeffs.get(self.wires[0].unwrap()) {
Some(val) => -GF101::from(*val),
match self.wires[0] {
Some(wire) => match self.coeffs.get(wire) {
Some(val) => -GF101::from(*val),
None => GF101::ZERO,
},
None => GF101::ZERO,
}
}

fn r(&self) -> GF101 {
match self.coeffs.get(self.wires[1].unwrap()) {
Some(val) => -GF101::from(*val),
None => GF101::ZERO,
if self.wires[0].is_some() && self.wires[1].is_some() && self.wires[0] != self.wires[1] {
match self.coeffs.get(self.wires[1].unwrap()) {
Some(val) => -GF101::from(*val),
None => GF101::ZERO,
}
} else {
GF101::ZERO
}
}

fn o(&self) -> GF101 {
match self.coeffs.get("$output_coeffs") {
Some(val) => GF101::from(*val),
None => GF101::ZERO,
None => GF101::ONE,
}
}

fn c(&self) -> GF101 {
match self.coeffs.get("") {
match self.coeffs.get("$constant") {
Some(val) => -GF101::from(*val),
None => GF101::ZERO,
}
}

fn m(&self) -> GF101 {
match self.coeffs.get(&get_product_key(self.wires[0].unwrap(), self.wires[1].unwrap())) {
Some(val) => -GF101::from(*val),
None => GF101::ZERO,
match (self.wires[0], self.wires[1]) {
(Some(a), Some(b)) => match self.coeffs.get(&get_product_key(a, b)) {
Some(val) => -GF101::from(*val),
None => GF101::ZERO,
},
_ => GF101::ZERO,
}
}

/// sends
/// sends gate activation coefficients from each wires.
pub fn gate(&self) -> Gate {
// first two variables shouldn't be none for a gate
assert!(self.wires[0].is_some());
assert!(self.wires[1].is_some());

Gate { l: self.l(), r: self.r(), o: self.o(), m: self.m(), c: self.c() }
}
}

/// returns product key required for coefficient mapping in plonk's multiplication gate variable.
/// split `a` and `b` by `*`, sort and join by `*`.
pub fn get_product_key(a: &str, b: &str) -> String {
// TODO: might be a better alternative here
if b.is_empty() {
return a.to_string();
}
if a.is_empty() {
return b.to_string();
}

let mut a_star: Vec<&str> = a.split('*').collect();
a_star.append(&mut b.split('*').collect());

a_star.sort();
a_star.join("*")
}

/// Converts an arithmetic expression containing numbers, variables and {+, -, *}
/// Converts an arithmetic expression containing numbers, variables and `{+, -, *}`
/// into a mapping of term to coefficient
///
/// For example:
/// ['a', '+', 'b', '*', 'c', '*', '5'] becomes {'a': 1, 'b*c': 5}
/// `['a', '+', 'b', '*', 'c', '*', '5']` becomes `{'a': 1, 'b*c': 5}`
///
/// Note that this is a recursive algo, so the input can be a mix of tokens and
/// mapping expressions
Expand Down Expand Up @@ -172,7 +175,7 @@ fn evaluate(exprs: &[&str], first_is_neg: bool) -> HashMap<String, i32> {
} else if exprs[0].trim().parse::<i32>().is_ok() {
let num = exprs[0].trim().parse::<i32>().unwrap_or(0);
let sign = if first_is_neg { -1 } else { 1 };
return HashMap::from([("".to_string(), num * sign)]);
return HashMap::from([("$constant".to_string(), num * sign)]);
} else if is_valid_var_name(exprs[0]) {
let sign = if first_is_neg { -1 } else { 1 };
return HashMap::from([(exprs[0].to_string(), sign)]);
Expand All @@ -181,30 +184,22 @@ fn evaluate(exprs: &[&str], first_is_neg: bool) -> HashMap<String, i32> {
}
}

/// Checks whether a variable name is valid.
/// - len > 0
/// - chars are alphanumeric
/// - 1st element is not a number
fn is_valid_var_name(name: &str) -> bool {
!name.is_empty()
&& name.chars().all(char::is_alphanumeric)
&& !(48u8..=57u8).contains(&name.as_bytes()[0])
}

/// Parse constraints into [`WireValues`] containing wires and corresponding coefficients.
/// Parse constraints into [`WireCoeffs`] containing wires and corresponding coefficients.
///
/// ## Example
///
/// valid equations, and output:
/// - `a === 9` => `([None, None, 'a'], {'': 9})`
/// - `a public` => `(['a', None, None], {'$public': 1, 'a': -1,
/// '$output_coeffs': 0}`
/// - `b <== a * c` => `(['a', 'c', 'b'], {'a*c': 1})`
/// - `d <== a * c - 45 * a + 987` => `(['a', 'c', 'd'], {'a*c': 1, 'a': -45, '': 987})`
///
/// invalid equations:
/// - `7 === 7` => # Can't assign to non-variable
/// - `a <== b * * c` => # Two times signs in a row
/// - `e <== a + b * c * d` => # Multiplicative degree > 2
pub fn parse_constraints(constraint: &str) -> WireValues {
pub fn parse_constraints(constraint: &str) -> WireCoeffs {
let tokens: Vec<&str> = constraint.trim().trim_end_matches('\n').split(' ').collect();
if tokens[1] == "<==" || tokens[1] == "===" {
let mut out = tokens[0];
Expand All @@ -226,7 +221,7 @@ pub fn parse_constraints(constraint: &str) -> WireValues {

let mut allowed_coeffs_set: HashSet<String> =
HashSet::from_iter(variables.iter().map(|var| var.to_string()));
allowed_coeffs_set.extend(["$output_coeffs".to_string(), "".to_string()]);
allowed_coeffs_set.extend(["$output_coeffs".to_string(), "$constant".to_string()]);

match variables.len() {
0 => {},
Expand All @@ -252,53 +247,32 @@ pub fn parse_constraints(constraint: &str) -> WireValues {
variables.into_iter().map(Some).chain(iter::repeat(None).take(2 - variables_len)).collect();
wires.push(Some(out));

WireValues { wires, coeffs }
WireCoeffs { wires, coeffs }
} else if tokens[1] == "public" {
let coeffs = HashMap::from([
(tokens[0].to_string(), -1),
(String::from("$output_coeffs"), 0),
(String::from("$output"), 1),
(String::from("$public"), 1),
]);

return WireValues { wires: vec![Some(tokens[0]), None, None], coeffs };
return WireCoeffs { wires: vec![Some(tokens[0]), None, None], coeffs };
} else {
panic!("unsupported value: {}", constraint);
}
}

#[cfg(test)]
mod tests {
use rstest::rstest;

use super::*;

#[rstest]
#[case("a", "b", "a*b")]
#[case("a*b", "c", "a*b*c")]
#[case("a*c", "d*b", "a*b*c*d")]
#[case("", "", "")]
#[case("", "a", "a")]
#[case("a", "", "a")]
fn product_key(#[case] a: &str, #[case] b: &str, #[case] expected: &str) {
assert_eq!(get_product_key(a, b), expected);
}

#[test]
fn valid_var_name() {
assert!(is_valid_var_name("a"));
assert!(!is_valid_var_name(""));
assert!(is_valid_var_name("abcd"));
assert!(!is_valid_var_name("1"));
}

#[test]
fn wire_values() {
let wire_values = WireValues {
let wire_values = WireCoeffs {
wires: vec![Some("a"), Some("b"), Some("c")],
coeffs: HashMap::from([
(String::from("$output_coeffs"), 2),
(String::from("a"), -1),
(String::from(""), 9),
(String::from("$constant"), 9),
]),
};
let gate = wire_values.gate();
Expand All @@ -308,16 +282,31 @@ mod tests {
assert_eq!(gate.o, GF101::from(2));
assert_eq!(gate.c, -GF101::from(9));

let wire_values = WireValues {
let wire_values = WireCoeffs {
wires: vec![Some("a"), Some("b"), Some("c")],
coeffs: HashMap::from([(String::from("b"), -1), (String::from("a*b"), -9)]),
};
let gate = wire_values.gate();
assert_eq!(gate.l, -GF101::ZERO);
assert_eq!(gate.r, -GF101::from(-1));
assert_eq!(gate.m, -GF101::from(-9));
assert_eq!(gate.o, GF101::ZERO);
assert_eq!(gate.o, GF101::ONE);
assert_eq!(gate.c, -GF101::ZERO);

let wire_values = WireCoeffs {
wires: vec![Some("a"), None, None],
coeffs: HashMap::from([
(String::from("$output"), 1),
(String::from("a"), -1),
(String::from("$output_coeffs"), 0),
]),
};
let gate = wire_values.gate();
assert_eq!(gate.l, GF101::ONE);
assert_eq!(gate.r, GF101::ZERO);
assert_eq!(gate.m, GF101::ZERO);
assert_eq!(gate.o, GF101::ZERO);
assert_eq!(gate.c, GF101::ZERO);
}

#[test]
Expand Down Expand Up @@ -347,46 +336,50 @@ mod tests {

let expr = ["-10", "+", "c", "*", "-8", "-", "11"];
let res = evaluate(&expr, false);
assert_eq!(res, HashMap::from([("c".to_string(), -8), ("".to_string(), -21)]));
assert_eq!(res, HashMap::from([("c".to_string(), -8), ("$constant".to_string(), -21)]));

let expr = ["-2", "*", "b", "-", "a", "*", "b"];
let res = evaluate(&expr, false);
assert_eq!(res, HashMap::from([("a*b".to_string(), -1), ("b".to_string(), -2)]));
}

#[test]
fn circuit_parse_constraints() {
let wire_values = parse_constraints("a <== b * c");
assert_eq!(wire_values, WireValues {
assert_eq!(wire_values, WireCoeffs {
wires: vec![Some("b"), Some("c"), Some("a")],
coeffs: HashMap::from([(String::from("b*c"), 1)]),
});

let wire_values = parse_constraints("a public");
assert_eq!(wire_values, WireValues {
assert_eq!(wire_values, WireCoeffs {
wires: vec![Some("a"), None, None],
coeffs: HashMap::from([
(String::from("$output_coeffs"), 0),
(String::from("$output"), 1),
(String::from("$public"), 1),
(String::from("a"), -1)
]),
});

let wire_values = parse_constraints("a === 9");
assert_eq!(wire_values, WireValues {
assert_eq!(wire_values, WireCoeffs {
wires: vec![None, None, Some("a")],
coeffs: HashMap::from([(String::from(""), 9)]),
coeffs: HashMap::from([(String::from("$constant"), 9)]),
});

let wire_values = parse_constraints("b <== a + 9 * 10");
assert_eq!(wire_values, WireValues {
assert_eq!(wire_values, WireCoeffs {
wires: vec![Some("a"), Some("a"), Some("b")],
coeffs: HashMap::from([(String::from("a"), 1), (String::from(""), 90)]),
coeffs: HashMap::from([(String::from("a"), 1), (String::from("$constant"), 90)]),
});

let wire_values = parse_constraints("-a <== b * -c * -9 - 10");
assert_eq!(wire_values, WireValues {
assert_eq!(wire_values, WireCoeffs {
wires: vec![Some("b"), Some("c"), Some("a")],
coeffs: HashMap::from([
(String::from("$output_coeffs"), -1),
(String::from("b*c"), 9),
(String::from(""), -10)
(String::from("$constant"), -10)
]),
});
}
Expand Down
Loading