diff --git a/Cargo.lock b/Cargo.lock index e3c0e1dc4..ade5da76e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4574,13 +4574,13 @@ dependencies = [ "serde", "serde_json", "serde_plain", - "serde_with 3.9.0", + "serde_with 3.11.0", "sha2", "sha256", "starkyx", "tokio", "tracing", - "uuid 1.10.0", + "uuid 1.11.0", ] [[package]] @@ -5693,9 +5693,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cecfa94848272156ea67b2b1a53f20fc7bc638c4a46d2f8abde08f05f4b857" +checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ "base64 0.22.1", "chrono", @@ -5705,7 +5705,7 @@ dependencies = [ "serde", "serde_derive", "serde_json", - "serde_with_macros 3.9.0", + "serde_with_macros 3.11.0", "time", ] @@ -5723,9 +5723,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8fee4991ef4f274617a51ad4af30519438dacb2f56ac773b08a1922ff743350" +checksum = "9d846214a9854ef724f3da161b426242d8de7c1fc7de2f89bb1efcb154dca79d" dependencies = [ "darling", "proc-macro2", @@ -6797,9 +6797,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "serde", ] @@ -6830,6 +6830,7 @@ dependencies = [ "log", "mp2_common", "mp2_test", + "num", "plonky2", "plonky2_crypto", "plonky2_ecdsa", diff --git a/mp2-common/src/poseidon.rs b/mp2-common/src/poseidon.rs index b64755fd6..cf2a84eda 100644 --- a/mp2-common/src/poseidon.rs +++ b/mp2-common/src/poseidon.rs @@ -35,6 +35,9 @@ pub type H = >::Hasher; pub type P = >::AlgebraicPermutation; pub type HashPermutation = >::Permutation; +/// The result of hash to integer has 4 Uint32 (128 bits). +pub const HASH_TO_INT_LEN: usize = 4; + /// The flattened length of Poseidon hash, each original field is splitted from an /// Uint64 into two Uint32. pub const FLATTEN_POSEIDON_LEN: usize = NUM_HASH_OUT_ELTS * 2; diff --git a/mp2-common/src/utils.rs b/mp2-common/src/utils.rs index a09f60a90..76b3d6ec0 100644 --- a/mp2-common/src/utils.rs +++ b/mp2-common/src/utils.rs @@ -12,6 +12,7 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::VerifierCircuitData; use plonky2::plonk::config::{GenericConfig, GenericHashOut, Hasher}; use plonky2_crypto::u32::arithmetic_u32::U32Target; +use plonky2_ecdsa::gadgets::biguint::BigUintTarget; use plonky2_ecgfp5::gadgets::{base_field::QuinticExtensionTarget, curve::CurveTarget}; use sha3::Digest; @@ -439,6 +440,12 @@ impl ToTargets for &[Target] { } } +impl ToTargets for BigUintTarget { + fn to_targets(&self) -> Vec { + self.limbs.iter().map(|u| u.0).collect() + } +} + pub trait TargetsConnector { fn connect_targets(&mut self, e1: T, e2: T); fn is_equal_targets(&mut self, e1: T, e2: T) -> BoolTarget; diff --git a/verifiable-db/Cargo.toml b/verifiable-db/Cargo.toml index 28b41c55e..6e27d0299 100644 --- a/verifiable-db/Cargo.toml +++ b/verifiable-db/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] mp2_common = { path = "../mp2-common" } +num.workspace = true plonky2_crypto.workspace = true recursion_framework = { path = "../recursion-framework" } ryhope = { path = "../ryhope" } @@ -29,4 +30,5 @@ serial_test.workspace = true tokio.workspace = true [features] -original_poseidon = ["mp2_common/original_poseidon"] \ No newline at end of file +original_poseidon = ["mp2_common/original_poseidon"] + diff --git a/verifiable-db/src/block_tree/api.rs b/verifiable-db/src/block_tree/api.rs index 33f5c6e54..023494840 100644 --- a/verifiable-db/src/block_tree/api.rs +++ b/verifiable-db/src/block_tree/api.rs @@ -294,7 +294,7 @@ mod tests { use std::iter; const EXTRACTION_IO_LEN: usize = extraction::test::PublicInputs::::TOTAL_LEN; - const ROWS_TREE_IO_LEN: usize = row_tree::PublicInputs::::TOTAL_LEN; + const ROWS_TREE_IO_LEN: usize = row_tree::PublicInputs::::total_len(); struct TestBuilder where diff --git a/verifiable-db/src/block_tree/leaf.rs b/verifiable-db/src/block_tree/leaf.rs index d2e1e055c..b6966047a 100644 --- a/verifiable-db/src/block_tree/leaf.rs +++ b/verifiable-db/src/block_tree/leaf.rs @@ -2,7 +2,7 @@ //! an existing node (or if there is no existing node, which happens for the //! first block number). -use super::{compute_index_digest, public_inputs::PublicInputs}; +use super::{compute_final_digest, compute_index_digest, public_inputs::PublicInputs}; use crate::{ extraction::{ExtractionPI, ExtractionPIWrap}, row_tree, @@ -10,7 +10,6 @@ use crate::{ use anyhow::Result; use mp2_common::{ default_config, - group_hashing::CircuitBuilderGroupHashing, poseidon::{empty_poseidon_hash, H}, proof::ProofWithVK, public_inputs::PublicInputCommon, @@ -55,15 +54,12 @@ impl LeafCircuit { let extraction_pi = E::PI::from_slice(extraction_pi); let rows_tree_pi = row_tree::PublicInputs::::from_slice(rows_tree_pi); + let final_digest = compute_final_digest::(b, &extraction_pi, &rows_tree_pi); // in our case, the extraction proofs extracts from the blockchain and sets // the block number as the primary index let index_value = extraction_pi.primary_index_value(); - // Enforce that the data extracted from the blockchain is the same as the data - // employed to build the rows tree for this node. - b.connect_curve_points(extraction_pi.value_set_digest(), rows_tree_pi.rows_digest()); - // Compute the hash of table metadata, to be exposed as public input to prove to // the verifier that we extracted the correct storage slots and we place the data // in the expected columns of the constructed tree; we add also the identifier @@ -82,7 +78,7 @@ impl LeafCircuit { let inputs = iter::once(index_identifier) .chain(index_value.iter().cloned()) .collect(); - let node_digest = compute_index_digest(b, inputs, rows_tree_pi.rows_digest()); + let node_digest = compute_index_digest(b, inputs, final_digest); // Compute hash of the inserted node // node_min = block_number @@ -103,7 +99,7 @@ impl LeafCircuit { // check that the rows tree built is for a merged table iff we extract data from MPT for a merged table b.connect( - rows_tree_pi.is_merge_case().target, + rows_tree_pi.merge_flag_target().target, extraction_pi.is_merge_case().target, ); @@ -170,7 +166,7 @@ where _verified_proofs: [&ProofWithPublicInputsTarget; 0], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - const ROWS_TREE_IO: usize = row_tree::PublicInputs::::TOTAL_LEN; + const ROWS_TREE_IO: usize = row_tree::PublicInputs::::total_len(); let extraction_verifier = RecursiveCircuitsVerifierGagdet::::new( @@ -262,7 +258,7 @@ pub mod tests { let hash = H::hash_no_pad(&inputs); let int = hash_to_int_value(hash); let scalar = Scalar::from_noncanonical_biguint(int); - let point = rows_tree_pi.rows_digest_field(); + let point = rows_tree_pi.individual_digest_point(); let point = weierstrass_to_point(&point); point * scalar } @@ -279,7 +275,7 @@ pub mod tests { fn build(b: &mut CBuilder) -> Self::Wires { let extraction_pi = b.add_virtual_targets(TestPITargets::TOTAL_LEN); - let rows_tree_pi = b.add_virtual_targets(row_tree::PublicInputs::::TOTAL_LEN); + let rows_tree_pi = b.add_virtual_targets(row_tree::PublicInputs::::total_len()); let leaf_wires = LeafCircuit::build::(b, &extraction_pi, &rows_tree_pi); @@ -292,7 +288,7 @@ pub mod tests { assert_eq!(wires.1.len(), TestPITargets::TOTAL_LEN); pw.set_target_arr(&wires.1, self.extraction_pi); - assert_eq!(wires.2.len(), row_tree::PublicInputs::::TOTAL_LEN); + assert_eq!(wires.2.len(), row_tree::PublicInputs::::total_len()); pw.set_target_arr(&wires.2, self.rows_tree_pi); } } diff --git a/verifiable-db/src/block_tree/mod.rs b/verifiable-db/src/block_tree/mod.rs index 34f172404..6a65418fa 100644 --- a/verifiable-db/src/block_tree/mod.rs +++ b/verifiable-db/src/block_tree/mod.rs @@ -4,9 +4,18 @@ mod membership; mod parent; mod public_inputs; +use crate::{ + extraction::{ExtractionPI, ExtractionPIWrap}, + row_tree, +}; pub use api::{CircuitInput, PublicParameters}; -use mp2_common::{poseidon::hash_to_int_target, CHasher, D, F}; -use plonky2::{iop::target::Target, plonk::circuit_builder::CircuitBuilder}; +use mp2_common::{ + group_hashing::{circuit_hashed_scalar_mul, CircuitBuilderGroupHashing}, + poseidon::hash_to_int_target, + types::CBuilder, + CHasher, D, F, +}; +use plonky2::{field::types::Field, iop::target::Target, plonk::circuit_builder::CircuitBuilder}; use plonky2_ecdsa::gadgets::nonnative::CircuitBuilderNonNative; use plonky2_ecgfp5::gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}; @@ -25,10 +34,62 @@ pub(crate) fn compute_index_digest( b.curve_scalar_mul(base, &scalar) } +/// Compute the final digest. +pub(crate) fn compute_final_digest<'a, E>( + b: &mut CBuilder, + extraction_pi: &E::PI<'a>, + rows_tree_pi: &row_tree::PublicInputs, +) -> CurveTarget +where + E: ExtractionPIWrap, +{ + // Compute the final row digest from rows_tree_proof for merge case: + // multiplier_digest = rows_tree_proof.row_id_multiplier * rows_tree_proof.multiplier_vd + let multiplier_vd = rows_tree_pi.multiplier_digest_target(); + let row_id_multiplier = b.biguint_to_nonnative(&rows_tree_pi.row_id_multiplier_target()); + let multiplier_digest = b.curve_scalar_mul(multiplier_vd, &row_id_multiplier); + // rows_digest_merge = multiplier_digest * rows_tree_proof.DR + let individual_digest = rows_tree_pi.individual_digest_target(); + let rows_digest_merge = circuit_hashed_scalar_mul(b, multiplier_digest, individual_digest); + + // Choose the final row digest depending on whether we are in merge case or not: + // final_digest = extraction_proof.is_merge ? rows_digest_merge : rows_tree_proof.DR + let final_digest = b.curve_select( + extraction_pi.is_merge_case(), + rows_digest_merge, + individual_digest, + ); + + // Enforce that the data extracted from the blockchain is the same as the data + // employed to build the rows tree for this node: + // assert final_digest == extraction_proof.DV + b.connect_curve_points(final_digest, extraction_pi.value_set_digest()); + + // Enforce that if we aren't in merge case, then no cells were accumulated in + // multiplier digest: + // assert extraction_proof.is_merge or rows_tree_proof.multiplier_vd != 0 + // => (1 - is_merge) * is_multiplier_vd_zero == false + let ffalse = b._false(); + let curve_zero = b.curve_zero(); + let is_multiplier_vd_zero = b + .curve_eq(rows_tree_pi.multiplier_digest_target(), curve_zero) + .target; + let should_be_false = b.arithmetic( + F::NEG_ONE, + F::ONE, + extraction_pi.is_merge_case().target, + is_multiplier_vd_zero, + is_multiplier_vd_zero, + ); + b.connect(should_be_false, ffalse.target); + + final_digest +} + #[cfg(test)] pub(crate) mod tests { use alloy::primitives::U256; - use mp2_common::{keccak::PACKED_HASH_LEN, utils::ToFields, F}; + use mp2_common::{keccak::PACKED_HASH_LEN, poseidon::HASH_TO_INT_LEN, utils::ToFields, F}; use mp2_test::utils::random_vector; use plonky2::{ field::types::{Field, Sample}, @@ -79,7 +140,19 @@ pub(crate) mod tests { let h = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); let [min, max] = [0; 2].map(|_| U256::from_limbs(rng.gen::<[u64; 4]>()).to_fields()); let is_merge = [F::from_canonical_usize(is_merge_case as usize)]; - row_tree::PublicInputs::new(&h, row_digest, &min, &max, &is_merge).to_vec() + let multiplier_digest = Point::sample(rng).to_weierstrass().to_fields(); + let row_id_multiplier = random_vector::(HASH_TO_INT_LEN).to_fields(); + + row_tree::PublicInputs::new( + &h, + row_digest, + &min, + &max, + &is_merge, + &multiplier_digest, + &row_id_multiplier, + ) + .to_vec() } /// Generate a random extraction public inputs. diff --git a/verifiable-db/src/block_tree/parent.rs b/verifiable-db/src/block_tree/parent.rs index ca7b8af66..68988e87f 100644 --- a/verifiable-db/src/block_tree/parent.rs +++ b/verifiable-db/src/block_tree/parent.rs @@ -1,7 +1,7 @@ //! This circuit is employed when the new node is inserted as parent of an existing node, //! referred to as old node. -use super::{compute_index_digest, public_inputs::PublicInputs}; +use super::{compute_final_digest, compute_index_digest, public_inputs::PublicInputs}; use crate::{ extraction::{ExtractionPI, ExtractionPIWrap}, row_tree, @@ -10,7 +10,6 @@ use alloy::primitives::U256; use anyhow::Result; use mp2_common::{ default_config, - group_hashing::CircuitBuilderGroupHashing, poseidon::{empty_poseidon_hash, H}, proof::ProofWithVK, public_inputs::PublicInputCommon, @@ -84,13 +83,10 @@ impl ParentCircuit { let extraction_pi = E::PI::from_slice(extraction_pi); let rows_tree_pi = row_tree::PublicInputs::::from_slice(rows_tree_pi); + let final_digest = compute_final_digest::(b, &extraction_pi, &rows_tree_pi); let block_number = extraction_pi.primary_index_value(); - // Enforce that the data extracted from the blockchain is the same as the data - // employed to build the rows tree for this node. - b.connect_curve_points(extraction_pi.value_set_digest(), rows_tree_pi.rows_digest()); - // Compute the hash of table metadata, to be exposed as public input to prove to // the verifier that we extracted the correct storage slots and we place the data // in the expected columns of the constructed tree; we add also the identifier @@ -110,7 +106,7 @@ impl ParentCircuit { let inputs = iter::once(index_identifier) .chain(block_number.iter().cloned()) .collect(); - let node_digest = compute_index_digest(b, inputs, rows_tree_pi.rows_digest()); + let node_digest = compute_index_digest(b, inputs, final_digest); // We recompute the hash of the old node to bind the `old_min` and `old_max` // values to the hash of the old tree. @@ -154,7 +150,7 @@ impl ParentCircuit { // check that the rows tree built is for a merged table iff we extract data from MPT for a merged table b.connect( - rows_tree_pi.is_merge_case().target, + rows_tree_pi.merge_flag_target().target, extraction_pi.is_merge_case().target, ); @@ -236,7 +232,7 @@ where _verified_proofs: [&ProofWithPublicInputsTarget; 0], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - const ROWS_TREE_IO: usize = row_tree::PublicInputs::::TOTAL_LEN; + const ROWS_TREE_IO: usize = row_tree::PublicInputs::::total_len(); let extraction_verifier = RecursiveCircuitsVerifierGagdet::::new( @@ -315,7 +311,7 @@ mod tests { fn build(b: &mut CBuilder) -> Self::Wires { let extraction_pi = b.add_virtual_targets(TestPITargets::TOTAL_LEN); - let rows_tree_pi = b.add_virtual_targets(row_tree::PublicInputs::::TOTAL_LEN); + let rows_tree_pi = b.add_virtual_targets(row_tree::PublicInputs::::total_len()); let parent_wires = ParentCircuit::build::(b, &extraction_pi, &rows_tree_pi); @@ -329,7 +325,7 @@ mod tests { assert_eq!(wires.1.len(), TestPITargets::TOTAL_LEN); pw.set_target_arr(&wires.1, self.extraction_pi); - assert_eq!(wires.2.len(), row_tree::PublicInputs::::TOTAL_LEN); + assert_eq!(wires.2.len(), row_tree::PublicInputs::::total_len()); pw.set_target_arr(&wires.2, self.rows_tree_pi); } } diff --git a/verifiable-db/src/cells_tree/api.rs b/verifiable-db/src/cells_tree/api.rs index 1a6487fa6..8b7a84740 100644 --- a/verifiable-db/src/cells_tree/api.rs +++ b/verifiable-db/src/cells_tree/api.rs @@ -2,9 +2,9 @@ use super::{ empty_node::{EmptyNodeCircuit, EmptyNodeWires}, - full_node::{FullNodeCircuit, FullNodeWires}, + full_node::FullNodeWires, leaf::{LeafCircuit, LeafWires}, - partial_node::{PartialNodeCircuit, PartialNodeWires}, + partial_node::PartialNodeWires, public_inputs::PublicInputs, Cell, }; @@ -39,12 +39,13 @@ impl CircuitInput { /// Create a circuit input for proving a leaf node. /// It is not considered a multiplier column. Please use `leaf_multiplier` for registering a /// multiplier column. - pub fn leaf(identifier: u64, value: U256) -> Self { + pub fn leaf(identifier: u64, value: U256, mpt_metadata: HashOut) -> Self { CircuitInput::Leaf( Cell { identifier: F::from_canonical_u64(identifier), value, is_multiplier: false, + mpt_metadata, } .into(), ) @@ -52,12 +53,18 @@ impl CircuitInput { /// Create a circuit input for proving a leaf node whose value is considered as a multiplier /// depending on the boolean value. /// i.e. it means it's one of the repeated value amongst all the rows - pub fn leaf_multiplier(identifier: u64, value: U256, is_multiplier: bool) -> Self { + pub fn leaf_multiplier( + identifier: u64, + value: U256, + is_multiplier: bool, + mpt_metadata: HashOut, + ) -> Self { CircuitInput::Leaf( Cell { identifier: F::from_canonical_u64(identifier), value, is_multiplier, + mpt_metadata, } .into(), ) @@ -66,11 +73,17 @@ impl CircuitInput { /// Create a circuit input for proving a full node of 2 children. /// It is not considered a multiplier column. Please use `full_multiplier` for registering a /// multiplier column. - pub fn full(identifier: u64, value: U256, child_proofs: [Vec; 2]) -> Self { + pub fn full( + identifier: u64, + value: U256, + mpt_metadata: HashOut, + child_proofs: [Vec; 2], + ) -> Self { CircuitInput::FullNode(new_child_input( F::from_canonical_u64(identifier), value, false, + mpt_metadata, child_proofs.to_vec(), )) } @@ -80,23 +93,31 @@ impl CircuitInput { identifier: u64, value: U256, is_multiplier: bool, + mpt_metadata: HashOut, child_proofs: [Vec; 2], ) -> Self { CircuitInput::FullNode(new_child_input( F::from_canonical_u64(identifier), value, is_multiplier, + mpt_metadata, child_proofs.to_vec(), )) } /// Create a circuit input for proving a partial node of 1 child. /// It is not considered a multiplier column. Please use `partial_multiplier` for registering a /// multiplier column. - pub fn partial(identifier: u64, value: U256, child_proof: Vec) -> Self { + pub fn partial( + identifier: u64, + value: U256, + mpt_metadata: HashOut, + child_proof: Vec, + ) -> Self { CircuitInput::PartialNode(new_child_input( F::from_canonical_u64(identifier), value, false, + mpt_metadata, vec![child_proof], )) } @@ -104,12 +125,14 @@ impl CircuitInput { identifier: u64, value: U256, is_multiplier: bool, + mpt_metadata: HashOut, child_proof: Vec, ) -> Self { CircuitInput::PartialNode(new_child_input( F::from_canonical_u64(identifier), value, is_multiplier, + mpt_metadata, vec![child_proof], )) } @@ -120,6 +143,7 @@ fn new_child_input( identifier: F, value: U256, is_multiplier: bool, + mpt_metadata: HashOut, serialized_child_proofs: Vec>, ) -> ChildInput { ChildInput { @@ -127,6 +151,7 @@ fn new_child_input( identifier, value, is_multiplier, + mpt_metadata, }, serialized_child_proofs, } @@ -148,7 +173,7 @@ pub fn build_circuits_params() -> PublicParameters { PublicParameters::build() } -const NUM_IO: usize = PublicInputs::::TOTAL_LEN; +const NUM_IO: usize = PublicInputs::::total_len(); /// Number of circuits in the set /// 1 leaf + 1 full node + 1 partial node + 1 empty node @@ -246,8 +271,10 @@ impl PublicParameters { pub fn extract_hash_from_proof(proof: &[u8]) -> Result> { let p = ProofWithVK::deserialize(proof)?; - Ok(PublicInputs::from_slice(&p.proof.public_inputs).root_hash_hashout()) + Ok(PublicInputs::from_slice(&p.proof.public_inputs).node_hash()) } + +/* #[cfg(test)] mod tests { use super::*; @@ -452,3 +479,4 @@ mod tests { proof } } +*/ diff --git a/verifiable-db/src/cells_tree/empty_node.rs b/verifiable-db/src/cells_tree/empty_node.rs index d0d770b1f..b86013f3f 100644 --- a/verifiable-db/src/cells_tree/empty_node.rs +++ b/verifiable-db/src/cells_tree/empty_node.rs @@ -23,11 +23,11 @@ impl EmptyNodeCircuit { let empty_hash = empty_poseidon_hash(); let h = b.constant_hash(*empty_hash).elements; - // dc = CURVE_ZERO - let dc = b.curve_zero().to_targets(); + // CURVE_ZERO + let curve_zero = b.curve_zero().to_targets(); // Register the public inputs. - PublicInputs::new(&h, &dc, &dc).register(b); + PublicInputs::new(&h, &curve_zero, &curve_zero, &curve_zero, &curve_zero).register(b); EmptyNodeWires } @@ -39,7 +39,7 @@ impl CircuitLogicWires for EmptyNodeWires { type Inputs = EmptyNodeCircuit; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CBuilder, @@ -54,6 +54,7 @@ impl CircuitLogicWires for EmptyNodeWires { } } +/* #[cfg(test)] mod tests { use super::*; @@ -87,3 +88,4 @@ mod tests { } } } +*/ diff --git a/verifiable-db/src/cells_tree/full_node.rs b/verifiable-db/src/cells_tree/full_node.rs index 3a5bb4f3f..983ec071f 100644 --- a/verifiable-db/src/cells_tree/full_node.rs +++ b/verifiable-db/src/cells_tree/full_node.rs @@ -4,8 +4,7 @@ use super::{public_inputs::PublicInputs, Cell, CellWire}; use anyhow::Result; use derive_more::{From, Into}; use mp2_common::{ - group_hashing::CircuitBuilderGroupHashing, public_inputs::PublicInputCommon, types::CBuilder, - u256::CircuitBuilderU256, utils::ToTargets, CHasher, D, F, + poseidon::H, public_inputs::PublicInputCommon, types::CBuilder, utils::ToTargets, D, F, }; use plonky2::{ iop::{target::Target, witness::PartialWitness}, @@ -13,7 +12,7 @@ use plonky2::{ }; use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; -use std::{array, iter}; +use std::{array, iter::once}; #[derive(Clone, Debug, Serialize, Deserialize, Into, From)] pub struct FullNodeWires(CellWire); @@ -23,30 +22,35 @@ pub struct FullNodeCircuit(Cell); impl FullNodeCircuit { pub fn build(b: &mut CBuilder, child_proofs: [PublicInputs; 2]) -> FullNodeWires { + let [p1, p2] = child_proofs; + let cell = CellWire::new(b); + let metadata_digests = cell.split_metadata_digest(b); + let values_digests = cell.split_values_digest(b); + + let metadata_digests = metadata_digests.accumulate(b, &p1.split_metadata_digest_target()); + let metadata_digests = metadata_digests.accumulate(b, &p2.split_metadata_digest_target()); - // h = Poseidon(p1.H || p2.H || identifier || value) - let [p1_hash, p2_hash] = [0, 1].map(|i| child_proofs[i].node_hash()); - let inputs: Vec<_> = p1_hash - .elements - .iter() - .cloned() - .chain(p2_hash.elements) - .chain(iter::once(cell.identifier)) + let values_digests = values_digests.accumulate(b, &p1.split_values_digest_target()); + let values_digests = values_digests.accumulate(b, &p2.split_values_digest_target()); + + // H(p1.H || p2.H || identifier || value) + let inputs = p1 + .node_hash_target() + .into_iter() + .chain(p2.node_hash_target()) + .chain(once(cell.identifier)) .chain(cell.value.to_targets()) .collect(); - let h = b.hash_n_to_hash_no_pad::(inputs).elements; - - // digest_cell = p1.digest_cell + p2.digest_cell + D(identifier || value) - let split_digest = cell.split_digest(b); - let split_digest = split_digest.accumulate(b, &child_proofs[0].split_digest_target()); - let split_digest = split_digest.accumulate(b, &child_proofs[1].split_digest_target()); + let h = b.hash_n_to_hash_no_pad::(inputs); // Register the public inputs. PublicInputs::new( - &h, - &split_digest.individual.to_targets(), - &split_digest.multiplier.to_targets(), + &h.to_targets(), + &values_digests.individual.to_targets(), + &values_digests.multiplier.to_targets(), + &metadata_digests.individual.to_targets(), + &metadata_digests.multiplier.to_targets(), ) .register(b); @@ -65,7 +69,7 @@ impl CircuitLogicWires for FullNodeWires { type Inputs = FullNodeCircuit; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CBuilder, @@ -83,6 +87,7 @@ impl CircuitLogicWires for FullNodeWires { } } +/* #[cfg(test)] mod tests { use super::*; @@ -195,3 +200,4 @@ mod tests { } } } +*/ diff --git a/verifiable-db/src/cells_tree/leaf.rs b/verifiable-db/src/cells_tree/leaf.rs index 72fefca14..180643fc1 100644 --- a/verifiable-db/src/cells_tree/leaf.rs +++ b/verifiable-db/src/cells_tree/leaf.rs @@ -3,8 +3,11 @@ use super::{public_inputs::PublicInputs, Cell, CellWire}; use derive_more::{From, Into}; use mp2_common::{ - poseidon::empty_poseidon_hash, public_inputs::PublicInputCommon, types::CBuilder, - utils::ToTargets, CHasher, D, F, + poseidon::{empty_poseidon_hash, H}, + public_inputs::PublicInputCommon, + types::CBuilder, + utils::ToTargets, + D, F, }; use plonky2::{ iop::witness::PartialWitness, @@ -12,7 +15,7 @@ use plonky2::{ }; use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; -use std::iter; +use std::iter::once; #[derive(Clone, Debug, Serialize, Deserialize, From, Into)] pub struct LeafWires(CellWire); @@ -23,28 +26,27 @@ pub struct LeafCircuit(Cell); impl LeafCircuit { fn build(b: &mut CBuilder) -> LeafWires { let cell = CellWire::new(b); - - // h = Poseidon(Poseidon("") || Poseidon("") || identifier || value) - let empty_hash = empty_poseidon_hash(); - let empty_hash = b.constant_hash(*empty_hash); - let inputs: Vec<_> = empty_hash - .elements - .iter() - .cloned() - .chain(empty_hash.elements) - .chain(iter::once(cell.identifier)) + let metadata_digests = cell.split_metadata_digest(b); + let values_digests = cell.split_values_digest(b); + + // H(H("") || H("") || identifier || pack_u32(value)) + let empty_hash = b.constant_hash(*empty_poseidon_hash()).to_targets(); + let inputs = empty_hash + .clone() + .into_iter() + .chain(empty_hash) + .chain(once(cell.identifier)) .chain(cell.value.to_targets()) .collect(); - let h = b.hash_n_to_hash_no_pad::(inputs).elements; - - // digest_cell = D(identifier || value) - let split_digest = cell.split_digest(b); + let h = b.hash_n_to_hash_no_pad::(inputs); // Register the public inputs. PublicInputs::new( - &h, - &split_digest.individual.to_targets(), - &split_digest.multiplier.to_targets(), + &h.to_targets(), + &values_digests.individual.to_targets(), + &values_digests.multiplier.to_targets(), + &metadata_digests.individual.to_targets(), + &metadata_digests.multiplier.to_targets(), ) .register(b); @@ -63,7 +65,7 @@ impl CircuitLogicWires for LeafWires { type Inputs = LeafCircuit; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CircuitBuilder, @@ -79,6 +81,7 @@ impl CircuitLogicWires for LeafWires { } } +/* #[cfg(test)] mod tests { use super::*; @@ -153,3 +156,4 @@ mod tests { } } } +*/ diff --git a/verifiable-db/src/cells_tree/mod.rs b/verifiable-db/src/cells_tree/mod.rs index af0e85846..a5ba1d0dc 100644 --- a/verifiable-db/src/cells_tree/mod.rs +++ b/verifiable-db/src/cells_tree/mod.rs @@ -5,11 +5,10 @@ mod leaf; mod partial_node; mod public_inputs; -use serde::{Deserialize, Serialize}; - use alloy::primitives::U256; pub use api::{build_circuits_params, extract_hash_from_proof, CircuitInput, PublicParameters}; use derive_more::Constructor; +use itertools::Itertools; use mp2_common::{ digest::{Digest, SplitDigestPoint, SplitDigestTarget}, group_hashing::{map_to_curve_point, CircuitBuilderGroupHashing}, @@ -17,15 +16,17 @@ use mp2_common::{ types::CBuilder, u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, utils::{ToFields, ToTargets}, - D, F, + F, }; +use serde::{Deserialize, Serialize}; +use std::iter::once; use plonky2::{ + hash::hash_types::{HashOut, HashOutTarget}, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, - plonk::circuit_builder::CircuitBuilder, }; use plonky2_ecgfp5::gadgets::curve::CurveTarget; pub use public_inputs::PublicInputs; @@ -40,6 +41,8 @@ pub(crate) struct Cell { pub(crate) value: U256, /// is the secondary value should be included in multiplier digest or not pub(crate) is_multiplier: bool, + /// Hash of the metadata associated to this cell, as computed in MPT extraction circuits + pub(crate) mpt_metadata: HashOut, } impl Cell { @@ -47,29 +50,48 @@ impl Cell { pw.set_u256_target(&wires.value, self.value); pw.set_target(wires.identifier, self.identifier); pw.set_bool_target(wires.is_multiplier, self.is_multiplier); + pw.set_hash_target(wires.mpt_metadata, self.mpt_metadata); } - pub(crate) fn digest(&self) -> Digest { - map_to_curve_point(&self.to_fields()) + pub(crate) fn split_metadata_digest(&self) -> SplitDigestPoint { + let digest = self.metadata_digest(); + SplitDigestPoint::from_single_digest_point(digest, self.is_multiplier) } - pub(crate) fn split_digest(&self) -> SplitDigestPoint { - let digest = self.digest(); + pub(crate) fn split_values_digest(&self) -> SplitDigestPoint { + let digest = self.values_digest(); SplitDigestPoint::from_single_digest_point(digest, self.is_multiplier) } - pub(crate) fn split_and_accumulate_digest( + pub(crate) fn split_and_accumulate_metadata_digest( &self, child_digest: SplitDigestPoint, ) -> SplitDigestPoint { - let sd = self.split_digest(); - sd.accumulate(&child_digest) + let split_digest = self.split_metadata_digest(); + split_digest.accumulate(&child_digest) } -} - -impl ToFields for Cell { - fn to_fields(&self) -> Vec { - [self.identifier] + pub(crate) fn split_and_accumulate_values_digest( + &self, + child_digest: SplitDigestPoint, + ) -> SplitDigestPoint { + let split_digest = self.split_values_digest(); + split_digest.accumulate(&child_digest) + } + fn metadata_digest(&self) -> Digest { + // D(mpt_metadata || identifier) + let inputs = self + .mpt_metadata + .to_fields() .into_iter() + .chain(once(self.identifier)) + .collect_vec(); + + map_to_curve_point(&inputs) + } + fn values_digest(&self) -> Digest { + // D(identifier || pack_u32(value)) + let inputs = once(self.identifier) .chain(self.value.to_fields()) - .collect() + .collect_vec(); + + map_to_curve_point(&inputs) } } @@ -80,44 +102,60 @@ pub(crate) struct CellWire { pub(crate) identifier: Target, #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] pub(crate) is_multiplier: BoolTarget, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + pub(crate) mpt_metadata: HashOutTarget, } impl CellWire { - pub(crate) fn new(b: &mut CircuitBuilder) -> Self { + pub(crate) fn new(b: &mut CBuilder) -> Self { Self { value: b.add_virtual_u256(), identifier: b.add_virtual_target(), is_multiplier: b.add_virtual_bool_target_safe(), + mpt_metadata: b.add_virtual_hash(), } } - /// Returns the digest of the cell - pub(crate) fn digest(&self, b: &mut CircuitBuilder) -> CurveTarget { - b.map_to_curve_point(&self.to_targets()) + pub(crate) fn split_metadata_digest(&self, b: &mut CBuilder) -> SplitDigestTarget { + let digest = self.metadata_digest(b); + SplitDigestTarget::from_single_digest_target(b, digest, self.is_multiplier) } - /// Returns the different digest, multiplier or individual - pub(crate) fn split_digest(&self, c: &mut CBuilder) -> SplitDigestTarget { - let d = self.digest(c); - SplitDigestTarget::from_single_digest_target(c, d, self.is_multiplier) + pub(crate) fn split_values_digest(&self, b: &mut CBuilder) -> SplitDigestTarget { + let digest = self.values_digest(b); + SplitDigestTarget::from_single_digest_target(b, digest, self.is_multiplier) } - /// Returns the split digest from this cell added with the one from the proof. - /// NOTE: it calls agains split_digest, so call that first if you need the individual - /// SplitDigestTarget - pub(crate) fn split_and_accumulate_digest( + pub(crate) fn split_and_accumulate_metadata_digest( &self, - c: &mut CBuilder, + b: &mut CBuilder, child_digest: SplitDigestTarget, ) -> SplitDigestTarget { - let sd = self.split_digest(c); - sd.accumulate(c, &child_digest) + let split_digest = self.split_metadata_digest(b); + split_digest.accumulate(b, &child_digest) } -} - -impl ToTargets for CellWire { - fn to_targets(&self) -> Vec { - self.identifier + pub(crate) fn split_and_accumulate_values_digest( + &self, + b: &mut CBuilder, + child_digest: SplitDigestTarget, + ) -> SplitDigestTarget { + let split_digest = self.split_values_digest(b); + split_digest.accumulate(b, &child_digest) + } + fn metadata_digest(&self, b: &mut CBuilder) -> CurveTarget { + // D(mpt_metadata || identifier) + let inputs = self + .mpt_metadata .to_targets() .into_iter() + .chain(once(self.identifier)) + .collect_vec(); + + b.map_to_curve_point(&inputs) + } + fn values_digest(&self, b: &mut CBuilder) -> CurveTarget { + // D(identifier || pack_u32(value)) + let inputs = once(self.identifier) .chain(self.value.to_targets()) - .collect::>() + .collect_vec(); + + b.map_to_curve_point(&inputs) } } diff --git a/verifiable-db/src/cells_tree/partial_node.rs b/verifiable-db/src/cells_tree/partial_node.rs index d9b5bf45b..2e5363bc3 100644 --- a/verifiable-db/src/cells_tree/partial_node.rs +++ b/verifiable-db/src/cells_tree/partial_node.rs @@ -1,29 +1,22 @@ //! Module handling the intermediate node with 1 child inside a cells tree use super::{public_inputs::PublicInputs, Cell, CellWire}; -use alloy::primitives::U256; use anyhow::Result; use derive_more::{From, Into}; use mp2_common::{ - group_hashing::CircuitBuilderGroupHashing, - poseidon::empty_poseidon_hash, + poseidon::{empty_poseidon_hash, H}, public_inputs::PublicInputCommon, types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, utils::ToTargets, - CHasher, D, F, + D, F, }; use plonky2::{ - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, + iop::{target::Target, witness::PartialWitness}, plonk::proof::ProofWithPublicInputsTarget, }; -use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; -use std::iter; +use std::iter::once; #[derive(Clone, Debug, Serialize, Deserialize, From, Into)] pub struct PartialNodeWires(CellWire); @@ -32,32 +25,38 @@ pub struct PartialNodeWires(CellWire); pub struct PartialNodeCircuit(Cell); impl PartialNodeCircuit { - pub fn build(b: &mut CBuilder, child_proof: PublicInputs) -> PartialNodeWires { + pub fn build(b: &mut CBuilder, p: PublicInputs) -> PartialNodeWires { let cell = CellWire::new(b); - - // h = Poseidon(p.H || Poseidon("") || identifier || value) - let child_hash = child_proof.node_hash(); - let empty_hash = empty_poseidon_hash(); - let empty_hash = b.constant_hash(*empty_hash); - let inputs: Vec<_> = child_hash - .elements - .iter() - .cloned() - .chain(empty_hash.elements) - .chain(iter::once(cell.identifier)) + let metadata_digests = cell.split_metadata_digest(b); + let values_digests = cell.split_values_digest(b); + + let metadata_digests = metadata_digests.accumulate(b, &p.split_metadata_digest_target()); + let values_digests = values_digests.accumulate(b, &p.split_values_digest_target()); + + /* + # since there is no sorting constraint among the nodes of this tree, to simplify + # the circuits, when we build a node with only one child, we can always place + # it as the left child + # NOTE: this is true only if we the "block" tree + h = H(p.H || H("") || identifier || value) + */ + let empty_hash = b.constant_hash(*empty_poseidon_hash()).to_targets(); + let inputs = p + .node_hash_target() + .into_iter() + .chain(empty_hash) + .chain(once(cell.identifier)) .chain(cell.value.to_targets()) .collect(); - let h = b.hash_n_to_hash_no_pad::(inputs).elements; - - // aggregate the digest of the child proof in the right digest - // digest_cell = p.digest_cell + D(identifier || value) - let split_digest = cell.split_and_accumulate_digest(b, child_proof.split_digest_target()); + let h = b.hash_n_to_hash_no_pad::(inputs); // Register the public inputs. PublicInputs::new( - &h, - &split_digest.individual.to_targets(), - &split_digest.multiplier.to_targets(), + &h.to_targets(), + &values_digests.individual.to_targets(), + &values_digests.multiplier.to_targets(), + &metadata_digests.individual.to_targets(), + &metadata_digests.multiplier.to_targets(), ) .register(b); @@ -76,7 +75,7 @@ impl CircuitLogicWires for PartialNodeWires { type Inputs = PartialNodeCircuit; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CBuilder, @@ -93,6 +92,7 @@ impl CircuitLogicWires for PartialNodeWires { } } +/* #[cfg(test)] mod tests { use super::*; @@ -191,3 +191,4 @@ mod tests { } } } +*/ diff --git a/verifiable-db/src/cells_tree/public_inputs.rs b/verifiable-db/src/cells_tree/public_inputs.rs index a8ccfdafe..e2c2f5b3c 100644 --- a/verifiable-db/src/cells_tree/public_inputs.rs +++ b/verifiable-db/src/cells_tree/public_inputs.rs @@ -1,122 +1,225 @@ //! Public inputs for Cells Tree Construction circuits + use mp2_common::{ digest::{SplitDigestPoint, SplitDigestTarget}, group_hashing::weierstrass_to_point, public_inputs::{PublicInputCommon, PublicInputRange}, - types::{CBuilder, GFp, CURVE_TARGET_LEN}, + types::{CBuilder, CURVE_TARGET_LEN}, utils::{FromFields, FromTargets}, F, }; use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, + hash::hash_types::{HashOut, NUM_HASH_OUT_ELTS}, iop::target::Target, }; use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; -use std::{array, fmt::Debug}; -// Cells Tree Construction public inputs: -// - `H : [4]F` : Poseidon hash of the subtree at this node -// - `DI : Digest[F]` : Cells digests accumulated up so far for INDIVIDUAL digest -// - `DM: Digest[F]` : Cells digests accumulated up so far for MULTIPLIER digest -const H_RANGE: PublicInputRange = 0..NUM_HASH_OUT_ELTS; -const DI_RANGE: PublicInputRange = H_RANGE.end..H_RANGE.end + CURVE_TARGET_LEN; -const DM_RANGE: PublicInputRange = DI_RANGE.end..DI_RANGE.end + CURVE_TARGET_LEN; +pub enum CellsTreePublicInputs { + // `H : F[4]` - Poseidon hash of the subtree at this node + NodeHash, + // - `individual_vd : Digest` - Cumulative digest of values of cells accumulated as individual + IndividualValuesDigest, + // - `multiplier_vd : Digest` - Cumulative digest of values of cells accumulated as multiplier + MultiplierValuesDigest, + // - `individual_md : Digest` - Cumulative digest of metadata of cells accumulated as individual + IndividualMetadataDigest, + // - `multiplier_md : Digest` - Cumulative digest of metadata of cells accumulated as multiplier + MultiplierMetadataDigest, +} /// Public inputs for Cells Tree Construction #[derive(Clone, Debug)] pub struct PublicInputs<'a, T> { pub(crate) h: &'a [T], - pub(crate) ind: &'a [T], - pub(crate) mul: &'a [T], + pub(crate) individual_vd: &'a [T], + pub(crate) multiplier_vd: &'a [T], + pub(crate) individual_md: &'a [T], + pub(crate) multiplier_md: &'a [T], } -impl<'a> PublicInputCommon for PublicInputs<'a, Target> { - const RANGES: &'static [PublicInputRange] = &[H_RANGE, DI_RANGE, DM_RANGE]; +const NUM_PUBLIC_INPUTS: usize = CellsTreePublicInputs::MultiplierMetadataDigest as usize + 1; - fn register_args(&self, cb: &mut CBuilder) { - cb.register_public_inputs(self.h); - cb.register_public_inputs(self.ind); - cb.register_public_inputs(self.mul); - } -} +impl<'a, T: Clone> PublicInputs<'a, T> { + const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ + Self::to_range(CellsTreePublicInputs::NodeHash), + Self::to_range(CellsTreePublicInputs::IndividualValuesDigest), + Self::to_range(CellsTreePublicInputs::MultiplierValuesDigest), + Self::to_range(CellsTreePublicInputs::IndividualMetadataDigest), + Self::to_range(CellsTreePublicInputs::MultiplierMetadataDigest), + ]; -impl<'a> PublicInputs<'a, GFp> { - /// Get the cells digest point. - pub fn individual_digest_point(&self) -> WeierstrassPoint { - WeierstrassPoint::from_fields(self.ind) + const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ + // Poseidon hash of the subtree at this node + NUM_HASH_OUT_ELTS, + // Cumulative digest of values of cells accumulated as individual + CURVE_TARGET_LEN, + // Cumulative digest of values of cells accumulated as multiplier + CURVE_TARGET_LEN, + // Cumulative digest of metadata of cells accumulated as individual + CURVE_TARGET_LEN, + // Cumulative digest of metadata of cells accumulated as multiplier + CURVE_TARGET_LEN, + ]; + + pub(crate) const fn to_range(pi: CellsTreePublicInputs) -> PublicInputRange { + let mut i = 0; + let mut offset = 0; + let pi_pos = pi as usize; + while i < pi_pos { + offset += Self::SIZES[i]; + i += 1; + } + offset..offset + Self::SIZES[pi_pos] } - pub fn multiplier_digest_point(&self) -> WeierstrassPoint { - WeierstrassPoint::from_fields(self.mul) + + pub(crate) const fn total_len() -> usize { + Self::to_range(CellsTreePublicInputs::MultiplierMetadataDigest).end } - pub fn split_digest_point(&self) -> SplitDigestPoint { - SplitDigestPoint { - individual: weierstrass_to_point(&self.individual_digest_point()), - multiplier: weierstrass_to_point(&self.multiplier_digest_point()), - } + + pub(crate) fn to_node_hash_raw(&self) -> &[T] { + self.h } -} -impl<'a> PublicInputs<'a, Target> { - /// Get the Poseidon hash of the subtree at this node. - pub fn node_hash(&self) -> HashOutTarget { - self.h.try_into().unwrap() + pub(crate) fn to_individual_values_digest_raw(&self) -> &[T] { + self.individual_vd } - /// Get the individual digest target. - pub fn individual_digest_target(&self) -> CurveTarget { - CurveTarget::from_targets(self.ind) + pub(crate) fn to_multiplier_values_digest_raw(&self) -> &[T] { + self.multiplier_vd } - /// Get the cells multiplier digest - pub fn multiplier_digest_target(&self) -> CurveTarget { - CurveTarget::from_targets(self.mul) + pub(crate) fn to_individual_metadata_digest_raw(&self) -> &[T] { + self.individual_md } - pub fn split_digest_target(&self) -> SplitDigestTarget { - SplitDigestTarget { - individual: self.individual_digest_target(), - multiplier: self.multiplier_digest_target(), - } + + pub(crate) fn to_multiplier_metadata_digest_raw(&self) -> &[T] { + self.multiplier_md } -} -impl<'a, T: Copy> PublicInputs<'a, T> { - /// Total length of the public inputs - pub(crate) const TOTAL_LEN: usize = DM_RANGE.end; + pub fn from_slice(input: &'a [T]) -> Self { + assert!( + input.len() >= Self::total_len(), + "Input slice too short to build cells tree public inputs, must be at least {} elements", + Self::total_len(), + ); - /// Create a new public inputs. - pub fn new(h: &'a [T], ind: &'a [T], mul: &'a [T]) -> Self { - Self { h, ind, mul } + Self { + h: &input[Self::PI_RANGES[0].clone()], + individual_vd: &input[Self::PI_RANGES[1].clone()], + multiplier_vd: &input[Self::PI_RANGES[2].clone()], + individual_md: &input[Self::PI_RANGES[3].clone()], + multiplier_md: &input[Self::PI_RANGES[4].clone()], + } } - /// Create from a slice. - pub fn from_slice(pi: &'a [T]) -> Self { - assert!(pi.len() >= Self::TOTAL_LEN); + pub fn new( + h: &'a [T], + individual_vd: &'a [T], + multiplier_vd: &'a [T], + individual_md: &'a [T], + multiplier_md: &'a [T], + ) -> Self { Self { - h: &pi[H_RANGE], - ind: &pi[DI_RANGE], - mul: &pi[DM_RANGE], + h, + individual_vd, + multiplier_vd, + individual_md, + multiplier_md, } } - /// Combine to a vector. pub fn to_vec(&self) -> Vec { self.h .iter() - .chain(self.ind) - .chain(self.mul) + .chain(self.individual_vd) + .chain(self.multiplier_vd) + .chain(self.individual_md) + .chain(self.multiplier_md) .cloned() .collect() } +} - pub fn h_raw(&self) -> &'a [T] { - self.h +impl<'a> PublicInputCommon for PublicInputs<'a, Target> { + const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES; + + fn register_args(&self, cb: &mut CBuilder) { + cb.register_public_inputs(self.h); + cb.register_public_inputs(self.individual_vd); + cb.register_public_inputs(self.multiplier_vd); + cb.register_public_inputs(self.individual_md); + cb.register_public_inputs(self.multiplier_md); + } +} + +impl<'a> PublicInputs<'a, Target> { + pub fn node_hash_target(&self) -> [Target; NUM_HASH_OUT_ELTS] { + self.to_node_hash_raw().try_into().unwrap() + } + + pub fn individual_values_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.individual_vd) + } + + pub fn multiplier_values_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.multiplier_vd) + } + + pub fn individual_metadata_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.individual_md) + } + + pub fn multiplier_metadata_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.multiplier_md) + } + + pub fn split_values_digest_target(&self) -> SplitDigestTarget { + SplitDigestTarget { + individual: self.individual_values_digest_target(), + multiplier: self.multiplier_values_digest_target(), + } + } + + pub fn split_metadata_digest_target(&self) -> SplitDigestTarget { + SplitDigestTarget { + individual: self.individual_metadata_digest_target(), + multiplier: self.multiplier_metadata_digest_target(), + } } } impl<'a> PublicInputs<'a, F> { - pub fn root_hash_hashout(&self) -> HashOut { - HashOut { - elements: array::from_fn(|i| self.h[i]), + pub fn node_hash(&self) -> HashOut { + HashOut::from_partial(self.to_node_hash_raw()) + } + + pub fn individual_values_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.individual_vd) + } + + pub fn multiplier_values_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.multiplier_vd) + } + + pub fn individual_metadata_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.individual_md) + } + + pub fn multiplier_metadata_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.multiplier_md) + } + + pub fn split_values_digest_point(&self) -> SplitDigestPoint { + SplitDigestPoint { + individual: weierstrass_to_point(&self.individual_values_digest_point()), + multiplier: weierstrass_to_point(&self.multiplier_values_digest_point()), + } + } + + pub fn split_metadata_digest_point(&self) -> SplitDigestPoint { + SplitDigestPoint { + individual: weierstrass_to_point(&&self.individual_metadata_digest_point()), + multiplier: weierstrass_to_point(&self.multiplier_metadata_digest_point()), } } } @@ -138,20 +241,21 @@ mod tests { }; use plonky2_ecgfp5::curve::curve::Point; use rand::thread_rng; + use std::array; #[derive(Clone, Debug)] - struct TestPICircuit<'a> { + struct TestPublicInputs<'a> { exp_pi: &'a [F], } - impl<'a> UserCircuit for TestPICircuit<'a> { + impl<'a> UserCircuit for TestPublicInputs<'a> { type Wires = Vec; fn build(b: &mut CBuilder) -> Self::Wires { - let pi = b.add_virtual_targets(PublicInputs::::TOTAL_LEN); - PublicInputs::from_slice(&pi).register(b); + let exp_pi = b.add_virtual_targets(PublicInputs::::total_len()); + PublicInputs::from_slice(&exp_pi).register(b); - pi + exp_pi } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { @@ -161,21 +265,46 @@ mod tests { #[test] fn test_cells_tree_public_inputs() { - let mut rng = thread_rng(); + let rng = &mut thread_rng(); // Prepare the public inputs. - let h = &random_vector::(NUM_HASH_OUT_ELTS).to_fields(); - let dc = &Point::sample(&mut rng).to_weierstrass().to_fields(); - let exp_pi = PublicInputs { - h, - ind: dc, - mul: dc, - }; + let h = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); + let [individual_vd, multiplier_vd, individual_md, multiplier_md] = + array::from_fn(|_| Point::sample(rng).to_weierstrass().to_fields()); + let exp_pi = PublicInputs::new( + &h, + &individual_vd, + &multiplier_vd, + &individual_md, + &multiplier_md, + ); let exp_pi = &exp_pi.to_vec(); - let test_circuit = TestPICircuit { exp_pi }; + let test_circuit = TestPublicInputs { exp_pi }; let proof = run_circuit::(test_circuit); - assert_eq!(&proof.public_inputs, exp_pi); + + // Check if the public inputs are constructed correctly. + let pi = PublicInputs::from_slice(&proof.public_inputs); + assert_eq!( + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::NodeHash)], + pi.to_node_hash_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::IndividualValuesDigest)], + pi.to_individual_values_digest_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::MultiplierValuesDigest)], + pi.to_multiplier_values_digest_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::IndividualMetadataDigest)], + pi.to_individual_metadata_digest_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::MultiplierMetadataDigest)], + pi.to_multiplier_metadata_digest_raw(), + ); } } diff --git a/verifiable-db/src/revelation/api.rs b/verifiable-db/src/revelation/api.rs index d0581135d..664bfb8d4 100644 --- a/verifiable-db/src/revelation/api.rs +++ b/verifiable-db/src/revelation/api.rs @@ -213,7 +213,7 @@ pub enum CircuitInput< [(); ROW_TREE_MAX_DEPTH - 1]:, [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, - [(); { 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS) }]:, + [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, { NoResultsTree { query_proof: ProofWithVK, diff --git a/verifiable-db/src/row_tree/api.rs b/verifiable-db/src/row_tree/api.rs index 7a17398b5..a56c24fe3 100644 --- a/verifiable-db/src/row_tree/api.rs +++ b/verifiable-db/src/row_tree/api.rs @@ -14,6 +14,7 @@ use super::{ full_node::{self, FullNodeCircuit}, leaf::{self, LeafCircuit}, partial_node::{self, PartialNodeCircuit}, + row::Row, PublicInputs, }; @@ -38,7 +39,7 @@ pub struct PublicParameters { row_set: RecursiveCircuits, } -const ROW_IO_LEN: usize = super::public_inputs::TOTAL_LEN; +const ROW_IO_LEN: usize = super::PublicInputs::::total_len(); impl PublicParameters { pub fn build(cells_set: &RecursiveCircuits) -> Self { @@ -180,18 +181,39 @@ pub enum CircuitInput { } impl CircuitInput { - pub fn leaf(identifier: u64, value: U256, cells_proof: Vec) -> Result { - Self::leaf_multiplier(identifier, value, false, cells_proof) + pub fn leaf( + identifier: u64, + value: U256, + mpt_metadata: HashOut, + row_unique_data: HashOut, + cells_proof: Vec, + ) -> Result { + Self::leaf_multiplier( + identifier, + value, + false, + mpt_metadata, + row_unique_data, + cells_proof, + ) } pub fn leaf_multiplier( identifier: u64, value: U256, is_multiplier: bool, + mpt_metadata: HashOut, + row_unique_data: HashOut, cells_proof: Vec, ) -> Result { - let circuit = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); + let cell = Cell::new( + F::from_canonical_u64(identifier), + value, + is_multiplier, + mpt_metadata, + ); + let row = Row::new(cell, row_unique_data); Ok(CircuitInput::Leaf { - witness: circuit.into(), + witness: row.into(), cells_proof, }) } @@ -199,6 +221,8 @@ impl CircuitInput { pub fn full( identifier: u64, value: U256, + mpt_metadata: HashOut, + row_unique_data: HashOut, left_proof: Vec, right_proof: Vec, cells_proof: Vec, @@ -207,6 +231,8 @@ impl CircuitInput { identifier, value, false, + mpt_metadata, + row_unique_data, left_proof, right_proof, cells_proof, @@ -216,13 +242,21 @@ impl CircuitInput { identifier: u64, value: U256, is_multiplier: bool, + mpt_metadata: HashOut, + row_unique_data: HashOut, left_proof: Vec, right_proof: Vec, cells_proof: Vec, ) -> Result { - let circuit = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); + let cell = Cell::new( + F::from_canonical_u64(identifier), + value, + is_multiplier, + mpt_metadata, + ); + let row = Row::new(cell, row_unique_data); Ok(CircuitInput::Full { - witness: circuit.into(), + witness: row.into(), left_proof, right_proof, cells_proof, @@ -232,6 +266,8 @@ impl CircuitInput { identifier: u64, value: U256, is_child_left: bool, + mpt_metadata: HashOut, + row_unique_data: HashOut, child_proof: Vec, cells_proof: Vec, ) -> Result { @@ -240,6 +276,8 @@ impl CircuitInput { value, false, is_child_left, + mpt_metadata, + row_unique_data, child_proof, cells_proof, ) @@ -249,11 +287,19 @@ impl CircuitInput { value: U256, is_multiplier: bool, is_child_left: bool, + mpt_metadata: HashOut, + row_unique_data: HashOut, child_proof: Vec, cells_proof: Vec, ) -> Result { - let tuple = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); - let witness = PartialNodeCircuit::new(tuple, is_child_left); + let cell = Cell::new( + F::from_canonical_u64(identifier), + value, + is_multiplier, + mpt_metadata, + ); + let row = Row::new(cell, row_unique_data); + let witness = PartialNodeCircuit::new(row, is_child_left); Ok(CircuitInput::Partial { witness, child_proof, @@ -264,9 +310,10 @@ impl CircuitInput { pub fn extract_hash_from_proof(proof: &[u8]) -> Result> { let p = ProofWithVK::deserialize(proof)?; - Ok(PublicInputs::from_slice(&p.proof.public_inputs).root_hash_hashout()) + Ok(PublicInputs::from_slice(&p.proof.public_inputs).root_hash()) } +/* #[cfg(test)] mod test { use crate::{cells_tree, row_tree::public_inputs::PublicInputs}; @@ -533,3 +580,4 @@ mod test { Ok(proof) } } +*/ diff --git a/verifiable-db/src/row_tree/full_node.rs b/verifiable-db/src/row_tree/full_node.rs index d672e6145..4bd983214 100644 --- a/verifiable-db/src/row_tree/full_node.rs +++ b/verifiable-db/src/row_tree/full_node.rs @@ -1,19 +1,15 @@ +use super::row::{Row, RowWire}; +use crate::cells_tree; use derive_more::{From, Into}; use mp2_common::{ - default_config, - group_hashing::{cond_circuit_hashed_scalar_mul, CircuitBuilderGroupHashing}, - poseidon::H, - proof::ProofWithVK, - public_inputs::PublicInputCommon, - u256::CircuitBuilderU256, - utils::ToTargets, - C, D, F, + default_config, group_hashing::CircuitBuilderGroupHashing, poseidon::H, proof::ProofWithVK, + public_inputs::PublicInputCommon, u256::CircuitBuilderU256, utils::ToTargets, C, D, F, }; use plonky2::{ iop::{target::Target, witness::PartialWitness}, plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, }; -use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; +use plonky2_ecdsa::gadgets::biguint::CircuitBuilderBiguint; use recursion_framework::{ circuit_builder::CircuitLogicWires, framework::{ @@ -21,19 +17,17 @@ use recursion_framework::{ }, }; use serde::{Deserialize, Serialize}; -use std::array::from_fn as create_array; - -use crate::cells_tree::{self, Cell, CellWire}; +use std::{array::from_fn as create_array, iter::once}; use super::public_inputs::PublicInputs; // Arity not strictly needed now but may be an easy way to increase performance // easily down the line with less recursion. Best to provide code which is easily // amenable to a different arity rather than hardcoding binary tree only #[derive(Clone, Debug, From, Into)] -pub struct FullNodeCircuit(Cell); +pub struct FullNodeCircuit(Row); #[derive(Clone, Serialize, Deserialize, From, Into)] -pub(crate) struct FullNodeWires(CellWire); +pub(crate) struct FullNodeWires(RowWire); impl FullNodeCircuit { pub(crate) fn build( @@ -42,52 +36,64 @@ impl FullNodeCircuit { right_pi: &[Target], cells_pi: &[Target], ) -> FullNodeWires { - let cells_pi = cells_tree::PublicInputs::from_slice(cells_pi); let min_child = PublicInputs::from_slice(left_pi); let max_child = PublicInputs::from_slice(right_pi); - let tuple = CellWire::new(b); - let node_min = min_child.min_value(); - let node_max = max_child.max_value(); + let cells_pi = cells_tree::PublicInputs::from_slice(cells_pi); + let row = RowWire::new(b); + let id = row.identifier(); + let value = row.value(); + let digest = row.digest(b, &cells_pi); + + // Check multiplier_vd and row_id_multiplier are the same as children proofs. + // assert multiplier_vd == p1.multiplier_vd == p2.multiplier_vd + b.connect_curve_points(digest.multiplier_vd, min_child.multiplier_digest_target()); + b.connect_curve_points(digest.multiplier_vd, max_child.multiplier_digest_target()); + // assert row_id_multiplier == p1.row_id_multiplier == p2.row_id_multiplier + b.connect_biguint( + &digest.row_id_multiplier, + &min_child.row_id_multiplier_target(), + ); + b.connect_biguint( + &digest.row_id_multiplier, + &max_child.row_id_multiplier_target(), + ); + + let node_min = min_child.min_value_target(); + let node_max = max_child.max_value_target(); // enforcing BST property let _true = b._true(); - let left_comparison = b.is_less_or_equal_than_u256(&min_child.max_value(), &tuple.value); - let right_comparison = b.is_less_or_equal_than_u256(&tuple.value, &max_child.min_value()); + let left_comparison = b.is_less_or_equal_than_u256(&min_child.max_value_target(), value); + let right_comparison = b.is_less_or_equal_than_u256(value, &max_child.min_value_target()); b.connect(left_comparison.target, _true.target); b.connect(right_comparison.target, _true.target); // Poseidon(p1.H || p2.H || node_min || node_max || index_id || index_value ||p.H)) as H let inputs = min_child - .root_hash() - .to_targets() + .root_hash_target() .iter() - .chain(max_child.root_hash().to_targets().iter()) + .chain(max_child.root_hash_target().iter()) .chain(node_min.to_targets().iter()) .chain(node_max.to_targets().iter()) - .chain(tuple.to_targets().iter()) - .chain(cells_pi.node_hash().to_targets().iter()) + .chain(once(&id)) + .chain(cells_pi.node_hash_target().iter()) .cloned() .collect::>(); let hash = b.hash_n_to_hash_no_pad::(inputs); - // final_digest = HashToInt(mul_digest) * D(ind_digest) + left.digest() + right.digest() - let split_digest = tuple.split_and_accumulate_digest(b, cells_pi.split_digest_target()); - let (row_digest, is_merge) = split_digest.cond_combine_to_row_digest(b); - - // add this row digest with the rest - let final_digest = b.curve_add(min_child.rows_digest(), max_child.rows_digest()); - let final_digest = b.curve_add(final_digest, row_digest); // assert `is_merge` is the same as the flags in children pis - b.connect(min_child.is_merge_case().target, is_merge.target); - b.connect(max_child.is_merge_case().target, is_merge.target); + b.connect(min_child.merge_flag_target().target, digest.is_merge.target); + b.connect(max_child.merge_flag_target().target, digest.is_merge.target); PublicInputs::new( &hash.to_targets(), - &final_digest.to_targets(), + &digest.individual_vd.to_targets(), + &digest.multiplier_vd.to_targets(), + &digest.row_id_multiplier.to_targets(), &node_min.to_targets(), &node_max.to_targets(), - &[is_merge.target], + &[digest.is_merge.target], ) .register(b); - FullNodeWires(tuple) + FullNodeWires(row) } fn assign(&self, pw: &mut PartialWitness, wires: &FullNodeWires) { self.0.assign_wires(pw, &wires.0); @@ -113,14 +119,14 @@ impl CircuitLogicWires for RecursiveFullWires { type Inputs = RecursiveFullInput; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CircuitBuilder, verified_proofs: [&ProofWithPublicInputsTarget; NUM_CHILDREN], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - const CELLS_IO: usize = cells_tree::PublicInputs::::TOTAL_LEN; + const CELLS_IO: usize = cells_tree::PublicInputs::::total_len(); let verifier_gadget = RecursiveCircuitsVerifierGagdet::::new( default_config(), &builder_parameters, @@ -144,6 +150,7 @@ impl CircuitLogicWires for RecursiveFullWires { } } +/* #[cfg(test)] pub(crate) mod test { @@ -185,9 +192,9 @@ pub(crate) mod test { type Wires = (FullNodeWires, Vec, Vec, Vec); fn build(c: &mut CircuitBuilder) -> Self::Wires { - let cells_pi = c.add_virtual_targets(cells_tree::PublicInputs::::TOTAL_LEN); - let left_pi = c.add_virtual_targets(PublicInputs::::TOTAL_LEN); - let right_pi = c.add_virtual_targets(PublicInputs::::TOTAL_LEN); + let cells_pi = c.add_virtual_targets(cells_tree::PublicInputs::::total_len()); + let left_pi = c.add_virtual_targets(PublicInputs::::total_len()); + let right_pi = c.add_virtual_targets(PublicInputs::::total_len()); ( FullNodeCircuit::build(c, &left_pi, &right_pi, &cells_pi), left_pi, @@ -302,3 +309,4 @@ pub(crate) mod test { test_row_tree_full_circuit(true, true); } } +*/ diff --git a/verifiable-db/src/row_tree/gadgets/row_digest_gadget.rs b/verifiable-db/src/row_tree/gadgets/row_digest_gadget.rs new file mode 100644 index 000000000..dafb2193d --- /dev/null +++ b/verifiable-db/src/row_tree/gadgets/row_digest_gadget.rs @@ -0,0 +1,227 @@ +use crate::cells_tree::{self, CellWire}; +use itertools::Itertools; +use mp2_common::{ + group_hashing::CircuitBuilderGroupHashing, + poseidon::{empty_poseidon_hash, hash_to_int_target, H, HASH_TO_INT_LEN}, + types::CBuilder, + utils::ToTargets, +}; +use plonky2::{ + hash::hash_types::HashOutTarget, + iop::target::{BoolTarget, Target}, +}; +use plonky2_ecdsa::gadgets::nonnative::CircuitBuilderNonNative; +use plonky2_ecgfp5::gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}; + +#[derive(Debug)] +pub(crate) struct RowDigestGadget<'a> { + // - `p :` - cells proof for the row associated to the current node (from `cells_build_set`) + cells_pi: cells_tree::PublicInputs<'a, Target>, + // - `is_individual : bool` - Flag specifying whether the secondary index cell should be accumulated in `individual` or `multiplier` digest + is_multiplier: BoolTarget, + // - `mpt_metadata : [4]F` - Hash of the metadata associated to the secondary index cell, as computed in MPT extraction circuits + mpt_metadata: &'a HashOutTarget, + // - `row_unique_data : Hash` : Row unique data employed to compute the row id for individual cells, the same one employed in MPT extraction circuits + row_unique_data: &'a HashOutTarget, + current_cell: &'a CellWire, +} + +impl<'a> RowDigestGadget<'a> { + pub(crate) fn new( + cells_pi: cells_tree::PublicInputs<'a, Target>, + is_multiplier: BoolTarget, + mpt_metadata: &'a HashOutTarget, + row_unique_data: &'a HashOutTarget, + current_cell: &'a CellWire, + ) -> Self { + Self { + cells_pi, + is_multiplier, + mpt_metadata, + row_unique_data, + current_cell, + } + } + + pub(crate) fn compute_row_digest( + &self, + b: &mut CBuilder, + ) -> (CurveTarget, CurveTarget, [Target; HASH_TO_INT_LEN]) { + let (individual_vd, multiplier_vd) = + self.current_cell.individual_multiplier_values_digests(b); + let (individual_md, multiplier_md) = + self.current_cell.individual_multiplier_metadata_digests(b); + + let individual_vd = b.add_curve_point(&[ + individual_vd, + self.cells_pi.individual_values_digest_target(), + ]); + let multiplier_vd = b.add_curve_point(&[ + multiplier_vd, + self.cells_pi.multiplier_values_digest_target(), + ]); + let individual_md = b.add_curve_point(&[ + individual_md, + self.cells_pi.individual_metadata_digest_target(), + ]); + let multiplier_md = b.add_curve_point(&[ + multiplier_md, + self.cells_pi.multiplier_metadata_digest_target(), + ]); + + // # compute row id for individual cells + // row_id_individual = H2Int(row_unique_data || individual_md) + let inputs = self + .row_unique_data + .to_targets() + .into_iter() + .chain(individual_md.to_targets()) + .collect(); + let hash = b.hash_n_to_hash_no_pad::(inputs); + let row_id_individual = hash_to_int_target(b, hash); + let row_id_individual = b.biguint_to_nonnative(&row_id_individual); + + // # multiply row id to individual value digest + // individual_vd = row_id_individual * individual_vd # scalar mul + let individual_vd = b.curve_scalar_mul(individual_vd, &row_id_individual); + + // # multiplier is always employed for set of scalar variables, and the + // # row_unique_data for such a set is always H(""), so we can hardocode it + // # in the circuit + // row_id_multiplier = H2Int(H("") || multiplier_md) + let empty_hash = b.constant_hash(*empty_poseidon_hash()); + let inputs = empty_hash + .to_targets() + .into_iter() + .chain(multiplier_md.to_targets()) + .collect(); + let hash = b.hash_n_to_hash_no_pad::(inputs); + let row_id_multiplier = hash_to_int_target(b, hash); + assert_eq!(row_id_multiplier.num_limbs(), HASH_TO_INT_LEN); + let row_id_multiplier = row_id_multiplier + .limbs + .into_iter() + .map(|u| u.0) + .collect_vec() + .try_into() + .unwrap(); + + (individual_vd, multiplier_vd, row_id_multiplier) + } +} + +/* +#[cfg(test)] +pub(crate) mod tests { + use super::{super::column_info::ColumnInfoTarget, *}; + use crate::{ + tests::TEST_MAX_FIELD_PER_EVM, + values_extraction::gadgets::column_info::{ + CircuitBuilderColumnInfo, WitnessWriteColumnInfo, + }, + }; + use mp2_common::{C, D}; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::iop::witness::{PartialWitness, WitnessWrite}; + use plonky2_ecgfp5::gadgets::curve::PartialWitnessCurve; + + #[derive(Clone, Debug)] + pub(crate) struct ColumnGadgetTarget { + value: [Target; MAPPING_LEAF_VALUE_LEN], + table_info: [ColumnInfoTarget; MAX_FIELD_PER_EVM], + is_extracted_columns: [BoolTarget; MAX_FIELD_PER_EVM], + } + + impl ColumnGadgetTarget { + fn column_gadget(&self) -> ColumnGadget { + ColumnGadget::new(&self.value, &self.table_info, &self.is_extracted_columns) + } + } + + pub(crate) trait CircuitBuilderColumnGadget { + /// Add a virtual column gadget target. + fn add_virtual_column_gadget_target( + &mut self, + ) -> ColumnGadgetTarget; + } + + impl CircuitBuilderColumnGadget for CBuilder { + fn add_virtual_column_gadget_target( + &mut self, + ) -> ColumnGadgetTarget { + let value = self.add_virtual_target_arr(); + let table_info = array::from_fn(|_| self.add_virtual_column_info()); + let is_extracted_columns = array::from_fn(|_| self.add_virtual_bool_target_safe()); + + ColumnGadgetTarget { + value, + table_info, + is_extracted_columns, + } + } + } + + pub(crate) trait WitnessWriteColumnGadget { + fn set_column_gadget_target( + &mut self, + target: &ColumnGadgetTarget, + value: &ColumnGadgetData, + ); + } + + impl> WitnessWriteColumnGadget for T { + fn set_column_gadget_target( + &mut self, + target: &ColumnGadgetTarget, + data: &ColumnGadgetData, + ) { + self.set_target_arr(&target.value, &data.value); + self.set_column_info_target_arr(&target.table_info, &data.table_info); + target + .is_extracted_columns + .iter() + .enumerate() + .for_each(|(i, t)| self.set_bool_target(*t, i < data.num_extracted_columns)); + } + } + + #[derive(Clone, Debug)] + struct TestColumnGadgetCircuit { + column_gadget_data: ColumnGadgetData, + expected_column_digest: Point, + } + + impl UserCircuit for TestColumnGadgetCircuit { + // Column gadget target + expected column digest + type Wires = (ColumnGadgetTarget, CurveTarget); + + fn build(b: &mut CBuilder) -> Self::Wires { + let column_gadget_target = b.add_virtual_column_gadget_target(); + let expected_column_digest = b.add_virtual_curve_target(); + + let column_digest = column_gadget_target.column_gadget().build(b); + b.connect_curve_points(column_digest, expected_column_digest); + + (column_gadget_target, expected_column_digest) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + pw.set_column_gadget_target(&wires.0, &self.column_gadget_data); + pw.set_curve_target(wires.1, self.expected_column_digest.to_weierstrass()); + } + } + + #[test] + fn test_values_extraction_column_gadget() { + let column_gadget_data = ColumnGadgetData::sample(); + let expected_column_digest = column_gadget_data.digest(); + + let test_circuit = TestColumnGadgetCircuit { + column_gadget_data, + expected_column_digest, + }; + + let _ = run_circuit::(test_circuit); + } +} +*/ diff --git a/verifiable-db/src/row_tree/leaf.rs b/verifiable-db/src/row_tree/leaf.rs index 4d6e0a4d9..e9c6a34f6 100644 --- a/verifiable-db/src/row_tree/leaf.rs +++ b/verifiable-db/src/row_tree/leaf.rs @@ -1,7 +1,11 @@ +use super::{ + public_inputs::PublicInputs, + row::{Row, RowWire}, +}; +use crate::cells_tree; use derive_more::{From, Into}; use mp2_common::{ default_config, - group_hashing::{cond_circuit_hashed_scalar_mul, CircuitBuilderGroupHashing}, poseidon::{empty_poseidon_hash, H}, proof::ProofWithVK, public_inputs::PublicInputCommon, @@ -9,10 +13,7 @@ use mp2_common::{ C, D, F, }; use plonky2::{ - iop::{ - target::{BoolTarget, Target}, - witness::PartialWitness, - }, + iop::{target::Target, witness::PartialWitness}, plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, }; use recursion_framework::{ @@ -22,57 +23,50 @@ use recursion_framework::{ }, }; use serde::{Deserialize, Serialize}; - -use crate::cells_tree::{self, Cell, CellWire}; - -use super::public_inputs::PublicInputs; +use std::iter::once; // new type to implement the circuit logic on each differently // deref to access directly the same members - read only so it's ok #[derive(Clone, Debug, From, Into)] -pub struct LeafCircuit(Cell); +pub struct LeafCircuit(Row); #[derive(Clone, Serialize, Deserialize, From, Into)] -pub(crate) struct LeafWires(CellWire); +pub(crate) struct LeafWires(RowWire); impl LeafCircuit { pub(crate) fn build(b: &mut CircuitBuilder, cells_pis: &[Target]) -> LeafWires { let cells_pis = cells_tree::PublicInputs::from_slice(cells_pis); - // D(index_id||pack_u32(index_value) - let tuple = CellWire::new(b); - // set the right digest depending on the multiplier and accumulate the ones from the public - // inputs of the cell root proof - let split_digest = tuple.split_and_accumulate_digest(b, cells_pis.split_digest_target()); - // final_digest = HashToInt(D(mul_digest)) * D(ind_digest) - // NOTE This additional digest is necessary since the individual digest is supposed to be a - // full row, that is how it is extracted from MPT - let (final_digest, is_merge) = split_digest.cond_combine_to_row_digest(b); + let row = RowWire::new(b); + let id = row.identifier(); + let value = row.value().to_targets(); + let digest = row.digest(b, &cells_pis); // H(left_child_hash,right_child_hash,min,max,index_identifier,index_value,cells_tree_hash) // in our case, min == max == index_value // left_child_hash == right_child_hash == empty_hash since there is not children - let empty_hash = b.constant_hash(*empty_poseidon_hash()); + let empty_hash = b.constant_hash(*empty_poseidon_hash()).to_targets(); let inputs = empty_hash - .to_targets() - .iter() - .chain(empty_hash.to_targets().iter()) - .chain(tuple.value.to_targets().iter()) - .chain(tuple.value.to_targets().iter()) - .chain(tuple.to_targets().iter()) - .chain(cells_pis.node_hash().to_targets().iter()) - .cloned() + .clone() + .into_iter() + .chain(empty_hash) + .chain(value.clone()) + .chain(value.clone()) + .chain(once(id)) + .chain(cells_pis.node_hash_target()) .collect::>(); let row_hash = b.hash_n_to_hash_no_pad::(inputs); - let value_fields = tuple.value.to_targets(); PublicInputs::new( &row_hash.elements, - &final_digest.to_targets(), - &value_fields, - &value_fields, - &[is_merge.target], + &digest.individual_vd.to_targets(), + &digest.multiplier_vd.to_targets(), + &digest.row_id_multiplier.to_targets(), + &value, + &value, + &[digest.is_merge.target], ) .register(b); - LeafWires(tuple) + + LeafWires(row) } fn assign(&self, pw: &mut PartialWitness, wires: &LeafWires) { @@ -102,14 +96,14 @@ impl CircuitLogicWires for RecursiveLeafWires { type Inputs = RecursiveLeafInput; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CircuitBuilder, _verified_proofs: [&ProofWithPublicInputsTarget; 0], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - const CELLS_IO: usize = cells_tree::PublicInputs::::TOTAL_LEN; + const CELLS_IO: usize = cells_tree::PublicInputs::::total_len(); let verifier_gadget = RecursiveCircuitsVerifierGagdet::::new( default_config(), &builder_parameters, @@ -131,6 +125,7 @@ impl CircuitLogicWires for RecursiveLeafWires { } } +/* #[cfg(test)] mod test { @@ -243,3 +238,4 @@ mod test { test_row_tree_leaf_circuit(true, true); } } +*/ diff --git a/verifiable-db/src/row_tree/mod.rs b/verifiable-db/src/row_tree/mod.rs index c45a72292..82f0247e5 100644 --- a/verifiable-db/src/row_tree/mod.rs +++ b/verifiable-db/src/row_tree/mod.rs @@ -1,26 +1,9 @@ -use alloy::primitives::U256; -use derive_more::Constructor; -use mp2_common::{ - group_hashing::CircuitBuilderGroupHashing, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::{ToFields, ToTargets}, - D, F, -}; -use plonky2::{ - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::circuit_builder::CircuitBuilder, -}; -use plonky2_ecgfp5::gadgets::curve::CurveTarget; -use serde::{Deserialize, Serialize}; - mod api; mod full_node; mod leaf; mod partial_node; mod public_inputs; +mod row; pub use api::{extract_hash_from_proof, CircuitInput, PublicParameters}; pub use public_inputs::PublicInputs; diff --git a/verifiable-db/src/row_tree/partial_node.rs b/verifiable-db/src/row_tree/partial_node.rs index 00af074c8..2b2d6bde2 100644 --- a/verifiable-db/src/row_tree/partial_node.rs +++ b/verifiable-db/src/row_tree/partial_node.rs @@ -1,8 +1,8 @@ -use plonky2::plonk::proof::ProofWithPublicInputsTarget; - +use super::row::{Row, RowWire}; +use crate::cells_tree; use mp2_common::{ default_config, - group_hashing::{cond_circuit_hashed_scalar_mul, CircuitBuilderGroupHashing}, + group_hashing::CircuitBuilderGroupHashing, hash::hash_maybe_first, poseidon::empty_poseidon_hash, proof::ProofWithVK, @@ -18,9 +18,9 @@ use plonky2::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, - plonk::circuit_builder::CircuitBuilder, + plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, }; -use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; +use plonky2_ecdsa::gadgets::biguint::CircuitBuilderBiguint; use recursion_framework::{ circuit_builder::CircuitLogicWires, framework::{ @@ -28,28 +28,27 @@ use recursion_framework::{ }, }; use serde::{Deserialize, Serialize}; - -use crate::cells_tree::{self, Cell, CellWire}; +use std::iter::once; use super::public_inputs::PublicInputs; #[derive(Clone, Debug)] pub struct PartialNodeCircuit { - pub(crate) tuple: Cell, + pub(crate) row: Row, pub(crate) is_child_at_left: bool, } #[derive(Clone, Debug, Serialize, Deserialize)] struct PartialNodeWires { - tuple: CellWire, + row: RowWire, #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] is_child_at_left: BoolTarget, } impl PartialNodeCircuit { - pub(crate) fn new(tuple: Cell, is_child_at_left: bool) -> Self { + pub(crate) fn new(row: Row, is_child_at_left: bool) -> Self { Self { - tuple, + row, is_child_at_left, } } @@ -58,22 +57,35 @@ impl PartialNodeCircuit { child_pi: &[Target], cells_pi: &[Target], ) -> PartialNodeWires { + let child_pi = PublicInputs::from_slice(child_pi); let cells_pi = cells_tree::PublicInputs::from_slice(cells_pi); - let tuple = CellWire::new(b); + let row = RowWire::new(b); + let id = row.identifier(); + let value = row.value(); + let digest = row.digest(b, &cells_pi); + + // Check multiplier_vd and row_id_multiplier are the same as child proof + // assert multiplier_vd == child_proof.multiplier_vd + b.connect_curve_points(digest.multiplier_vd, child_pi.multiplier_digest_target()); + //assert row_id_multiplier == child_proof.row_id_multiplier + b.connect_biguint( + &digest.row_id_multiplier, + &child_pi.row_id_multiplier_target(), + ); + // bool target range checked in poseidon gate let is_child_at_left = b.add_virtual_bool_target_unsafe(); - let child_pi = PublicInputs::from_slice(child_pi); // max_left = left ? child_proof.max : index_value // min_right = left ? index_value : child_proof.min - let max_left = b.select_u256(is_child_at_left, &child_pi.max_value(), &tuple.value); - let min_right = b.select_u256(is_child_at_left, &tuple.value, &child_pi.min_value()); + let max_left = b.select_u256(is_child_at_left, &child_pi.max_value_target(), value); + let min_right = b.select_u256(is_child_at_left, value, &child_pi.min_value_target()); let bst_enforced = b.is_less_or_equal_than_u256(&max_left, &min_right); let _true = b._true(); b.connect(bst_enforced.target, _true.target); // node_min = left ? child_proof.min : index_value // node_max = left ? index_value : child_proof.max - let node_min = b.select_u256(is_child_at_left, &child_pi.min_value(), &tuple.value); - let node_max = b.select_u256(is_child_at_left, &tuple.value, &child_pi.max_value()); + let node_min = b.select_u256(is_child_at_left, &child_pi.min_value_target(), value); + let node_max = b.select_u256(is_child_at_left, value, &child_pi.max_value_target()); let empty_hash = b.constant_hash(*empty_poseidon_hash()); // left_hash = left ? child_proof.H : H("") @@ -85,8 +97,8 @@ impl PartialNodeCircuit { .to_targets() .iter() .chain(node_max.to_targets().iter()) - .chain(tuple.to_targets().iter()) - .chain(cells_pi.node_hash().to_targets().iter()) + .chain(once(&id)) + .chain(cells_pi.node_hash_target().iter()) .cloned() .collect::>(); // if child at left, then hash should be child_proof.H || H("") || rest @@ -95,34 +107,31 @@ impl PartialNodeCircuit { b, is_child_at_left, empty_hash.elements, - child_pi.root_hash().elements, + child_pi.root_hash_target(), &rest, ); - // final_digest = HashToInt(mul_digest) * D(ind_digest) - let split_digest = tuple.split_and_accumulate_digest(b, cells_pi.split_digest_target()); - let (row_digest, is_merge) = split_digest.cond_combine_to_row_digest(b); - - // and add the digest of the row other rows - let final_digest = b.curve_add(child_pi.rows_digest(), row_digest); // assert is_merge is the same between this row and `child_pi` - b.connect(is_merge.target, child_pi.is_merge_case().target); + b.connect(digest.is_merge.target, child_pi.merge_flag_target().target); + PublicInputs::new( &node_hash, - &final_digest.to_targets(), + &digest.individual_vd.to_targets(), &node_min.to_targets(), &node_max.to_targets(), - &[is_merge.target], + &[digest.is_merge.target], + &digest.multiplier_vd.to_targets(), + &digest.row_id_multiplier.to_targets(), ) .register(b); PartialNodeWires { - tuple, + row, is_child_at_left, } } fn assign(&self, pw: &mut PartialWitness, wires: &PartialNodeWires) { - self.tuple.assign_wires(pw, &wires.tuple); + self.row.assign_wires(pw, &wires.row); pw.set_bool_target(wires.is_child_at_left, self.is_child_at_left); } } @@ -145,14 +154,14 @@ impl CircuitLogicWires for RecursivePartialWires { type Inputs = RecursivePartialInput; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CircuitBuilder, verified_proofs: [&ProofWithPublicInputsTarget; NUM_CHILDREN], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - const CELLS_IO: usize = cells_tree::PublicInputs::::TOTAL_LEN; + const CELLS_IO: usize = cells_tree::PublicInputs::::total_len(); let verifier_gadget = RecursiveCircuitsVerifierGagdet::::new( default_config(), &builder_parameters, @@ -175,6 +184,7 @@ impl CircuitLogicWires for RecursivePartialWires { } } +/* #[cfg(test)] pub mod test { use mp2_common::{ @@ -322,8 +332,8 @@ pub mod test { // node_min = left ? child_proof.min : index_value // node_max = left ? index_value : child_proof.max let (node_min, node_max) = match child_at_left { - true => (pi.min_value_u256(), tuple.value), - false => (tuple.value, pi.max_value_u256()), + true => (pi.min_value(), tuple.value), + false => (tuple.value, pi.max_value()), }; // Poseidon(p1.H || p2.H || node_min || node_max || index_id || index_value ||p.H)) as H let child_hash = PublicInputs::from_slice(&child_pi).root_hash_hashout(); @@ -352,3 +362,4 @@ pub mod test { assert_eq!(split_digest.is_merge_case(), pi.is_merge_flag()); } } +*/ diff --git a/verifiable-db/src/row_tree/public_inputs.rs b/verifiable-db/src/row_tree/public_inputs.rs index a775af1af..06eb726d1 100644 --- a/verifiable-db/src/row_tree/public_inputs.rs +++ b/verifiable-db/src/row_tree/public_inputs.rs @@ -1,183 +1,299 @@ //! Public inputs for rows trees creation circuits -//! + use alloy::primitives::U256; +use itertools::Itertools; use mp2_common::{ + poseidon::HASH_TO_INT_LEN, public_inputs::{PublicInputCommon, PublicInputRange}, - types::CURVE_TARGET_LEN, + types::{CBuilder, CURVE_TARGET_LEN}, u256::{self, UInt256Target}, utils::{FromFields, FromTargets, TryIntoBool}, - D, F, + F, }; +use num::BigUint; use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, + field::types::PrimeField64, + hash::hash_types::{HashOut, NUM_HASH_OUT_ELTS}, iop::target::{BoolTarget, Target}, - plonk::circuit_builder::CircuitBuilder, }; +use plonky2_crypto::u32::arithmetic_u32::U32Target; +use plonky2_ecdsa::gadgets::biguint::BigUintTarget; use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; -use std::array::from_fn as create_array; - -// Contract extraction public Inputs: -// - `H : [4]F` : Poseidon hash of the leaf -// - `DR : Digest[F]` : accumulated digest of all the rows up to this node -// - `min : Uint256` : min value of the secondary index stored up to this node -// - `max : Uint256` : max value of the secondary index stored up to this node -// - `merge : bool` : Flag specifying whether we are building rows for a merge table or not -const H_RANGE: PublicInputRange = 0..NUM_HASH_OUT_ELTS; -const DR_RANGE: PublicInputRange = H_RANGE.end..H_RANGE.end + CURVE_TARGET_LEN; -const MIN_RANGE: PublicInputRange = DR_RANGE.end..DR_RANGE.end + u256::NUM_LIMBS; -const MAX_RANGE: PublicInputRange = MIN_RANGE.end..MIN_RANGE.end + u256::NUM_LIMBS; -const MERGE_RANGE: PublicInputRange = MAX_RANGE.end..MAX_RANGE.end + 1; - -/// Public inputs for contract extraction +use std::iter::once; + +pub enum RowsTreePublicInputs { + // `H : F[4]` - Poseidon hash of the leaf + RootHash, + // `individual_digest : Digest` - Cumulative digest of the values of the cells which are accumulated in individual digest + IndividualDigest, + // `multiplier_digest : Digest` - Cumulative digest of the values of the cells which are accumulated in multiplier digest + MultiplierDigest, + // `row_id_multiplier : F[4]` - `H2Int(H("") || multiplier_md)`, where `multiplier_md` is the metadata digest of cells accumulated in `multiplier_digest` + RowIdMultiplier, + // `min : Uint256` - Minimum alue of the secondary index stored up to this node + MinValue, + // `max : Uint256` - Maximum value of the secondary index stored up to this node + MaxValue, + // `merge : bool` - Flag specifying whether we are building rows for a merge table or not + MergeFlag, +} + +/// Public inputs for Rows Tree Construction #[derive(Clone, Debug)] pub struct PublicInputs<'a, T> { pub(crate) h: &'a [T], - pub(crate) dr: &'a [T], + pub(crate) individual_digest: &'a [T], + pub(crate) multiplier_digest: &'a [T], + pub(crate) row_id_multiplier: &'a [T], pub(crate) min: &'a [T], pub(crate) max: &'a [T], - pub(crate) merge: &'a [T], + pub(crate) merge: &'a T, } -impl<'a> PublicInputCommon for PublicInputs<'a, Target> { - const RANGES: &'static [PublicInputRange] = - &[H_RANGE, DR_RANGE, MIN_RANGE, MAX_RANGE, MERGE_RANGE]; +const NUM_PUBLIC_INPUTS: usize = RowsTreePublicInputs::MergeFlag as usize + 1; - fn register_args(&self, cb: &mut CircuitBuilder) { - cb.register_public_inputs(self.h); - cb.register_public_inputs(self.dr); - cb.register_public_inputs(self.min); - cb.register_public_inputs(self.max); - cb.register_public_input(self.merge[0]); - } -} +impl<'a, T: Clone> PublicInputs<'a, T> { + const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ + Self::to_range(RowsTreePublicInputs::RootHash), + Self::to_range(RowsTreePublicInputs::IndividualDigest), + Self::to_range(RowsTreePublicInputs::MultiplierDigest), + Self::to_range(RowsTreePublicInputs::RowIdMultiplier), + Self::to_range(RowsTreePublicInputs::MinValue), + Self::to_range(RowsTreePublicInputs::MaxValue), + Self::to_range(RowsTreePublicInputs::MergeFlag), + ]; -// mostly used for testing -impl<'a> PublicInputs<'a, F> { - /// Get the metadata point. - pub fn rows_digest_field(&self) -> WeierstrassPoint { - WeierstrassPoint::from_fields(self.dr) - } - /// minimum index value - pub fn min_value_u256(&self) -> U256 { - U256::from_fields(self.min) - } - /// maximum index value - pub fn max_value_u256(&self) -> U256 { - U256::from_fields(self.max) - } - /// hash of the subtree at this node - pub fn root_hash_hashout(&self) -> HashOut { - HashOut { - elements: create_array(|i| self.h[i]), + const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ + // Poseidon hash of the leaf + NUM_HASH_OUT_ELTS, + // Cumulative digest of the values of the cells which are accumulated in individual digest + CURVE_TARGET_LEN, + // Cumulative digest of the values of the cells which are accumulated in multiplier digest + CURVE_TARGET_LEN, + // `H2Int(H("") || multiplier_md)`, where `multiplier_md` is the metadata digest of cells accumulated in `multiplier_digest` + HASH_TO_INT_LEN, + // Minimum alue of the secondary index stored up to this node + u256::NUM_LIMBS, + // Maximum value of the secondary index stored up to this node + u256::NUM_LIMBS, + // Flag specifying whether we are building rows for a merge table or not + 1, + ]; + + pub(crate) const fn to_range(pi: RowsTreePublicInputs) -> PublicInputRange { + let mut i = 0; + let mut offset = 0; + let pi_pos = pi as usize; + while i < pi_pos { + offset += Self::SIZES[i]; + i += 1; } + offset..offset + Self::SIZES[pi_pos] } - pub fn is_merge_flag(&self) -> bool { - self.merge[0].try_into_bool().unwrap() + pub(crate) const fn total_len() -> usize { + Self::to_range(RowsTreePublicInputs::RowIdMultiplier).end } -} -impl<'a> PublicInputs<'a, Target> { - /// Get the hash corresponding to the root of the subtree of this node - pub fn root_hash(&self) -> HashOutTarget { - HashOutTarget::from_targets(self.h) + pub(crate) fn to_root_hash_raw(&self) -> &[T] { + self.h } - pub fn rows_digest(&self) -> CurveTarget { - let dv = self.dr; - CurveTarget::from_targets(dv) + pub(crate) fn to_individual_digest_raw(&self) -> &[T] { + self.individual_digest } - pub fn min_value(&self) -> UInt256Target { - UInt256Target::from_targets(self.min) + pub(crate) fn to_multiplier_digest_raw(&self) -> &[T] { + self.multiplier_digest } - pub fn max_value(&self) -> UInt256Target { - UInt256Target::from_targets(self.max) + + pub(crate) fn to_row_id_multiplier_raw(&self) -> &[T] { + self.row_id_multiplier } - pub fn is_merge_case(&self) -> BoolTarget { - BoolTarget::new_unsafe(self.merge[0]) + pub(crate) fn to_min_value_raw(&self) -> &[T] { + self.min } -} -pub const TOTAL_LEN: usize = PublicInputs::::TOTAL_LEN; + pub(crate) fn to_max_value_raw(&self) -> &[T] { + self.max + } -impl<'a, T: Copy> PublicInputs<'a, T> { - /// Total length of the public inputs - pub(crate) const TOTAL_LEN: usize = MERGE_RANGE.end; + pub(crate) fn to_merge_flag_raw(&self) -> &T { + self.merge + } - /// Create from a slice. - pub fn from_slice(pi: &'a [T]) -> Self { - assert!(pi.len() >= Self::TOTAL_LEN); + pub fn from_slice(input: &'a [T]) -> Self { + assert!( + input.len() >= Self::total_len(), + "Input slice too short to build rows tree public inputs, must be at least {} elements", + Self::total_len(), + ); Self { - h: &pi[H_RANGE], - dr: &pi[DR_RANGE], - min: &pi[MIN_RANGE], - max: &pi[MAX_RANGE], - merge: &pi[MERGE_RANGE], + h: &input[Self::PI_RANGES[0].clone()], + individual_digest: &input[Self::PI_RANGES[1].clone()], + multiplier_digest: &input[Self::PI_RANGES[2].clone()], + row_id_multiplier: &input[Self::PI_RANGES[3].clone()], + min: &input[Self::PI_RANGES[4].clone()], + max: &input[Self::PI_RANGES[5].clone()], + merge: &input[Self::PI_RANGES[6].clone()][0], } } - /// Create a new public inputs. - pub fn new(h: &'a [T], dr: &'a [T], min: &'a [T], max: &'a [T], merge: &'a [T]) -> Self { - assert_eq!(h.len(), NUM_HASH_OUT_ELTS); - assert_eq!(dr.len(), CURVE_TARGET_LEN); - assert_eq!(min.len(), u256::NUM_LIMBS); - assert_eq!(max.len(), u256::NUM_LIMBS); - assert_eq!(merge.len(), 1); + pub fn new( + h: &'a [T], + individual_digest: &'a [T], + multiplier_digest: &'a [T], + row_id_multiplier: &'a [T], + min: &'a [T], + max: &'a [T], + merge: &'a [T], + ) -> Self { Self { h, - dr, + individual_digest, + multiplier_digest, + row_id_multiplier, min, max, - merge, + merge: &merge[0], } } - /// Combine to a vector. pub fn to_vec(&self) -> Vec { self.h .iter() - .chain(self.dr) + .chain(self.individual_digest) + .chain(self.multiplier_digest) + .chain(self.row_id_multiplier) .chain(self.min) .chain(self.max) - .chain(self.merge) + .chain(once(self.merge)) .cloned() .collect() } } +impl<'a> PublicInputCommon for PublicInputs<'a, Target> { + const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES; + + fn register_args(&self, cb: &mut CBuilder) { + cb.register_public_inputs(self.h); + cb.register_public_inputs(self.individual_digest); + cb.register_public_inputs(self.multiplier_digest); + cb.register_public_inputs(self.row_id_multiplier); + cb.register_public_inputs(self.min); + cb.register_public_inputs(self.max); + cb.register_public_input(*self.merge); + } +} + +impl<'a> PublicInputs<'a, Target> { + pub fn root_hash_target(&self) -> [Target; NUM_HASH_OUT_ELTS] { + self.to_root_hash_raw().try_into().unwrap() + } + + pub fn individual_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.individual_digest) + } + + pub fn multiplier_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.multiplier_digest) + } + + pub fn row_id_multiplier_target(&self) -> BigUintTarget { + let limbs = self + .row_id_multiplier + .iter() + .cloned() + .map(U32Target) + .collect(); + + BigUintTarget { limbs } + } + + pub fn min_value_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.min) + } + + pub fn max_value_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.max) + } + + pub fn merge_flag_target(&self) -> BoolTarget { + BoolTarget::new_unsafe(*self.merge) + } +} + +impl<'a> PublicInputs<'a, F> { + pub fn root_hash(&self) -> HashOut { + HashOut::from_partial(self.h) + } + + pub fn individual_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.individual_digest) + } + + pub fn multiplier_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.multiplier_digest) + } + + pub fn row_id_multiplier(&self) -> BigUint { + let limbs = self + .row_id_multiplier + .iter() + .map(|f| u32::try_from(f.to_canonical_u64()).unwrap()) + .collect_vec(); + + BigUint::from_slice(&limbs) + } + + pub fn min_value(&self) -> U256 { + U256::from_fields(self.min) + } + + pub fn max_value(&self) -> U256 { + U256::from_fields(self.max) + } + + pub fn merge_flag(&self) -> bool { + self.merge.try_into_bool().unwrap() + } +} + #[cfg(test)] mod tests { use super::*; - use alloy::primitives::U256; - use mp2_common::{public_inputs::PublicInputCommon, utils::ToFields, C, D, F}; - use mp2_test::circuit::{run_circuit, UserCircuit}; + use mp2_common::{utils::ToFields, C, D, F}; + use mp2_test::{ + circuit::{run_circuit, UserCircuit}, + utils::random_vector, + }; use plonky2::{ field::types::{Field, Sample}, iop::{ target::Target, witness::{PartialWitness, WitnessWrite}, }, - plonk::config::GenericHashOut, }; use plonky2_ecgfp5::curve::curve::Point; use rand::{thread_rng, Rng}; + use std::{array, slice}; #[derive(Clone, Debug)] - struct TestPICircuit<'a> { + struct TestPublicInputs<'a> { exp_pi: &'a [F], } - impl<'a> UserCircuit for TestPICircuit<'a> { + impl<'a> UserCircuit for TestPublicInputs<'a> { type Wires = Vec; - fn build(b: &mut CircuitBuilder) -> Self::Wires { - let pi = b.add_virtual_targets(PublicInputs::::TOTAL_LEN); - let pi = PublicInputs::from_slice(&pi); - pi.register(b); - pi.to_vec() + fn build(b: &mut CBuilder) -> Self::Wires { + let exp_pi = b.add_virtual_targets(PublicInputs::::total_len()); + PublicInputs::from_slice(&exp_pi).register(b); + + exp_pi } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { @@ -187,21 +303,59 @@ mod tests { #[test] fn test_rows_tree_public_inputs() { - let mut rng = thread_rng(); + let rng = &mut thread_rng(); // Prepare the public inputs. - let h = HashOut::rand().to_vec(); - let dr = Point::sample(&mut rng); - let drw = dr.to_weierstrass().to_fields(); - let min = U256::from_limbs(rng.gen::<[u64; 4]>()).to_fields(); - let max = U256::from_limbs(rng.gen::<[u64; 4]>()).to_fields(); - let merge = [F::from_canonical_usize(rng.gen_bool(0.5) as usize)]; - let exp_pi = PublicInputs::new(&h, &drw, &min, &max, &merge); + let h = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); + let [individual_digest, multiplier_digest] = + array::from_fn(|_| Point::sample(rng).to_weierstrass().to_fields()); + let row_id_multiplier = rng.gen::<[u32; 4]>().map(F::from_canonical_u32); + let [min, max] = array::from_fn(|_| U256::from_limbs(rng.gen()).to_fields()); + let merge = [F::from_bool(rng.gen_bool(0.5))]; + let exp_pi = PublicInputs::new( + &h, + &individual_digest, + &multiplier_digest, + &row_id_multiplier, + &min, + &max, + &merge, + ); let exp_pi = &exp_pi.to_vec(); - assert_eq!(exp_pi.len(), PublicInputs::::TOTAL_LEN); - let test_circuit = TestPICircuit { exp_pi }; - let proof = run_circuit::(test_circuit); + let test_circuit = TestPublicInputs { exp_pi }; + let proof = run_circuit::(test_circuit); assert_eq!(&proof.public_inputs, exp_pi); + + // Check if the public inputs are constructed correctly. + let pi = PublicInputs::from_slice(&proof.public_inputs); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::RootHash)], + pi.to_root_hash_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::IndividualDigest)], + pi.to_individual_digest_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MultiplierDigest)], + pi.to_multiplier_digest_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::RowIdMultiplier)], + pi.to_row_id_multiplier_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MinValue)], + pi.to_min_value_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MaxValue)], + pi.to_max_value_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MergeFlag)], + slice::from_ref(pi.to_merge_flag_raw()), + ); } } diff --git a/verifiable-db/src/row_tree/row.rs b/verifiable-db/src/row_tree/row.rs new file mode 100644 index 000000000..df882e5a6 --- /dev/null +++ b/verifiable-db/src/row_tree/row.rs @@ -0,0 +1,122 @@ +//! Row information for the rows tree + +use crate::cells_tree::{Cell, CellWire, PublicInputs}; +use derive_more::Constructor; +use mp2_common::{ + poseidon::{empty_poseidon_hash, hash_to_int_target, H, HASH_TO_INT_LEN}, + serialization::{deserialize, serialize}, + types::CBuilder, + u256::UInt256Target, + utils::ToTargets, + F, +}; +use plonky2::{ + hash::hash_types::{HashOut, HashOutTarget}, + iop::{ + target::{BoolTarget, Target}, + witness::{PartialWitness, WitnessWrite}, + }, +}; +use plonky2_ecdsa::gadgets::{biguint::BigUintTarget, nonnative::CircuitBuilderNonNative}; +use plonky2_ecgfp5::gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize, Constructor)] +pub(crate) struct Row { + pub(crate) cell: Cell, + pub(crate) row_unique_data: HashOut, +} + +impl Row { + pub(crate) fn assign_wires(&self, pw: &mut PartialWitness, wires: &RowWire) { + self.cell.assign_wires(pw, &wires.cell); + pw.set_hash_target(wires.row_unique_data, self.row_unique_data); + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(crate) struct RowWire { + pub(crate) cell: CellWire, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + pub(crate) row_unique_data: HashOutTarget, +} + +/// Row digest result +#[derive(Clone, Debug)] +pub(crate) struct RowDigest { + pub(crate) is_merge: BoolTarget, + pub(crate) row_id_multiplier: BigUintTarget, + pub(crate) individual_vd: CurveTarget, + pub(crate) multiplier_vd: CurveTarget, +} + +impl RowWire { + pub(crate) fn new(b: &mut CBuilder) -> Self { + Self { + cell: CellWire::new(b), + row_unique_data: b.add_virtual_hash(), + } + } + + pub(crate) fn identifier(&self) -> Target { + self.cell.identifier + } + + pub(crate) fn value(&self) -> &UInt256Target { + &self.cell.value + } + + pub(crate) fn digest(&self, b: &mut CBuilder, cells_pi: &PublicInputs) -> RowDigest { + let metadata_digests = self.cell.split_metadata_digest(b); + let values_digests = self.cell.split_values_digest(b); + + let metadata_digests = + metadata_digests.accumulate(b, &cells_pi.split_metadata_digest_target()); + let values_digests = values_digests.accumulate(b, &cells_pi.split_values_digest_target()); + + // Compute row ID for individual cells: + // row_id_individual = H2Int(row_unique_data || individual_md) + let inputs = self + .row_unique_data + .to_targets() + .into_iter() + .chain(metadata_digests.individual.to_targets()) + .collect(); + let hash = b.hash_n_to_hash_no_pad::(inputs); + let row_id_individual = hash_to_int_target(b, hash); + let row_id_individual = b.biguint_to_nonnative(&row_id_individual); + + // Multiply row ID to individual value digest: + // individual_vd = row_id_individual * individual_vd + let individual_vd = b.curve_scalar_mul(values_digests.individual, &row_id_individual); + + // Multiplier is always employed for set of scalar variables, and `row_unique_data` + // for such a set is always `H("")``, so we can hardocode it in the circuit: + // row_id_multiplier = H2Int(H("") || multiplier_md) + let empty_hash = b.constant_hash(*empty_poseidon_hash()); + let inputs = empty_hash + .to_targets() + .into_iter() + .chain(metadata_digests.multiplier.to_targets()) + .collect(); + let hash = b.hash_n_to_hash_no_pad::(inputs); + let row_id_multiplier = hash_to_int_target(b, hash); + assert_eq!(row_id_multiplier.num_limbs(), HASH_TO_INT_LEN); + + let is_merge = values_digests.is_merge_case(b); + let multiplier_vd = values_digests.multiplier; + + RowDigest { + is_merge, + row_id_multiplier, + individual_vd, + multiplier_vd, + } + } +} + +/* +#[cfg(test)] +mod test { +} +*/