diff --git a/mp2-common/src/utils.rs b/mp2-common/src/utils.rs index 9529f9968..5a5f7eafc 100644 --- a/mp2-common/src/utils.rs +++ b/mp2-common/src/utils.rs @@ -335,6 +335,56 @@ impl, const D: usize> SelectHashBuilder for Circuit } } +pub trait SelectCurveBuilder { + /// Select `first_curve` or `second_curve` as output depending on the Boolean `cond` + fn select_curve( + &mut self, + cond: BoolTarget, + first_curve: &CurveTarget, + second_curve: &CurveTarget, + ) -> CurveTarget; +} + +impl, const D: usize> SelectCurveBuilder for CircuitBuilder { + fn select_curve( + &mut self, + cond: BoolTarget, + first_curve: &CurveTarget, + second_curve: &CurveTarget, + ) -> CurveTarget { + let CurveTarget((first_ext, first_bool)) = first_curve; + let CurveTarget((second_ext, second_bool)) = second_curve; + + let selected_ext = [ + QuinticExtensionTarget::new( + first_ext[0] + .to_target_array() + .iter() + .zip(second_ext[0].to_target_array().iter()) + .map(|(first, second)| self.select(cond, *first, *second)) + .collect::>() + .try_into() + .unwrap(), + ), + QuinticExtensionTarget::new( + first_ext[1] + .to_target_array() + .iter() + .zip(second_ext[1].to_target_array().iter()) + .map(|(first, second)| self.select(cond, *first, *second)) + .collect::>() + .try_into() + .unwrap(), + ), + ]; + + let selected_bool_target = self.select(cond, first_bool.target, second_bool.target); + let selected_bool = BoolTarget::new_unsafe(selected_bool_target); + + CurveTarget((selected_ext, selected_bool)) + } +} + pub trait ToFields { fn to_fields(&self) -> Vec; } diff --git a/verifiable-db/src/results_tree/extraction/child_included_single_path_node.rs b/verifiable-db/src/results_tree/extraction/child_included_single_path_node.rs new file mode 100644 index 000000000..65a2ac926 --- /dev/null +++ b/verifiable-db/src/results_tree/extraction/child_included_single_path_node.rs @@ -0,0 +1,496 @@ +use crate::results_tree::extraction::PublicInputs; +use anyhow::Result; +use mp2_common::{ + hash::hash_maybe_first, + poseidon::{empty_poseidon_hash, H}, + public_inputs::PublicInputCommon, + serialization::{deserialize, serialize}, + types::CBuilder, + u256::CircuitBuilderU256, + utils::{greater_than, less_than, SelectHashBuilder, ToTargets}, + D, F, +}; +use plonky2::{ + iop::{ + target::{BoolTarget, Target}, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::proof::ProofWithPublicInputsTarget, +}; +use recursion_framework::circuit_builder::CircuitLogicWires; +use serde::{Deserialize, Serialize}; +use std::iter; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChildIncludedSinglePathNodeWires { + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + is_left_child: BoolTarget, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + sibling_exists: BoolTarget, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + is_rows_tree: BoolTarget, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChildIncludedSinglePathNodeCircuit { + /// Boolean flag specifying whether the included child is the left child or not + pub(crate) is_left_child: bool, + /// Boolean flag specifying whether the included child has a sibling or not + pub(crate) sibling_exists: bool, + /// Boolean flag specifying whether the current node is a node + /// of a rows tree or of the index tree + pub(crate) is_rows_tree: bool, +} + +impl ChildIncludedSinglePathNodeCircuit { + pub fn build( + b: &mut CBuilder, + subtree_proof: &PublicInputs, + included_chid_proof: &PublicInputs, + sibling_proof: &PublicInputs, + ) -> ChildIncludedSinglePathNodeWires { + let empty_hash = b.constant_hash(*empty_poseidon_hash()); + let one = b.one(); + + let [is_left_child, sibling_exists, is_rows_tree] = + [0; 3].map(|_| b.add_virtual_bool_target_safe()); + + let column_id = b.select( + is_rows_tree, + subtree_proof.index_ids_target()[1], + subtree_proof.index_ids_target()[0], + ); + let node_value = b.select_u256( + is_rows_tree, + &subtree_proof.min_value_target(), + &subtree_proof.primary_index_value_target(), + ); + let sibling_min = b.select_u256( + sibling_exists, + &sibling_proof.min_value_target(), + &node_value, + ); + let sibling_max = b.select_u256( + sibling_exists, + &sibling_proof.max_value_target(), + &node_value, + ); + let node_min = b.select_u256( + is_left_child, + &included_chid_proof.min_value_target(), + &sibling_min, + ); + let node_max = b.select_u256( + is_left_child, + &sibling_max, + &included_chid_proof.max_value_target(), + ); + let sibling_hash = b.select_hash( + sibling_exists, + &sibling_proof.tree_hash_target(), + &empty_hash, + ); + + // Compute the node hash: + // H(left_hash||right_hash||node_min||node_max||column_id||node_value||pR.H) + let rest: Vec<_> = node_min + .to_targets() + .into_iter() + .chain(node_max.to_targets()) + .chain(iter::once(column_id)) + .chain(node_value.to_targets()) + .chain(subtree_proof.tree_hash_target().to_targets()) + .collect(); + + let node_hash = hash_maybe_first( + b, + is_left_child, + sibling_hash.elements, + included_chid_proof.tree_hash_target().elements, + &rest, + ); + + // Enforce consistency of counters + let min_minus_one = b.sub(subtree_proof.min_counter_target(), one); + let sibling_max_counter = b.select( + sibling_exists, + sibling_proof.max_counter_target(), + min_minus_one, + ); + let max_left = b.select( + is_left_child, + included_chid_proof.max_counter_target(), + sibling_max_counter, + ); + let max_plus_one = b.add(subtree_proof.max_counter_target(), one); + let sibling_min_counter = b.select( + sibling_exists, + sibling_proof.min_counter_target(), + max_plus_one, + ); + let min_right = b.select( + is_left_child, + sibling_min_counter, + included_chid_proof.min_counter_target(), + ); + // assert max_left + 1 == pR.min_counter + let left_plus_one = b.add(max_left, one); + b.connect(left_plus_one, subtree_proof.min_counter_target()); + // assert pR.max_counter + 1 == min_right + b.connect(max_plus_one, min_right); + + // Ensure that the record/rows tree stored in the current node contains + // only records with counters outside of [query_min; query_max] range + // left == (left AND (pR.min_counter > pI.offset_range_max)) + let is_greater = greater_than( + b, + subtree_proof.min_counter_target(), + included_chid_proof.offset_range_max_target(), + 32, + ); + let is_greater = b.and(is_greater, is_left_child); + b.connect(is_greater.target, is_left_child.target); + // NOT(left) == (NOT(left) AND( pR.max_counter < pI.offset_range_min)) + let is_right_child = b.not(is_left_child); + let is_less = less_than( + b, + subtree_proof.max_counter_target(), + included_chid_proof.offset_range_min_target(), + 32, + ); + let is_less = b.and(is_less, is_right_child); + b.connect(is_less.target, is_right_child.target); + + // Compute `min_counter` and `max_counter` for current node + let sibling_min_counter = b.select( + sibling_exists, + sibling_proof.min_counter_target(), + subtree_proof.min_counter_target(), + ); + let min_counter = b.select( + is_left_child, + included_chid_proof.min_counter_target(), + sibling_min_counter, + ); + let sibling_max_counter = b.select( + sibling_exists, + sibling_proof.max_counter_target(), + subtree_proof.max_counter_target(), + ); + let max_counter = b.select( + is_left_child, + sibling_max_counter, + included_chid_proof.max_counter_target(), + ); + + // Register the public inputs. + PublicInputs::new( + &node_hash.to_targets(), + &node_min.to_targets(), + &node_max.to_targets(), + subtree_proof.to_primary_index_value_raw(), + subtree_proof.to_index_ids_raw(), + &[min_counter], + &[max_counter], + &[*subtree_proof.to_offset_range_min_raw()], + &[*subtree_proof.to_offset_range_max_raw()], + subtree_proof.to_accumulator_raw(), + ) + .register(b); + + ChildIncludedSinglePathNodeWires { + is_left_child, + sibling_exists, + is_rows_tree, + } + } + + fn assign(&self, pw: &mut PartialWitness, wires: &ChildIncludedSinglePathNodeWires) { + pw.set_bool_target(wires.is_left_child, self.is_left_child); + pw.set_bool_target(wires.sibling_exists, self.sibling_exists); + pw.set_bool_target(wires.is_rows_tree, self.is_rows_tree); + } +} + +/// Subtree proof number = 1, child proof number = 2 +pub(crate) const NUM_VERIFIED_PROOFS: usize = 3; + +impl CircuitLogicWires for ChildIncludedSinglePathNodeWires { + type CircuitBuilderParams = (); + type Inputs = ChildIncludedSinglePathNodeCircuit; + + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); + + fn circuit_logic( + builder: &mut CBuilder, + verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], + _builder_parameters: Self::CircuitBuilderParams, + ) -> Self { + // The first one is the subtree proof, and the remainings are child proofs. + let [subtree_proof, included_child_proof, sibling_proof] = + verified_proofs.map(|p| PublicInputs::from_slice(&p.public_inputs)); + + Self::Inputs::build( + builder, + &subtree_proof, + &included_child_proof, + &sibling_proof, + ) + } + + fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { + inputs.assign(pw, self); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::results_tree::extraction::{ + tests::{random_results_extraction_public_inputs, unify_child_proof, unify_subtree_proof}, + PI_LEN, + }; + use mp2_common::{utils::ToFields, C}; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::plonk::config::Hasher; + use std::array; + + #[derive(Clone, Debug)] + struct TestChildIncludedSinglePathNodeCircuit<'a> { + c: ChildIncludedSinglePathNodeCircuit, + subtree_proof: &'a [F], + included_child_proof: &'a [F], + sibling_proof: &'a [F], + } + + impl<'a> UserCircuit for TestChildIncludedSinglePathNodeCircuit<'a> { + // Circuit wires + subtree proof + included child proof + sibling proof + type Wires = ( + ChildIncludedSinglePathNodeWires, + Vec, + Vec, + Vec, + ); + + fn build(b: &mut CBuilder) -> Self::Wires { + let proofs = array::from_fn(|_| b.add_virtual_target_arr::<{ PI_LEN }>().to_vec()); + + let [subtree_pi, included_child_pi, sibling_pi] = + array::from_fn(|i| PublicInputs::::from_slice(&proofs[i])); + + let wires = ChildIncludedSinglePathNodeCircuit::build( + b, + &subtree_pi, + &included_child_pi, + &sibling_pi, + ); + + let [subtree_proof, included_child_proof, sibling_proof] = proofs; + + (wires, subtree_proof, included_child_proof, sibling_proof) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.c.assign(pw, &wires.0); + pw.set_target_arr(&wires.1, self.subtree_proof); + pw.set_target_arr(&wires.2, self.included_child_proof); + pw.set_target_arr(&wires.3, self.sibling_proof); + } + } + + fn test_child_included_single_path_node_circuit( + is_rows_tree: bool, + is_left_child: bool, + sibling_exists: bool, + ) { + let [mut subtree_proof, mut included_child_proof, mut sibling_proof] = + random_results_extraction_public_inputs::<3>(); + unify_subtree_proof(&mut subtree_proof); + let subtree_pi = PublicInputs::from_slice(&subtree_proof); + if sibling_exists { + [ + (&mut included_child_proof, is_left_child), + (&mut sibling_proof, !is_left_child), + ] + .iter_mut() + .for_each(|(p, is_left_child)| { + unify_child_proof(p, is_rows_tree, *is_left_child, &subtree_pi) + }); + } else { + unify_child_proof( + &mut included_child_proof, + is_rows_tree, + is_left_child, + &subtree_pi, + ); + sibling_proof = subtree_proof.clone(); + } + let included_child_pi = PublicInputs::from_slice(&included_child_proof); + let sibling_pi = PublicInputs::from_slice(&sibling_proof); + + let empty_hash = empty_poseidon_hash(); + + // Construct the expected public input values. + let index_ids = subtree_pi.index_ids(); + let primary_index_value = subtree_pi.primary_index_value(); + let node_value = if is_rows_tree { + subtree_pi.min_value() + } else { + primary_index_value + }; + let sibling_min = if sibling_exists { + sibling_pi.min_value() + } else { + node_value + }; + let sibling_max = if sibling_exists { + sibling_pi.max_value() + } else { + node_value + }; + let node_min = if is_left_child { + included_child_pi.min_value() + } else { + sibling_min + }; + let node_max = if is_left_child { + sibling_max + } else { + included_child_pi.max_value() + }; + let sibling_min_counter = if sibling_exists { + sibling_pi.min_counter() + } else { + subtree_pi.min_counter() + }; + let sibling_max_counter = if sibling_exists { + sibling_pi.max_counter() + } else { + subtree_pi.max_counter() + }; + let min_counter = if is_left_child { + included_child_pi.min_counter() + } else { + sibling_min_counter + }; + let max_counter = if is_left_child { + sibling_max_counter + } else { + included_child_pi.max_counter() + }; + + // Construct the test circuit. + let test_circuit = TestChildIncludedSinglePathNodeCircuit { + c: ChildIncludedSinglePathNodeCircuit { + is_left_child, + sibling_exists, + is_rows_tree, + }, + subtree_proof: &subtree_proof, + included_child_proof: &included_child_proof, + sibling_proof: &sibling_proof, + }; + + // Prove for the test circuit. + let proof = run_circuit::(test_circuit); + let pi = PublicInputs::from_slice(&proof.public_inputs); + + // Check the public inputs. + // Tree hash + { + let column_id = if is_rows_tree { + index_ids[1] + } else { + index_ids[0] + }; + let sibling_hash = if sibling_exists { + sibling_pi.tree_hash() + } else { + *empty_hash + }; + let left_hash = if is_left_child { + included_child_pi.tree_hash() + } else { + sibling_hash + }; + let right_hash = if is_left_child { + sibling_hash + } else { + included_child_pi.tree_hash() + }; + let hash_inputs: Vec<_> = left_hash + .to_fields() + .into_iter() + .chain(right_hash.to_fields()) + .chain(node_min.to_fields()) + .chain(node_max.to_fields()) + .chain(iter::once(column_id)) + .chain(node_value.to_fields()) + .chain(subtree_pi.tree_hash().to_fields()) + .collect(); + let exp_hash = H::hash_no_pad(&hash_inputs); + assert_eq!(pi.tree_hash(), exp_hash) + } + + // Minimum value + assert_eq!(pi.min_value(), node_min); + + // Maximum value + assert_eq!(pi.max_value(), node_max); + + // Primary index value + assert_eq!(pi.primary_index_value(), subtree_pi.primary_index_value()); + + // Index IDs + assert_eq!(pi.index_ids(), index_ids); + + // Minimum counter + assert_eq!(pi.min_counter(), min_counter); + + // Maximum counter + assert_eq!(pi.max_counter(), max_counter); + + // Offset range min + assert_eq!(pi.offset_range_min(), subtree_pi.offset_range_min()); + + // Offset range max + assert_eq!(pi.offset_range_max(), subtree_pi.offset_range_max()); + + // Accumulator + assert_eq!(pi.accumulator(), subtree_pi.accumulator()); + } + + #[test] + fn test_child_included_single_path_for_row_node_with_only_right_child() { + test_child_included_single_path_node_circuit(true, false, false); + } + #[test] + fn test_child_included_single_path_for_row_node_with_only_left_child() { + test_child_included_single_path_node_circuit(true, true, false); + } + #[test] + fn test_child_included_single_path_for_row_node_with_right_child_included() { + test_child_included_single_path_node_circuit(true, false, true); + } + #[test] + fn test_child_included_single_path_for_row_node_with_left_child_included() { + test_child_included_single_path_node_circuit(true, true, true); + } + #[test] + fn test_child_included_single_path_for_index_node_with_only_right_child() { + test_child_included_single_path_node_circuit(false, false, false); + } + #[test] + fn test_child_included_single_path_for_index_node_with_only_left_child() { + test_child_included_single_path_node_circuit(false, true, false); + } + #[test] + fn test_child_included_single_path_for_index_node_with_right_child_included() { + test_child_included_single_path_node_circuit(false, false, true); + } + #[test] + fn test_child_included_single_path_for_index_node_with_left_child_included() { + test_child_included_single_path_node_circuit(false, true, true); + } +} diff --git a/verifiable-db/src/results_tree/extraction/full_node.rs b/verifiable-db/src/results_tree/extraction/full_node.rs new file mode 100644 index 000000000..89a2cbc21 --- /dev/null +++ b/verifiable-db/src/results_tree/extraction/full_node.rs @@ -0,0 +1,433 @@ +use crate::results_tree::extraction::PublicInputs; +use anyhow::Result; +use mp2_common::{ + group_hashing::CircuitBuilderGroupHashing, + poseidon::{empty_poseidon_hash, H}, + public_inputs::PublicInputCommon, + serialization::{deserialize, serialize}, + types::CBuilder, + u256::CircuitBuilderU256, + utils::{SelectCurveBuilder, SelectHashBuilder, ToTargets}, + D, F, +}; +use plonky2::{ + iop::{ + target::{BoolTarget, Target}, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::proof::ProofWithPublicInputsTarget, +}; +use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; +use recursion_framework::circuit_builder::CircuitLogicWires; +use serde::{Deserialize, Serialize}; +use std::iter; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FullNodeWires { + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + left_child_exists: BoolTarget, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + right_child_exists: BoolTarget, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + is_rows_tree: BoolTarget, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FullNodeCircuit { + /// Boolean flag specifying whether the node has a left child + pub(crate) left_child_exists: bool, + /// Boolean flag specifying whether the node has a right child + pub(crate) right_child_exists: bool, + /// Boolean flag specifying whether this node is a node of rows tree or of the index tree + pub(crate) is_rows_tree: bool, +} + +impl FullNodeCircuit { + pub fn build( + b: &mut CBuilder, + subtree_proof: &PublicInputs, + child_proofs: &[PublicInputs; 2], + ) -> FullNodeWires { + let empty_hash = b.constant_hash(*empty_poseidon_hash()); + let curve_zero = b.curve_zero(); + let one = b.one(); + + let [child_proof1, child_proof2] = child_proofs; + let [left_child_exists, right_child_exists, is_rows_tree] = + [0; 3].map(|_| b.add_virtual_bool_target_safe()); + let index_value = subtree_proof.primary_index_value_target(); + + let left_hash = b.select_hash( + left_child_exists, + &child_proof1.tree_hash_target(), + &empty_hash, + ); + let right_hash = b.select_hash( + right_child_exists, + &child_proof2.tree_hash_target(), + &empty_hash, + ); + let column_id = b.select( + is_rows_tree, + subtree_proof.index_ids_target()[1], + subtree_proof.index_ids_target()[0], + ); + let node_value = b.select_u256( + is_rows_tree, + &subtree_proof.min_value_target(), + &index_value, + ); + let node_min = b.select_u256( + left_child_exists, + &child_proof1.min_value_target(), + &node_value, + ); + let node_max = b.select_u256( + right_child_exists, + &child_proof2.max_value_target(), + &node_value, + ); + + // H(left_hash || right_hash || node_min || node_max || column_id || node_value || p.H) + let hash_inputs = left_hash + .to_targets() + .into_iter() + .chain(right_hash.to_targets()) + .chain(node_min.to_targets()) + .chain(node_max.to_targets()) + .chain(iter::once(column_id)) + .chain(node_value.to_targets()) + .chain(subtree_proof.tree_hash_target().to_targets()) + .collect(); + let node_hash = b.hash_n_to_hash_no_pad::(hash_inputs); + + // Ensure the proofs in the same rows tree are employing the same value + // of the primary indexed column: + // is_rows_tree == (is_rows_tree AND (p.I == p1.I AND p.I == p2.I)) + let [is_equal1, is_equal2] = [child_proof1, child_proof2] + .map(|p| b.is_equal_u256(&index_value, &p.primary_index_value_target())); + let is_equal = b.and(is_equal1, is_equal2); + let is_equal = b.and(is_equal, is_rows_tree); + b.connect(is_equal.target, is_rows_tree.target); + + // Enforce consistency of counters + let min_minus_one = b.sub(subtree_proof.min_counter_target(), one); + let max_plus_one = b.add(subtree_proof.max_counter_target(), one); + let max_left = b.select( + left_child_exists, + child_proof1.max_counter_target(), + min_minus_one, + ); + let min_right = b.select( + right_child_exists, + child_proof2.min_counter_target(), + max_plus_one, + ); + // assert max_left + 1 == p.min_counter + let left_plus_one = b.add(max_left, one); + b.connect(left_plus_one, subtree_proof.min_counter_target()); + // assert p.max_counter + 1 == min_right + b.connect(max_plus_one, min_right); + + // aggregate accumulators + let left_acc = b.select_curve( + left_child_exists, + &child_proof1.accumulator_target(), + &curve_zero, + ); + let right_acc = b.select_curve( + right_child_exists, + &child_proof2.accumulator_target(), + &curve_zero, + ); + let accumulator = + b.add_curve_point(&[left_acc, right_acc, subtree_proof.accumulator_target()]); + + let min_counter = b.select( + left_child_exists, + child_proof1.min_counter_target(), + subtree_proof.min_counter_target(), + ); + let max_counter = b.select( + right_child_exists, + child_proof2.max_counter_target(), + subtree_proof.max_counter_target(), + ); + + // Register the public inputs. + PublicInputs::<_>::new( + &node_hash.to_targets(), + &node_min.to_targets(), + &node_max.to_targets(), + subtree_proof.to_primary_index_value_raw(), + subtree_proof.to_index_ids_raw(), + &[min_counter], + &[max_counter], + &[*subtree_proof.to_offset_range_min_raw()], + &[*subtree_proof.to_offset_range_max_raw()], + &accumulator.to_targets(), + ) + .register(b); + + FullNodeWires { + left_child_exists, + right_child_exists, + is_rows_tree, + } + } + + fn assign(&self, pw: &mut PartialWitness, wires: &FullNodeWires) { + pw.set_bool_target(wires.left_child_exists, self.left_child_exists); + pw.set_bool_target(wires.right_child_exists, self.right_child_exists); + pw.set_bool_target(wires.is_rows_tree, self.is_rows_tree); + } +} + +/// Subtree proof number = 1, child proof number = 2 +pub(crate) const NUM_VERIFIED_PROOFS: usize = 3; + +impl CircuitLogicWires for FullNodeWires { + type CircuitBuilderParams = (); + type Inputs = FullNodeCircuit; + + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); + + fn circuit_logic( + builder: &mut CBuilder, + verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], + _builder_parameters: Self::CircuitBuilderParams, + ) -> Self { + // The first one is the subtree proof, and the remainings are child proofs. + let [subtree_proof, child_proof1, child_proof2] = + verified_proofs.map(|p| PublicInputs::from_slice(&p.public_inputs)); + + Self::Inputs::build(builder, &subtree_proof, &[child_proof1, child_proof2]) + } + + fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { + inputs.assign(pw, self); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::results_tree::extraction::{ + tests::{random_results_extraction_public_inputs, unify_child_proof, unify_subtree_proof}, + PI_LEN, + }; + use mp2_common::{group_hashing::add_weierstrass_point, utils::ToFields, C}; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::plonk::config::Hasher; + use plonky2_ecgfp5::curve::curve::WeierstrassPoint; + use std::array; + + #[derive(Clone, Debug)] + struct TestFullNodeCircuit<'a> { + c: FullNodeCircuit, + subtree_proof: &'a [F], + left_child_proof: &'a [F], + right_child_proof: &'a [F], + } + + impl<'a> UserCircuit for TestFullNodeCircuit<'a> { + // Circuit wires + subtree proof + left child proof + right child proof + type Wires = (FullNodeWires, Vec, Vec, Vec); + + fn build(b: &mut CBuilder) -> Self::Wires { + let proofs = array::from_fn(|_| b.add_virtual_target_arr::<{ PI_LEN }>().to_vec()); + + let [subtree_pi, left_child_pi, right_child_pi] = + array::from_fn(|i| PublicInputs::::from_slice(&proofs[i])); + + let wires = FullNodeCircuit::build(b, &subtree_pi, &[left_child_pi, right_child_pi]); + + let [subtree_proof, left_child_proof, right_child_proof] = proofs; + + (wires, subtree_proof, left_child_proof, right_child_proof) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.c.assign(pw, &wires.0); + pw.set_target_arr(&wires.1, self.subtree_proof); + pw.set_target_arr(&wires.2, self.left_child_proof); + pw.set_target_arr(&wires.3, self.right_child_proof); + } + } + + fn test_full_node_circuit( + is_rows_tree: bool, + left_child_exists: bool, + right_child_exists: bool, + ) { + // Generate the input proofs. + let [mut subtree_proof, mut left_child_proof, mut right_child_proof] = + random_results_extraction_public_inputs::<3>(); + unify_subtree_proof(&mut subtree_proof); + let subtree_pi = PublicInputs::from_slice(&subtree_proof); + [ + (&mut left_child_proof, true), + (&mut right_child_proof, false), + ] + .iter_mut() + .for_each(|(p, is_left_child)| { + unify_child_proof(p, is_rows_tree, *is_left_child, &subtree_pi) + }); + let left_child_pi = PublicInputs::from_slice(&left_child_proof); + let right_child_pi = PublicInputs::from_slice(&right_child_proof); + + // Construct the expected public input values. + let index_ids = subtree_pi.index_ids(); + let primary_index_value = subtree_pi.primary_index_value(); + let node_value = if is_rows_tree { + subtree_pi.min_value() + } else { + primary_index_value + }; + let node_min = if left_child_exists { + left_child_pi.min_value() + } else { + node_value + }; + let node_max = if right_child_exists { + right_child_pi.max_value() + } else { + node_value + }; + let min_counter = if left_child_exists { + left_child_pi.min_counter() + } else { + subtree_pi.min_counter() + }; + let max_counter = if right_child_exists { + right_child_pi.max_counter() + } else { + subtree_pi.max_counter() + }; + + // Construct the test circuit. + let test_circuit = TestFullNodeCircuit { + c: FullNodeCircuit { + left_child_exists, + right_child_exists, + is_rows_tree, + }, + subtree_proof: &subtree_proof, + left_child_proof: &left_child_proof, + right_child_proof: &right_child_proof, + }; + + // Prove for the test circuit. + let proof = run_circuit::(test_circuit); + let pi = PublicInputs::<_>::from_slice(&proof.public_inputs); + + // Check the public inputs. + // Tree hash + { + let column_id = if is_rows_tree { + index_ids[1] + } else { + index_ids[0] + }; + let empty_hash = empty_poseidon_hash(); + let left_hash = if left_child_exists { + left_child_pi.tree_hash() + } else { + *empty_hash + }; + let right_hash = if right_child_exists { + right_child_pi.tree_hash() + } else { + *empty_hash + }; + let hash_inputs: Vec<_> = left_hash + .to_fields() + .into_iter() + .chain(right_hash.to_fields()) + .chain(node_min.to_fields()) + .chain(node_max.to_fields()) + .chain(iter::once(column_id)) + .chain(node_value.to_fields()) + .chain(subtree_pi.tree_hash().to_fields()) + .collect(); + let exp_hash = H::hash_no_pad(&hash_inputs); + assert_eq!(pi.tree_hash(), exp_hash); + } + + // Minimum value + assert_eq!(pi.min_value(), node_min); + + // Maximum value + assert_eq!(pi.max_value(), node_max); + + // Primary index value + assert_eq!(pi.primary_index_value(), subtree_pi.primary_index_value()); + + // Index IDs + assert_eq!(pi.index_ids(), index_ids); + + // Minimum counter + assert_eq!(pi.min_counter(), min_counter); + + // Maximum counter + assert_eq!(pi.max_counter(), max_counter); + + // Offset range min + assert_eq!(pi.offset_range_min(), subtree_pi.offset_range_min()); + + // Offset range max + assert_eq!(pi.offset_range_max(), subtree_pi.offset_range_max()); + + // Accumulator + { + let left_acc = if left_child_exists { + left_child_pi.accumulator() + } else { + WeierstrassPoint::NEUTRAL + }; + let right_acc = if right_child_exists { + right_child_pi.accumulator() + } else { + WeierstrassPoint::NEUTRAL + }; + let exp_accumulator = + add_weierstrass_point(&[left_acc, right_acc, subtree_pi.accumulator()]); + + assert_eq!(pi.accumulator(), exp_accumulator); + } + } + + #[test] + fn test_full_node_circuit_for_row_node_with_no_child() { + test_full_node_circuit(true, false, false); + } + #[test] + fn test_full_node_circuit_for_row_node_with_left_child() { + test_full_node_circuit(true, true, false); + } + #[test] + fn test_full_node_circuit_for_row_node_with_right_child() { + test_full_node_circuit(true, false, true); + } + #[test] + fn test_full_node_circuit_for_row_node_with_both_children() { + test_full_node_circuit(true, true, true); + } + #[test] + fn test_full_node_circuit_for_index_node_with_no_child() { + test_full_node_circuit(false, false, false); + } + #[test] + fn test_full_node_circuit_for_index_node_with_left_child() { + test_full_node_circuit(false, true, false); + } + #[test] + fn test_full_node_circuit_for_index_node_with_right_child() { + test_full_node_circuit(false, false, true); + } + #[test] + fn test_full_node_circuit_for_index_node_with_both_children() { + test_full_node_circuit(false, true, true); + } +} diff --git a/verifiable-db/src/results_tree/extraction/mod.rs b/verifiable-db/src/results_tree/extraction/mod.rs index 0ea2eb331..d4e5e8e6c 100644 --- a/verifiable-db/src/results_tree/extraction/mod.rs +++ b/verifiable-db/src/results_tree/extraction/mod.rs @@ -1,4 +1,114 @@ +use mp2_common::F; +use public_inputs::PublicInputs; + +pub(crate) mod child_included_single_path_node; +pub(crate) mod full_node; +pub(crate) mod no_child_included_single_path_node; +pub(crate) mod no_results_in_chunk; +pub(crate) mod partial_node; pub(crate) mod public_inputs; pub(crate) mod record; -use public_inputs::PublicInputs; +// Without this skipping config, the generic parameter was deleted when `cargo fmt`. +#[rustfmt::skip] +pub(crate) const PI_LEN: usize = PublicInputs::::total_len(); + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use mp2_common::utils::ToFields; + use mp2_test::utils::random_vector; + use plonky2::field::types::{Field, Sample}; + use plonky2_ecgfp5::curve::curve::Point; + use public_inputs::{PublicInputs, ResultsExtractionPublicInputs}; + use rand::{thread_rng, Rng}; + use std::array; + + /// Generate N number of proof public input slices. The each returned proof public inputs + /// could be constructed by `PublicInputs::from_slice` function. + pub(crate) fn random_results_extraction_public_inputs() -> [Vec; N] { + let mut rng = thread_rng(); + + let index_ids: [F; 2] = F::rand_array(); + + let [idx_ids_range, acc_range] = [ + ResultsExtractionPublicInputs::IndexIds, + ResultsExtractionPublicInputs::Accumulator, + ] + .map(PublicInputs::::to_range); + + array::from_fn(|_| { + let mut pi = random_vector::(PI_LEN).to_fields(); + + // Set the Index IDs. + pi[idx_ids_range.clone()].copy_from_slice(&index_ids); + + // Set a random point to Accumulator. + let acc = Point::sample(&mut rng).to_weierstrass().to_fields(); + pi[acc_range.clone()].copy_from_slice(&acc); + + pi + }) + } + + /// Assign the subtree proof to make consistent. + pub(crate) fn unify_subtree_proof(proof: &mut [F]) { + // offset_range_min <= min_counter <= max_counter <= offset_range_max + let mut rng = thread_rng(); + let min_counter = F::from_canonical_u32(rng.gen()); + let max_counter = min_counter + F::from_canonical_u32(100); + let [min_cnt_range, max_cnt_range, offset_rng_min_range, offset_rng_max_range] = [ + ResultsExtractionPublicInputs::MinCounter, + ResultsExtractionPublicInputs::MaxCounter, + ResultsExtractionPublicInputs::OffsetRangeMin, + ResultsExtractionPublicInputs::OffsetRangeMax, + ] + .map(PublicInputs::::to_range); + + // Set the Min/Max counters. + proof[min_cnt_range].copy_from_slice(&[min_counter]); + proof[max_cnt_range].copy_from_slice(&[max_counter]); + proof[offset_rng_min_range].copy_from_slice(&[min_counter]); + proof[offset_rng_max_range].copy_from_slice(&[max_counter]); + } + + /// Assign the child proof to make consistent. + pub(crate) fn unify_child_proof( + proof: &mut [F], + is_rows_tree: bool, + is_left_child: bool, + subtree_pi: &PublicInputs, + ) { + let [pri_idx_val_range, min_cnt_range, max_cnt_range, offset_rng_min_range, offset_rng_max_range] = + [ + ResultsExtractionPublicInputs::PrimaryIndexValue, + ResultsExtractionPublicInputs::MinCounter, + ResultsExtractionPublicInputs::MaxCounter, + ResultsExtractionPublicInputs::OffsetRangeMin, + ResultsExtractionPublicInputs::OffsetRangeMax, + ] + .map(PublicInputs::::to_range); + + if is_rows_tree { + // pC.I == pR.I + proof[pri_idx_val_range].copy_from_slice(subtree_pi.to_primary_index_value_raw()); + } + + if is_left_child { + let left_min_counter = subtree_pi.min_counter() - F::ONE; + + // pC.max_counter = pR.min_counter - 1 + proof[max_cnt_range].copy_from_slice(&[left_min_counter]); + + // pC.offset_range_max < pR.min_counter + proof[offset_rng_max_range].copy_from_slice(&[left_min_counter]); + } else { + let right_max_counter = subtree_pi.max_counter() + F::ONE; + // pC.min_counter = pR.max_counter + 1 + proof[min_cnt_range].copy_from_slice(&[right_max_counter]); + + // pC.offset_range_min > pR.max_counter + proof[offset_rng_min_range].copy_from_slice(&[right_max_counter]); + } + } +} diff --git a/verifiable-db/src/results_tree/extraction/no_child_included_single_path_node.rs b/verifiable-db/src/results_tree/extraction/no_child_included_single_path_node.rs new file mode 100644 index 000000000..d89352616 --- /dev/null +++ b/verifiable-db/src/results_tree/extraction/no_child_included_single_path_node.rs @@ -0,0 +1,416 @@ +use crate::results_tree::extraction::PublicInputs; +use anyhow::Result; +use mp2_common::{ + poseidon::{empty_poseidon_hash, H}, + public_inputs::PublicInputCommon, + serialization::{deserialize, serialize}, + types::CBuilder, + u256::CircuitBuilderU256, + utils::{greater_than, less_than, SelectHashBuilder, ToTargets}, + D, F, +}; +use plonky2::{ + iop::{ + target::{BoolTarget, Target}, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::proof::ProofWithPublicInputsTarget, +}; +use recursion_framework::circuit_builder::CircuitLogicWires; +use serde::{Deserialize, Serialize}; +use std::iter; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NoChildIncludedSinglePathNodeWires { + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + left_child_exists: BoolTarget, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + right_child_exists: BoolTarget, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + is_rows_tree: BoolTarget, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NoChildIncludedSinglePathNodeCircuit { + /// Boolean flag specifying whether the node has a left child or not + pub(crate) left_child_exists: bool, + /// Boolean flag specifying whether the node has a right child or not + pub(crate) right_child_exists: bool, + /// Boolean flag specifying whether the current node is a node of + /// a rows tree or of the index tree + pub(crate) is_rows_tree: bool, +} + +impl NoChildIncludedSinglePathNodeCircuit { + pub fn build( + b: &mut CBuilder, + subtree_proof: &PublicInputs, + child_proofs: &[PublicInputs; 2], + ) -> NoChildIncludedSinglePathNodeWires { + let empty_hash = b.constant_hash(*empty_poseidon_hash()); + let one = b.one(); + let ttrue = b._true(); + + let [child_proof1, child_proof2] = child_proofs; + let [left_child_exists, right_child_exists, is_rows_tree] = + [0; 3].map(|_| b.add_virtual_bool_target_safe()); + let index_value = subtree_proof.primary_index_value_target(); + + let left_hash = b.select_hash( + left_child_exists, + &child_proof1.tree_hash_target(), + &empty_hash, + ); + let right_hash = b.select_hash( + right_child_exists, + &child_proof2.tree_hash_target(), + &empty_hash, + ); + let column_id = b.select( + is_rows_tree, + subtree_proof.index_ids_target()[1], + subtree_proof.index_ids_target()[0], + ); + let node_value = b.select_u256( + is_rows_tree, + &subtree_proof.min_value_target(), + &index_value, + ); + let node_min = b.select_u256( + left_child_exists, + &child_proof1.min_value_target(), + &node_value, + ); + let node_max = b.select_u256( + right_child_exists, + &child_proof2.max_value_target(), + &node_value, + ); + + // H(left_hash || right_hash || node_min || node_max || column_id || node_value || p.H) + let hash_inputs = left_hash + .to_targets() + .into_iter() + .chain(right_hash.to_targets()) + .chain(node_min.to_targets()) + .chain(node_max.to_targets()) + .chain(iter::once(column_id)) + .chain(node_value.to_targets()) + .chain(subtree_proof.tree_hash_target().to_targets()) + .collect(); + let node_hash = b.hash_n_to_hash_no_pad::(hash_inputs); + + // Enforce consistency of counters + let min_minus_one = b.sub(subtree_proof.min_counter_target(), one); + let max_plus_one = b.add(subtree_proof.max_counter_target(), one); + let max_left = b.select( + left_child_exists, + child_proof1.max_counter_target(), + min_minus_one, + ); + let min_right = b.select( + right_child_exists, + child_proof2.min_counter_target(), + max_plus_one, + ); + // assert max_left + 1 == p.min_counter + let left_plus_one = b.add(max_left, one); + b.connect(left_plus_one, subtree_proof.min_counter_target()); + // assert p.max_counter + 1 == min_right + b.connect(max_plus_one, min_right); + + // Ensure that all the records in the subtree rooted in the left child, + // if there is a left child, are associated to counters outside of the + // range specified by the query + // max_left < p.offset_range_max + let is_less = less_than(b, max_left, subtree_proof.offset_range_min_target(), 32); + b.connect(is_less.target, ttrue.target); + + // Enforce that all the records in the subtree rooted in the right child, + // if there is a right child, are associated to counters outside of the + // range specified by the query + // min_right > p.offset_range_min + let is_greater = greater_than(b, min_right, subtree_proof.offset_range_max_target(), 32); + b.connect(is_greater.target, ttrue.target); + + let min_counter = b.select( + left_child_exists, + child_proof1.min_counter_target(), + subtree_proof.min_counter_target(), + ); + let max_counter = b.select( + right_child_exists, + child_proof2.max_counter_target(), + subtree_proof.max_counter_target(), + ); + + // Register the public inputs. + PublicInputs::new( + &node_hash.to_targets(), + &node_min.to_targets(), + &node_max.to_targets(), + subtree_proof.to_primary_index_value_raw(), + subtree_proof.to_index_ids_raw(), + &[min_counter], + &[max_counter], + &[*subtree_proof.to_offset_range_min_raw()], + &[*subtree_proof.to_offset_range_max_raw()], + subtree_proof.to_accumulator_raw(), + ) + .register(b); + + NoChildIncludedSinglePathNodeWires { + left_child_exists, + right_child_exists, + is_rows_tree, + } + } + + fn assign(&self, pw: &mut PartialWitness, wires: &NoChildIncludedSinglePathNodeWires) { + pw.set_bool_target(wires.left_child_exists, self.left_child_exists); + pw.set_bool_target(wires.right_child_exists, self.right_child_exists); + pw.set_bool_target(wires.is_rows_tree, self.is_rows_tree); + } +} + +/// Subtree proof number = 1, child proof number = 2 +pub(crate) const NUM_VERIFIED_PROOFS: usize = 3; + +impl CircuitLogicWires for NoChildIncludedSinglePathNodeWires { + type CircuitBuilderParams = (); + type Inputs = NoChildIncludedSinglePathNodeCircuit; + + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); + + fn circuit_logic( + builder: &mut CBuilder, + verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], + _builder_parameters: Self::CircuitBuilderParams, + ) -> Self { + // The first one is the subtree proof, and the remainings are child proofs. + let [subtree_proof, child_proof1, child_proof2] = + verified_proofs.map(|p| PublicInputs::from_slice(&p.public_inputs)); + + Self::Inputs::build(builder, &subtree_proof, &[child_proof1, child_proof2]) + } + + fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { + inputs.assign(pw, self); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::results_tree::extraction::{ + tests::{random_results_extraction_public_inputs, unify_child_proof, unify_subtree_proof}, + PI_LEN, + }; + use mp2_common::{utils::ToFields, C}; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::plonk::config::Hasher; + use std::array; + + #[derive(Clone, Debug)] + struct TestNoChildIncludedSinglePathNodeCircuit<'a> { + c: NoChildIncludedSinglePathNodeCircuit, + subtree_proof: &'a [F], + left_child_proof: &'a [F], + right_child_proof: &'a [F], + } + + impl<'a> UserCircuit for TestNoChildIncludedSinglePathNodeCircuit<'a> { + // Circuit wires + subtree proof + left child proof + right child proof + type Wires = ( + NoChildIncludedSinglePathNodeWires, + Vec, + Vec, + Vec, + ); + + fn build(b: &mut CBuilder) -> Self::Wires { + let proofs = array::from_fn(|_| b.add_virtual_target_arr::<{ PI_LEN }>().to_vec()); + + let [subtree_pi, left_child_pi, right_child_pi] = + array::from_fn(|i| PublicInputs::::from_slice(&proofs[i])); + + let wires = NoChildIncludedSinglePathNodeCircuit::build( + b, + &subtree_pi, + &[left_child_pi, right_child_pi], + ); + + let [subtree_proof, left_child_proof, right_child_proof] = proofs; + + (wires, subtree_proof, left_child_proof, right_child_proof) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.c.assign(pw, &wires.0); + pw.set_target_arr(&wires.1, self.subtree_proof); + pw.set_target_arr(&wires.2, self.left_child_proof); + pw.set_target_arr(&wires.3, self.right_child_proof); + } + } + + fn test_no_child_included_single_path_node_circuit( + is_rows_tree: bool, + left_child_exists: bool, + right_child_exists: bool, + ) { + // Generate the input proofs. + let [mut subtree_proof, mut left_child_proof, mut right_child_proof] = + random_results_extraction_public_inputs::<3>(); + unify_subtree_proof(&mut subtree_proof); + let subtree_pi = PublicInputs::from_slice(&subtree_proof); + [ + (&mut left_child_proof, true), + (&mut right_child_proof, false), + ] + .iter_mut() + .for_each(|(p, is_left_child)| { + unify_child_proof(p, is_rows_tree, *is_left_child, &subtree_pi) + }); + let left_child_pi = PublicInputs::from_slice(&left_child_proof); + let right_child_pi = PublicInputs::from_slice(&right_child_proof); + + // Construct the expected public input values. + let index_ids = subtree_pi.index_ids(); + let primary_index_value = subtree_pi.primary_index_value(); + let node_value = if is_rows_tree { + subtree_pi.min_value() + } else { + primary_index_value + }; + let node_min = if left_child_exists { + left_child_pi.min_value() + } else { + node_value + }; + let node_max = if right_child_exists { + right_child_pi.max_value() + } else { + node_value + }; + let min_counter = if left_child_exists { + left_child_pi.min_counter() + } else { + subtree_pi.min_counter() + }; + let max_counter = if right_child_exists { + right_child_pi.max_counter() + } else { + subtree_pi.max_counter() + }; + + // Construct the test circuit. + let test_circuit = TestNoChildIncludedSinglePathNodeCircuit { + c: NoChildIncludedSinglePathNodeCircuit { + left_child_exists, + right_child_exists, + is_rows_tree, + }, + subtree_proof: &subtree_proof, + left_child_proof: &left_child_proof, + right_child_proof: &right_child_proof, + }; + + // Prove for the test circuit. + let proof = run_circuit::(test_circuit); + let pi = PublicInputs::<_>::from_slice(&proof.public_inputs); + + // Check the public inputs. + // Tree hash + { + let column_id = if is_rows_tree { + index_ids[1] + } else { + index_ids[0] + }; + let empty_hash = empty_poseidon_hash(); + let left_hash = if left_child_exists { + left_child_pi.tree_hash() + } else { + *empty_hash + }; + let right_hash = if right_child_exists { + right_child_pi.tree_hash() + } else { + *empty_hash + }; + let hash_inputs: Vec<_> = left_hash + .to_fields() + .into_iter() + .chain(right_hash.to_fields()) + .chain(node_min.to_fields()) + .chain(node_max.to_fields()) + .chain(iter::once(column_id)) + .chain(node_value.to_fields()) + .chain(subtree_pi.tree_hash().to_fields()) + .collect(); + let exp_hash = H::hash_no_pad(&hash_inputs); + assert_eq!(pi.tree_hash(), exp_hash); + } + + // Minimum value + assert_eq!(pi.min_value(), node_min); + + // Maximum value + assert_eq!(pi.max_value(), node_max); + + // Primary index value + assert_eq!(pi.primary_index_value(), subtree_pi.primary_index_value()); + + // Index IDs + assert_eq!(pi.index_ids(), index_ids); + + // Minimum counter + assert_eq!(pi.min_counter(), min_counter); + + // Maximum counter + assert_eq!(pi.max_counter(), max_counter); + + // Offset range min + assert_eq!(pi.offset_range_min(), subtree_pi.offset_range_min()); + + // Offset range max + assert_eq!(pi.offset_range_max(), subtree_pi.offset_range_max()); + + // Accumulator + assert_eq!(pi.accumulator(), subtree_pi.accumulator()); + } + + #[test] + fn test_no_child_included_for_row_node_with_no_child() { + test_no_child_included_single_path_node_circuit(true, false, false); + } + #[test] + fn test_no_child_included_for_row_node_with_left_child() { + test_no_child_included_single_path_node_circuit(true, true, false); + } + #[test] + fn test_no_child_included_for_row_node_with_right_child() { + test_no_child_included_single_path_node_circuit(true, false, true); + } + #[test] + fn test_no_child_included_for_row_node_with_both_children() { + test_no_child_included_single_path_node_circuit(true, true, true); + } + #[test] + fn test_no_child_included_for_index_node_with_no_child() { + test_no_child_included_single_path_node_circuit(false, false, false); + } + #[test] + fn test_no_child_included_for_index_node_with_left_child() { + test_no_child_included_single_path_node_circuit(false, true, false); + } + #[test] + fn test_no_child_included_for_index_node_with_right_child() { + test_no_child_included_single_path_node_circuit(false, false, true); + } + #[test] + fn test_no_child_included_for_index_node_with_both_children() { + test_no_child_included_single_path_node_circuit(false, true, true); + } +} diff --git a/verifiable-db/src/results_tree/extraction/no_results_in_chunk.rs b/verifiable-db/src/results_tree/extraction/no_results_in_chunk.rs new file mode 100644 index 000000000..027e822bf --- /dev/null +++ b/verifiable-db/src/results_tree/extraction/no_results_in_chunk.rs @@ -0,0 +1,191 @@ +use crate::results_tree::extraction::PublicInputs; +use alloy::primitives::U256; +use mp2_common::{ + public_inputs::PublicInputCommon, + serialization::{deserialize, serialize}, + types::CBuilder, + u256::CircuitBuilderU256, + utils::{greater_than, ToTargets}, + D, F, +}; +use plonky2::{ + hash::hash_types::{HashOut, HashOutTarget}, + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::proof::ProofWithPublicInputsTarget, +}; +use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; +use recursion_framework::circuit_builder::CircuitLogicWires; +use serde::{Deserialize, Serialize}; +use std::iter; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NoResultsInChunkWires { + num_records: Target, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + results_tree_hash: HashOutTarget, + offset_range_min: Target, + offset_range_max: Target, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NoResultsInChunkCircuit { + /// Number of records in the results tree + pub(crate) num_records: F, + /// Hash of the results tree + pub(crate) results_tree_hash: HashOut, + /// Minimum offset range bound + pub(crate) offset_range_min: F, + /// Maximum offset range bound + pub(crate) offset_range_max: F, +} + +impl NoResultsInChunkCircuit { + pub fn build(b: &mut CBuilder) -> NoResultsInChunkWires { + let zero_u256 = b.zero_u256(); + let curve_zero = b.curve_zero(); + let one = b.one(); + let ttrue = b._true(); + + let num_records = b.add_virtual_target(); + let results_tree_hash = b.add_virtual_hash(); + let [offset_range_min, offset_range_max] = b.add_virtual_target_arr(); + + // Ensure that the query is asking to retrieve results with an offset + // being greater than the overall number of results + let is_greater = greater_than(b, offset_range_min, num_records, 32); + b.connect(is_greater.target, ttrue.target); + + // Register the public inputs. + PublicInputs::new( + &results_tree_hash.to_targets(), + &zero_u256.to_targets(), + &zero_u256.to_targets(), + &zero_u256.to_targets(), + &[one; 2], + &[one], + &[num_records], + &[offset_range_min], + &[offset_range_max], + &curve_zero.to_targets(), + ) + .register(b); + + NoResultsInChunkWires { + num_records, + results_tree_hash, + offset_range_min, + offset_range_max, + } + } + + fn assign(&self, pw: &mut PartialWitness, wires: &NoResultsInChunkWires) { + pw.set_target(wires.num_records, self.num_records); + pw.set_hash_target(wires.results_tree_hash, self.results_tree_hash); + pw.set_target(wires.offset_range_min, self.offset_range_min); + pw.set_target(wires.offset_range_max, self.offset_range_max); + } +} + +/// Verified proof number = 0 +pub(crate) const NUM_VERIFIED_PROOFS: usize = 0; + +impl CircuitLogicWires for NoResultsInChunkWires { + type CircuitBuilderParams = (); + type Inputs = NoResultsInChunkCircuit; + + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); + + fn circuit_logic( + builder: &mut CBuilder, + _verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], + _builder_parameters: Self::CircuitBuilderParams, + ) -> Self { + Self::Inputs::build(builder) + } + + fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> anyhow::Result<()> { + inputs.assign(pw, self); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mp2_common::C; + use mp2_test::{ + circuit::{run_circuit, UserCircuit}, + utils::gen_random_field_hash, + }; + use plonky2::field::types::Field; + use plonky2_ecgfp5::curve::curve::WeierstrassPoint; + use rand::{thread_rng, Rng}; + + impl UserCircuit for NoResultsInChunkCircuit { + type Wires = NoResultsInChunkWires; + + fn build(b: &mut CBuilder) -> Self::Wires { + NoResultsInChunkCircuit::build(b) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.assign(pw, wires); + } + } + + #[test] + fn test_no_results_in_chunk_circuit() { + // Construct the witness. + let mut rng = thread_rng(); + let num_records = F::from_canonical_u32(rng.gen()); + let results_tree_hash = gen_random_field_hash(); + let offset_range_min = num_records + F::ONE; + let offset_range_max = offset_range_min + F::from_canonical_u32(rng.gen()); + + // Construct the circuit. + let test_circuit = NoResultsInChunkCircuit { + num_records, + results_tree_hash, + offset_range_min, + offset_range_max, + }; + + // Proof for the test circuit. + let proof = run_circuit::(test_circuit); + let pi = PublicInputs::from_slice(&proof.public_inputs); + + // Check the public inputs. + // Tree hash + assert_eq!(pi.tree_hash(), results_tree_hash); + + // Min value + assert_eq!(pi.min_value(), U256::ZERO); + + // Max value + assert_eq!(pi.max_value(), U256::ZERO); + + // Primary index value + assert_eq!(pi.primary_index_value(), U256::ZERO); + + // Index ids + assert_eq!(pi.index_ids(), [F::ONE; 2]); + + // Min counter + assert_eq!(pi.min_counter(), F::ONE); + + // Max counter + assert_eq!(pi.max_counter(), num_records); + + // Offset range min + assert_eq!(pi.offset_range_min(), offset_range_min); + + // Offset range max + assert_eq!(pi.offset_range_max(), offset_range_max); + + // Accumulator + assert_eq!(pi.accumulator(), WeierstrassPoint::NEUTRAL); + } +} diff --git a/verifiable-db/src/results_tree/extraction/partial_node.rs b/verifiable-db/src/results_tree/extraction/partial_node.rs new file mode 100644 index 000000000..31d963189 --- /dev/null +++ b/verifiable-db/src/results_tree/extraction/partial_node.rs @@ -0,0 +1,415 @@ +use crate::results_tree::extraction::PublicInputs; +use anyhow::Result; +use mp2_common::{ + group_hashing::CircuitBuilderGroupHashing, + hash::hash_maybe_first, + poseidon::H, + public_inputs::PublicInputCommon, + serialization::{deserialize, serialize}, + types::CBuilder, + u256::CircuitBuilderU256, + utils::{greater_than, less_than, ToTargets}, + D, F, +}; +use plonky2::{ + iop::{ + target::{BoolTarget, Target}, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::proof::ProofWithPublicInputsTarget, +}; +use recursion_framework::circuit_builder::CircuitLogicWires; +use serde::{Deserialize, Serialize}; +use std::iter; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct PartialNodeWires { + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + is_left_child: BoolTarget, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + is_rows_tree: BoolTarget, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct PartialNodeCircuit { + /// Boolean flag specifying whether the included child is the left child or not + pub(crate) is_left_child: bool, + /// Boolean flag specifying whether the current node is a node + /// of a rows tree or of the index tree + pub(crate) is_rows_tree: bool, +} + +impl PartialNodeCircuit { + pub fn build( + b: &mut CBuilder, + subtree_proof: &PublicInputs, + included_chid_proof: &PublicInputs, + excluded_child_proof: &PublicInputs, + ) -> PartialNodeWires { + let one = b.one(); + + let [is_left_child, is_rows_tree] = [0; 2].map(|_| b.add_virtual_bool_target_safe()); + + let column_id = b.select( + is_rows_tree, + subtree_proof.index_ids_target()[1], + subtree_proof.index_ids_target()[0], + ); + let node_value = b.select_u256( + is_rows_tree, + &subtree_proof.min_value_target(), + &subtree_proof.primary_index_value_target(), + ); + let node_min = b.select_u256( + is_left_child, + &included_chid_proof.min_value_target(), + &excluded_child_proof.min_value_target(), + ); + let node_max = b.select_u256( + is_left_child, + &excluded_child_proof.max_value_target(), + &included_chid_proof.max_value_target(), + ); + + // Compute the node hash: + // H(left_hash||right_hash||node_min||node_max||column_id||node_value||pR.H) + let rest: Vec<_> = node_min + .to_targets() + .into_iter() + .chain(node_max.to_targets()) + .chain(iter::once(column_id)) + .chain(node_value.to_targets()) + .chain(subtree_proof.tree_hash_target().elements) + .collect(); + + let node_hash = hash_maybe_first( + b, + is_left_child, + excluded_child_proof.tree_hash_target().elements, + included_chid_proof.tree_hash_target().elements, + &rest, + ); + + // Ensure the proofs in the same record subtree are employing the same value + // of the indexed item + // is_rows_tree == (is_rows_tree AND (pR.I == pI.I)) + let is_equal = b.is_equal_u256( + &subtree_proof.primary_index_value_target(), + &included_chid_proof.primary_index_value_target(), + ); + let condition = b.and(is_equal, is_rows_tree); + b.connect(condition.target, is_rows_tree.target); + + // Enforce consistency of counters + let max_left = b.select( + is_left_child, + included_chid_proof.max_counter_target(), + excluded_child_proof.max_counter_target(), + ); + let min_right = b.select( + is_left_child, + excluded_child_proof.min_counter_target(), + included_chid_proof.min_counter_target(), + ); + // Verifying proof guarantees: + // If the excluded child has N rows in its subtree, + // then pC.max_counter - pC.min_counter == N + // assert max_left + 1 == pR.min_counter + let left_plus_one = b.add(max_left, one); + b.connect(left_plus_one, subtree_proof.min_counter_target()); + // assert pR.max_counter + 1 == min_right + let max_cnt_plus_one = b.add(subtree_proof.max_counter_target(), one); + b.connect(max_cnt_plus_one, min_right); + + // Ensure that the subtree rooted in the sibling of the included child + // contains only records outside of [query_min; query_max] range + // left == (left AND (pC.min_counter > offset_range_max)) + let is_greater = greater_than( + b, + excluded_child_proof.min_counter_target(), + subtree_proof.offset_range_max_target(), + 32, + ); + let is_greater = b.and(is_greater, is_left_child); + b.connect(is_greater.target, is_left_child.target); + // NOT(left) == (NOT(left) AND( pC.max_counter < offset_range_min)) + let is_right_child = b.not(is_left_child); + let is_less = less_than( + b, + excluded_child_proof.max_counter_target(), + subtree_proof.offset_range_min_target(), + 32, + ); + let is_less = b.and(is_less, is_right_child); + b.connect(is_less.target, is_right_child.target); + + // Compute `min_counter` and `max_counter` for current node + let min_counter = b.select( + is_left_child, + included_chid_proof.min_counter_target(), + excluded_child_proof.min_counter_target(), + ); + let max_counter = b.select( + is_left_child, + excluded_child_proof.max_counter_target(), + included_chid_proof.max_counter_target(), + ); + + // pR.D + pI.D + let accumulator = b.add_curve_point(&[ + subtree_proof.accumulator_target(), + included_chid_proof.accumulator_target(), + ]); + + // Register the public inputs. + PublicInputs::new( + &node_hash.to_targets(), + &node_min.to_targets(), + &node_max.to_targets(), + subtree_proof.to_primary_index_value_raw(), + subtree_proof.to_index_ids_raw(), + &[min_counter], + &[max_counter], + &[*subtree_proof.to_offset_range_min_raw()], + &[*subtree_proof.to_offset_range_max_raw()], + &accumulator.to_targets(), + ) + .register(b); + + PartialNodeWires { + is_left_child, + is_rows_tree, + } + } + + fn assign(&self, pw: &mut PartialWitness, wires: &PartialNodeWires) { + pw.set_bool_target(wires.is_left_child, self.is_left_child); + pw.set_bool_target(wires.is_rows_tree, self.is_rows_tree); + } +} + +/// Subtree proof number = 1, child proof number = 2 +pub(crate) const NUM_VERIFIED_PROOFS: usize = 3; + +impl CircuitLogicWires for PartialNodeWires { + type CircuitBuilderParams = (); + type Inputs = PartialNodeCircuit; + + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); + + fn circuit_logic( + builder: &mut CBuilder, + verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], + _builder_parameters: Self::CircuitBuilderParams, + ) -> Self { + // The first one is the subtree proof, and the remainings are child proofs. + let [subtree_proof, included_child_proof, excluded_child_proof] = + verified_proofs.map(|p| PublicInputs::from_slice(&p.public_inputs)); + + Self::Inputs::build( + builder, + &subtree_proof, + &included_child_proof, + &excluded_child_proof, + ) + } + + fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { + inputs.assign(pw, self); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::results_tree::extraction::{ + tests::{random_results_extraction_public_inputs, unify_child_proof, unify_subtree_proof}, + PI_LEN, + }; + use mp2_common::{group_hashing::add_weierstrass_point, utils::ToFields, C}; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::plonk::config::Hasher; + use std::array; + + #[derive(Clone, Debug)] + struct TestPartialNodeCircuit<'a> { + c: PartialNodeCircuit, + subtree_proof: &'a [F], + included_child_proof: &'a [F], + excluded_child_proof: &'a [F], + } + + impl<'a> UserCircuit for TestPartialNodeCircuit<'a> { + // Circuit wires + subtree proof + included child proof + excluded child proof + type Wires = (PartialNodeWires, Vec, Vec, Vec); + + fn build(b: &mut CBuilder) -> Self::Wires { + let proofs = array::from_fn(|_| b.add_virtual_target_arr::<{ PI_LEN }>().to_vec()); + + let [subtree_pi, included_child_pi, excluded_child_pi] = + array::from_fn(|i| PublicInputs::::from_slice(&proofs[i])); + + let wires = + PartialNodeCircuit::build(b, &subtree_pi, &included_child_pi, &excluded_child_pi); + + let [subtree_proof, included_child_proof, excluded_child_proof] = proofs; + + ( + wires, + subtree_proof, + included_child_proof, + excluded_child_proof, + ) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.c.assign(pw, &wires.0); + pw.set_target_arr(&wires.1, self.subtree_proof); + pw.set_target_arr(&wires.2, self.included_child_proof); + pw.set_target_arr(&wires.3, self.excluded_child_proof); + } + } + + fn test_partial_node_circuit(is_rows_tree: bool, is_left_child: bool) { + let [mut subtree_proof, mut included_child_proof, mut excluded_child_proof] = + random_results_extraction_public_inputs::<3>(); + unify_subtree_proof(&mut subtree_proof); + let subtree_pi = PublicInputs::from_slice(&subtree_proof); + [ + (&mut included_child_proof, is_left_child), + (&mut excluded_child_proof, !is_left_child), + ] + .iter_mut() + .for_each(|(p, is_left_child)| { + unify_child_proof(p, is_rows_tree, *is_left_child, &subtree_pi) + }); + let included_child_pi = PublicInputs::from_slice(&included_child_proof); + let excluded_child_pi = PublicInputs::from_slice(&excluded_child_proof); + + // Construct the expected public input values. + let index_ids = subtree_pi.index_ids(); + let primary_index_value = subtree_pi.primary_index_value(); + let node_value = if is_rows_tree { + subtree_pi.min_value() + } else { + primary_index_value + }; + let node_min = if is_left_child { + included_child_pi.min_value() + } else { + excluded_child_pi.min_value() + }; + let node_max = if is_left_child { + excluded_child_pi.max_value() + } else { + included_child_pi.max_value() + }; + let min_counter = if is_left_child { + included_child_pi.min_counter() + } else { + excluded_child_pi.min_counter() + }; + let max_counter = if is_left_child { + excluded_child_pi.max_counter() + } else { + included_child_pi.max_counter() + }; + + // Construct the test circuit. + let test_circuit = TestPartialNodeCircuit { + c: PartialNodeCircuit { + is_left_child, + is_rows_tree, + }, + subtree_proof: &subtree_proof, + included_child_proof: &included_child_proof, + excluded_child_proof: &excluded_child_proof, + }; + + // Prove for the test circuit. + let proof = run_circuit::(test_circuit); + let pi = PublicInputs::from_slice(&proof.public_inputs); + + // Check the public inputs. + // Tree hash + { + let column_id = if is_rows_tree { + index_ids[1] + } else { + index_ids[0] + }; + let left_hash = if is_left_child { + included_child_pi.tree_hash() + } else { + excluded_child_pi.tree_hash() + }; + let right_hash = if is_left_child { + excluded_child_pi.tree_hash() + } else { + included_child_pi.tree_hash() + }; + let hash_inputs: Vec<_> = left_hash + .to_fields() + .into_iter() + .chain(right_hash.to_fields()) + .chain(node_min.to_fields()) + .chain(node_max.to_fields()) + .chain(iter::once(column_id)) + .chain(node_value.to_fields()) + .chain(subtree_pi.tree_hash().to_fields()) + .collect(); + let exp_hash = H::hash_no_pad(&hash_inputs); + assert_eq!(pi.tree_hash(), exp_hash); + } + + // Minimum value + assert_eq!(pi.min_value(), node_min); + + // Maximum value + assert_eq!(pi.max_value(), node_max); + + // Primary index value + assert_eq!(pi.primary_index_value(), subtree_pi.primary_index_value()); + + // Index IDs + assert_eq!(pi.index_ids(), index_ids); + + // Minimum counter + assert_eq!(pi.min_counter(), min_counter); + + // Maximum counter + assert_eq!(pi.max_counter(), max_counter); + + // Offset range min + assert_eq!(pi.offset_range_min(), subtree_pi.offset_range_min()); + + // Offset range max + assert_eq!(pi.offset_range_max(), subtree_pi.offset_range_max()); + + // Accumulator + { + let exp_accumulator = + add_weierstrass_point(&[subtree_pi.accumulator(), included_child_pi.accumulator()]); + + assert_eq!(pi.accumulator(), exp_accumulator); + } + } + + #[test] + fn test_partial_node_for_row_node_with_left_child() { + test_partial_node_circuit(true, true); + } + #[test] + fn test_partial_node_for_row_node_with_right_child() { + test_partial_node_circuit(true, false); + } + #[test] + fn test_partial_node_for_index_node_with_left_child() { + test_partial_node_circuit(false, true); + } + #[test] + fn test_partial_node_for_index_node_with_right_child() { + test_partial_node_circuit(false, false); + } +}