Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/batching circuit gadgets #396

Merged
2 changes: 2 additions & 0 deletions mp2-common/src/group_hashing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ impl ToTargets for QuinticExtensionTarget {
}

impl FromTargets for CurveTarget {
const NUM_TARGETS: usize = CURVE_TARGET_LEN;

fn from_targets(t: &[Target]) -> Self {
nicholas-mainardi marked this conversation as resolved.
Show resolved Hide resolved
assert!(t.len() >= CURVE_TARGET_LEN);
let x = QuinticExtensionTarget(t[0..EXTENSION_DEGREE].try_into().unwrap());
Expand Down
2 changes: 2 additions & 0 deletions mp2-common/src/keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ pub type OutputHash = Array<U32Target, PACKED_HASH_LEN>;
pub type OutputByteHash = Array<Target, HASH_LEN>;

impl FromTargets for OutputHash {
const NUM_TARGETS: usize = PACKED_HASH_LEN;

fn from_targets(t: &[Target]) -> Self {
OutputHash::from_array(array::from_fn(|i| U32Target(t[i])))
}
Expand Down
1 change: 1 addition & 0 deletions mp2-common/src/u256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ impl ToTargets for UInt256Target {
}

impl FromTargets for UInt256Target {
const NUM_TARGETS: usize = NUM_LIMBS;
// Expects big endian limbs as the standard format for IO
fn from_targets(t: &[Target]) -> Self {
Self::new_from_be_target_limbs(&t[..NUM_LIMBS]).unwrap()
Expand Down
51 changes: 48 additions & 3 deletions mp2-common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use anyhow::{anyhow, Result};
use itertools::Itertools;
use plonky2::field::extension::Extendable;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField};
use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS};
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartialWitness, WitnessWrite};
use plonky2::plonk::circuit_builder::CircuitBuilder;
Expand All @@ -19,12 +19,26 @@ use sha3::Keccak256;

use crate::array::Targetable;
use crate::poseidon::{HashableField, H};
use crate::serialization::circuit_data_serialization::SerializableRichField;
use crate::{group_hashing::EXTENSION_DEGREE, types::HashOutput, ProofTuple};

const TWO_POWER_8: usize = 256;
const TWO_POWER_16: usize = 65536;
const TWO_POWER_24: usize = 16777216;

// check that the closure $f actually panics, printing $msg as error message if the function
// did not panic; this macro is employed in tests in place of #[should_panic] to ensure that a
// panic occurred in the expected function rather than in other parts of the test
#[macro_export]
macro_rules! check_panic {
($f: expr, $msg: expr) => {{
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe($f));
assert!(result.is_err(), $msg);
}};
}

pub use check_panic;

pub fn verify_proof_tuple<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
Expand Down Expand Up @@ -307,17 +321,20 @@ pub fn pack_and_compute_poseidon_target<F: HashableField + Extendable<D>, const
b.hash_n_to_hash_no_pad::<H>(packed)
}

pub trait SelectHashBuilder {
pub trait HashBuilder {
/// Select `first_hash` or `second_hash` as output depending on the Boolean `cond`
fn select_hash(
&mut self,
cond: BoolTarget,
first_hash: &HashOutTarget,
second_hash: &HashOutTarget,
) -> HashOutTarget;

/// Determine whether `first_hash == second_hash`
fn hash_eq(&mut self, first_hash: &HashOutTarget, second_hash: &HashOutTarget) -> BoolTarget;
}

impl<F: RichField + Extendable<D>, const D: usize> SelectHashBuilder for CircuitBuilder<F, D> {
impl<F: RichField + Extendable<D>, const D: usize> HashBuilder for CircuitBuilder<F, D> {
fn select_hash(
&mut self,
cond: BoolTarget,
Expand All @@ -333,6 +350,28 @@ impl<F: RichField + Extendable<D>, const D: usize> SelectHashBuilder for Circuit
.collect_vec(),
)
}

fn hash_eq(&mut self, first_hash: &HashOutTarget, second_hash: &HashOutTarget) -> BoolTarget {
let _true = self._true();
first_hash
.elements
.iter()
.zip(second_hash.elements.iter())
.fold(_true, |acc, (first, second)| {
let is_eq = self.is_equal(*first, *second);
self.and(acc, is_eq)
})
}
}

pub trait SelectTarget {
/// Return `first` if `cond` is true, `second` otherwise
fn select<F: SerializableRichField<D>, const D: usize>(
b: &mut CircuitBuilder<F, D>,
cond: &BoolTarget,
first: &Self,
second: &Self,
) -> Self;
}

pub trait ToFields<F: RichField> {
Expand Down Expand Up @@ -395,10 +434,16 @@ impl<F: RichField> Fieldable<F> for u64 {
}

pub trait FromTargets {
/// Number of targets necessary to instantiate `Self`
const NUM_TARGETS: usize;
nicholas-mainardi marked this conversation as resolved.
Show resolved Hide resolved

/// Number of targets in `t` must be at least `Self::NUM_TARGETS`
fn from_targets(t: &[Target]) -> Self;
}

impl FromTargets for HashOutTarget {
const NUM_TARGETS: usize = NUM_HASH_OUT_ELTS;

fn from_targets(t: &[Target]) -> Self {
HashOutTarget {
elements: create_array(|i| t[i]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use mp2_common::{
serialization::{deserialize, deserialize_array, serialize, serialize_array},
types::CBuilder,
u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256},
utils::{SelectHashBuilder, ToTargets},
utils::{HashBuilder, ToTargets},
D, F,
};
use plonky2::{
Expand Down
75 changes: 55 additions & 20 deletions verifiable-db/src/query/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use itertools::Itertools;
use mp2_common::{
poseidon::{empty_poseidon_hash, HashPermutation},
proof::ProofWithVK,
serialization::{deserialize_long_array, serialize_long_array},
serialization::{
deserialize, deserialize_array, deserialize_long_array, serialize, serialize_array,
serialize_long_array,
},
types::{CBuilder, HashOutput},
u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256},
utils::{Fieldable, ToFields, ToTargets},
Expand Down Expand Up @@ -225,13 +228,18 @@ impl NodeInfo {
}
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub(crate) struct NodeInfoTarget {
/// The hash of the embedded tree at this node. It can be the hash of the row tree if this node is a node in
/// the index tree, or it can be a hash of the cells tree if this node is a node in a rows tree
#[serde(serialize_with = "serialize", deserialize_with = "deserialize")]
pub(crate) embedded_tree_hash: HashOutTarget,
/// Hashes of the children of the current node, first left child and then right child hash. The hash of left/right child
/// is the empty hash (i.e., H("")) if there is no corresponding left/right child for the current node
#[serde(
serialize_with = "serialize_array",
deserialize_with = "deserialize_array"
)]
pub(crate) child_hashes: [HashOutTarget; 2],
/// value stored in the node. It can be a primary index value if the node is a node in the index tree,
/// a secondary index value if the node is a node in a rows tree
Expand Down Expand Up @@ -532,22 +540,29 @@ pub(crate) mod tests {
use crate::query::{
computational_hash_ids::{AggregationOperation, Identifiers},
public_inputs::PublicInputs,
universal_circuit::universal_query_gadget::{CurveOrU256, OutputValues},
};
use alloy::primitives::U256;
use mp2_common::{array::ToField, group_hashing::add_curve_point, utils::ToFields, F};
use itertools::Itertools;
use mp2_common::{
array::ToField,
group_hashing::add_curve_point,
utils::{FromFields, ToFields},
F,
};
use plonky2_ecgfp5::curve::curve::Point;

/// Compute the output values and the overflow number at the specified index by
/// the proofs. It's the test function corresponding to `compute_output_item`.
pub(crate) fn compute_output_item_value<const S: usize>(
/// Aggregate the i-th output values found in `outputs` according to the aggregation operation
/// with identifier `op`. It's the test function corresponding to `OutputValuesTarget::aggregate_outputs`
pub(crate) fn aggregate_output_values<const S: usize>(
i: usize,
proofs: &[&PublicInputs<F, S>],
outputs: &[OutputValues<S>],
op: F,
) -> (Vec<F>, u32)
where
[(); S - 1]:,
{
let proof0 = &proofs[0];
let op = proof0.operation_ids()[i];
let out0 = &outputs[0];

let [op_id, op_min, op_max, op_sum, op_avg] = [
AggregationOperation::IdOp,
Expand All @@ -564,22 +579,17 @@ pub(crate) mod tests {
let is_op_sum = op == op_sum;
let is_op_avg = op == op_avg;

// Check that the all proofs are employing the same aggregation operation.
proofs[1..]
.iter()
.for_each(|p| assert_eq!(p.operation_ids()[i], op));

// Compute the SUM, MIN or MAX value.
let mut sum_overflow = 0;
let mut output = proof0.value_at_index(i);
let mut output = out0.value_at_index(i);
if i == 0 && is_op_id {
// If it's the first proof and the operation is ID,
// the value is a curve point not a Uint256.
output = U256::ZERO;
}
for p in proofs[1..].iter() {
for out in outputs[1..].iter() {
// Get the current proof value.
let mut value = p.value_at_index(i);
let mut value = out.value_at_index(i);
if i == 0 && is_op_id {
// If it's the first proof and the operation is ID,
// the value is a curve point not a Uint256.
Expand All @@ -605,14 +615,14 @@ pub(crate) mod tests {
if i == 0 {
// We always accumulate order-agnostic digest of the proofs for the first item.
output = if is_op_id {
let points: Vec<_> = proofs
let points: Vec<_> = outputs
.iter()
.map(|p| Point::decode(p.first_value_as_curve_point().encode()).unwrap())
.map(|out| Point::decode(out.first_value_as_curve_point().encode()).unwrap())
.collect();
add_curve_point(&points).to_fields()
} else {
// Pad the current output to ``CURVE_TARGET_LEN` for the first item.
PublicInputs::<_, S>::pad_slice_to_curve_len(&output)
CurveOrU256::from_slice(&output).to_vec()
};
}

Expand All @@ -626,4 +636,29 @@ pub(crate) mod tests {

(output, overflow)
}

/// Compute the output values and the overflow number at the specified index by
/// the proofs. It's the test function corresponding to `compute_output_item`.
pub(crate) fn compute_output_item_value<const S: usize>(
i: usize,
proofs: &[&PublicInputs<F, S>],
) -> (Vec<F>, u32)
where
[(); S - 1]:,
{
let proof0 = &proofs[0];
let op = proof0.operation_ids()[i];

// Check that the all proofs are employing the same aggregation operation.
proofs[1..]
.iter()
.for_each(|p| assert_eq!(p.operation_ids()[i], op));

let outputs = proofs
.iter()
.map(|p| OutputValues::from_fields(p.to_values_raw()))
.collect_vec();

aggregate_output_values(i, &outputs, op)
}
}
2 changes: 1 addition & 1 deletion verifiable-db/src/query/aggregation/non_existence_inter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use mp2_common::{
},
types::CBuilder,
u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256},
utils::{SelectHashBuilder, ToTargets},
utils::{HashBuilder, ToTargets},
D, F,
};
use plonky2::{
Expand Down
Loading