From 0836ee369f9a2e7c97761191758ec376be9f875e Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 30 Oct 2024 23:08:36 +0800 Subject: [PATCH 1/5] Replace MPT metadata with the counter. --- mp2-v1/src/values_extraction/api.rs | 7 +- .../gadgets/metadata_gadget.rs | 42 ++++-- mp2-v1/src/values_extraction/leaf_mapping.rs | 13 +- .../leaf_mapping_of_mappings.rs | 12 +- mp2-v1/src/values_extraction/leaf_single.rs | 10 +- mp2-v1/src/values_extraction/mod.rs | 20 +-- mp2-v1/tests/common/celltree.rs | 6 - mp2-v1/tests/common/rowtree.rs | 6 - verifiable-db/src/block_tree/mod.rs | 43 ++++-- verifiable-db/src/cells_tree/api.rs | 116 ++++++---------- verifiable-db/src/cells_tree/empty_node.rs | 20 ++- verifiable-db/src/cells_tree/full_node.rs | 41 +++--- verifiable-db/src/cells_tree/leaf.rs | 25 ++-- verifiable-db/src/cells_tree/mod.rs | 90 +++--------- verifiable-db/src/cells_tree/partial_node.rs | 39 +++--- verifiable-db/src/cells_tree/public_inputs.rs | 129 +++++++----------- verifiable-db/src/row_tree/api.rs | 76 ++--------- verifiable-db/src/row_tree/full_node.rs | 34 ++--- verifiable-db/src/row_tree/leaf.rs | 4 +- verifiable-db/src/row_tree/partial_node.rs | 22 ++- verifiable-db/src/row_tree/public_inputs.rs | 103 ++++++-------- verifiable-db/src/row_tree/row.rs | 109 +++++++-------- 22 files changed, 398 insertions(+), 569 deletions(-) diff --git a/mp2-v1/src/values_extraction/api.rs b/mp2-v1/src/values_extraction/api.rs index 779cd4c48..dd1d9794f 100644 --- a/mp2-v1/src/values_extraction/api.rs +++ b/mp2-v1/src/values_extraction/api.rs @@ -3,7 +3,7 @@ use super::{ branch::{BranchCircuit, BranchWires}, extension::{ExtensionNodeCircuit, ExtensionNodeWires}, - gadgets::{column_info::ColumnInfo, metadata_gadget::MetadataGadget}, + gadgets::metadata_gadget::MetadataGadget, leaf_mapping::{LeafMappingCircuit, LeafMappingWires}, leaf_mapping_of_mappings::{LeafMappingOfMappingsCircuit, LeafMappingOfMappingsWires}, leaf_single::{LeafSingleCircuit, LeafSingleWires}, @@ -887,7 +887,6 @@ mod tests { >(table_info.clone()); let values_digest = compute_leaf_single_values_digest::( - &metadata_digest, table_info, &extracted_column_identifiers, value, @@ -908,7 +907,6 @@ mod tests { ); let values_digest = compute_leaf_mapping_values_digest::( - &metadata_digest, table_info, &extracted_column_identifiers, value, @@ -936,7 +934,6 @@ mod tests { >(table_info.clone()); let values_digest = compute_leaf_single_values_digest::( - &metadata_digest, table_info, &extracted_column_identifiers, value, @@ -957,7 +954,6 @@ mod tests { ); let values_digest = compute_leaf_mapping_values_digest::( - &metadata_digest, table_info, &extracted_column_identifiers, value, @@ -993,7 +989,6 @@ mod tests { let values_digest = compute_leaf_mapping_of_mappings_values_digest::< TEST_MAX_FIELD_PER_EVM, >( - &metadata_digest, table_info, &extracted_column_identifiers, value, diff --git a/mp2-v1/src/values_extraction/gadgets/metadata_gadget.rs b/mp2-v1/src/values_extraction/gadgets/metadata_gadget.rs index 0ceacdb81..f031435ac 100644 --- a/mp2-v1/src/values_extraction/gadgets/metadata_gadget.rs +++ b/mp2-v1/src/values_extraction/gadgets/metadata_gadget.rs @@ -210,11 +210,14 @@ pub(crate) struct MetadataTarget MetadataTarget { - /// Compute the metadata digest. - pub(crate) fn digest(&self, b: &mut CBuilder, slot: Target) -> CurveTarget { + /// Compute the metadata digest and number of actual columns. + pub(crate) fn digest_info(&self, b: &mut CBuilder, slot: Target) -> (CurveTarget, Target) { + let zero = b.zero(); + let mut partial = b.curve_zero(); let mut non_extracted_column_found = b._false(); - let mut num_extracted_columns = b.zero(); + let mut num_extracted_columns = zero; + let mut num_actual_columns = zero; for i in 0..MAX_COLUMNS { let info = &self.table_info[i]; @@ -224,11 +227,12 @@ impl // If the current column has to be extracted, we check that: // - The EVM word associated to this column is the same as the EVM word we are extracting data from. // - The slot associated to this column is the same as the slot we are extracting data from. + // - Ensure that we extract only from non-dummy columns. // if is_extracted: - // evm_word == info.evm_word && slot == info.slot + // evm_word == info.evm_word && slot == info.slot && is_actual let is_evm_word_eq = b.is_equal(self.evm_word, info.evm_word); let is_slot_eq = b.is_equal(slot, info.slot); - let acc = [is_extracted, is_evm_word_eq, is_slot_eq] + let acc = [is_extracted, is_actual, is_evm_word_eq, is_slot_eq] .into_iter() .reduce(|acc, flag| b.and(acc, flag)) .unwrap(); @@ -265,6 +269,7 @@ impl non_extracted_column_found = BoolTarget::new_unsafe(acc); // num_extracted_columns += is_extracted num_extracted_columns = b.add(num_extracted_columns, is_extracted.target); + num_actual_columns = b.add(num_actual_columns, is_actual.target); // Compute the partial digest of all columns. // mpt_metadata = H(info.slot || info.evm_word || info.byte_offset || info.bit_offset || info.length) @@ -295,7 +300,7 @@ impl less_than_or_equal_to_unsafe(b, num_extracted_columns, max_field_per_evm, 8); b.assert_one(num_extracted_lt_or_eq_max.target); - partial + (partial, num_actual_columns) } } @@ -311,32 +316,45 @@ pub(crate) mod tests { struct TestMedataCircuit { metadata_gadget: MetadataGadget, slot: u8, + expected_num_actual_columns: usize, expected_metadata_digest: Point, } impl UserCircuit for TestMedataCircuit { - // Metadata target + slot + expected metadata digest + // Metadata target + slot + expected number of actual columns + expected metadata digest type Wires = ( MetadataTarget, Target, + Target, CurveTarget, ); fn build(b: &mut CBuilder) -> Self::Wires { let metadata_target = MetadataGadget::build(b); let slot = b.add_virtual_target(); + let expected_num_actual_columns = b.add_virtual_target(); let expected_metadata_digest = b.add_virtual_curve_target(); - let metadata_digest = metadata_target.digest(b, slot); + let (metadata_digest, num_actual_columns) = metadata_target.digest_info(b, slot); b.connect_curve_points(metadata_digest, expected_metadata_digest); - - (metadata_target, slot, expected_metadata_digest) + b.connect(num_actual_columns, expected_num_actual_columns); + + ( + metadata_target, + slot, + expected_num_actual_columns, + expected_metadata_digest, + ) } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { self.metadata_gadget.assign(pw, &wires.0); pw.set_target(wires.1, F::from_canonical_u8(self.slot)); - pw.set_curve_target(wires.2, self.expected_metadata_digest.to_weierstrass()); + pw.set_target( + wires.2, + F::from_canonical_usize(self.expected_num_actual_columns), + ); + pw.set_curve_target(wires.3, self.expected_metadata_digest.to_weierstrass()); } } @@ -348,11 +366,13 @@ pub(crate) mod tests { let evm_word = rng.gen(); let metadata_gadget = MetadataGadget::sample(slot, evm_word); + let expected_num_actual_columns = metadata_gadget.num_actual_columns(); let expected_metadata_digest = metadata_gadget.digest(); let test_circuit = TestMedataCircuit { metadata_gadget, slot, + expected_num_actual_columns, expected_metadata_digest, }; diff --git a/mp2-v1/src/values_extraction/leaf_mapping.rs b/mp2-v1/src/values_extraction/leaf_mapping.rs index 965f317a9..8430c5132 100644 --- a/mp2-v1/src/values_extraction/leaf_mapping.rs +++ b/mp2-v1/src/values_extraction/leaf_mapping.rs @@ -82,6 +82,7 @@ where { pub fn build(b: &mut CBuilder) -> LeafMappingWires { let zero = b.zero(); + let one = b.one(); let key_id = b.add_virtual_target(); let metadata = MetadataGadget::build(b); @@ -99,8 +100,10 @@ where // Left pad the leaf value. let value: Array = left_pad_leaf_value(b, &wires.value); - // Compute the metadata digest. - let metadata_digest = metadata.digest(b, slot.mapping_slot); + // Compute the metadata digest and number of actual columns. + let (metadata_digest, num_actual_columns) = metadata.digest_info(b, slot.mapping_slot); + // We add key column to number of actual columns. + let num_actual_columns = b.add(num_actual_columns, one); // key_column_md = H( "\0KEY" || slot) let key_id_prefix = b.constant(F::from_canonical_u32(u32::from_be_bytes( @@ -139,11 +142,11 @@ where // Compute the unique data to identify a row is the mapping key. // row_unique_data = H(pack(left_pad32(key)) let row_unique_data = b.hash_n_to_hash_no_pad::(packed_mapping_key); - // row_id = H2int(row_unique_data || metadata_digest) + // row_id = H2int(row_unique_data || num_actual_columns) let inputs = row_unique_data .to_targets() .into_iter() - .chain(metadata_digest.to_targets()) + .chain(once(num_actual_columns)) .collect(); let hash = b.hash_n_to_hash_no_pad::(inputs); let row_id = hash_to_int_target(b, hash); @@ -222,7 +225,6 @@ where #[cfg(test)] mod tests { - use super::*; use crate::{ tests::{TEST_MAX_COLUMNS, TEST_MAX_FIELD_PER_EVM}, @@ -309,7 +311,6 @@ mod tests { >(table_info.clone(), slot, key_id); // Compute the values digest. let values_digest = compute_leaf_mapping_values_digest::( - &metadata_digest, table_info, &extracted_column_identifiers, value.clone().try_into().unwrap(), diff --git a/mp2-v1/src/values_extraction/leaf_mapping_of_mappings.rs b/mp2-v1/src/values_extraction/leaf_mapping_of_mappings.rs index 6ff6c7ff3..ae28adce9 100644 --- a/mp2-v1/src/values_extraction/leaf_mapping_of_mappings.rs +++ b/mp2-v1/src/values_extraction/leaf_mapping_of_mappings.rs @@ -91,6 +91,7 @@ where b: &mut CBuilder, ) -> LeafMappingOfMappingsWires { let zero = b.zero(); + let two = b.two(); let [outer_key_id, inner_key_id] = b.add_virtual_target_arr(); let metadata = MetadataGadget::build(b); @@ -108,8 +109,10 @@ where // Left pad the leaf value. let value: Array = left_pad_leaf_value(b, &wires.value); - // Compute the metadata digest. - let metadata_digest = metadata.digest(b, slot.mapping_slot); + // Compute the metadata digest and number of actual columns. + let (metadata_digest, num_actual_columns) = metadata.digest_info(b, slot.mapping_slot); + // Add inner key and outer key columns to the number of actual columns. + let num_actual_columns = b.add(num_actual_columns, two); // Compute the outer and inner key metadata digests. let [outer_key_digest, inner_key_digest] = [ @@ -173,11 +176,11 @@ where .chain(packed_inner_key) .collect(); let row_unique_data = b.hash_n_to_hash_no_pad::(inputs); - // row_id = H2int(row_unique_data || metadata_digest) + // row_id = H2int(row_unique_data || num_actual_columns) let inputs = row_unique_data .to_targets() .into_iter() - .chain(metadata_digest.to_targets()) + .chain(once(num_actual_columns)) .collect(); let hash = b.hash_n_to_hash_no_pad::(inputs); let row_id = hash_to_int_target(b, hash); @@ -356,7 +359,6 @@ mod tests { >(table_info.clone(), slot, outer_key_id, inner_key_id); // Compute the values digest. let values_digest = compute_leaf_mapping_of_mappings_values_digest::( - &metadata_digest, table_info, &extracted_column_identifiers, value.clone().try_into().unwrap(), diff --git a/mp2-v1/src/values_extraction/leaf_single.rs b/mp2-v1/src/values_extraction/leaf_single.rs index 435c0f4bf..fe730665b 100644 --- a/mp2-v1/src/values_extraction/leaf_single.rs +++ b/mp2-v1/src/values_extraction/leaf_single.rs @@ -29,6 +29,7 @@ use plonky2_ecdsa::gadgets::nonnative::CircuitBuilderNonNative; use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; +use std::iter::once; #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct LeafSingleWires< @@ -83,8 +84,8 @@ where // Left pad the leaf value. let value: Array = left_pad_leaf_value(b, &wires.value); - // Compute the metadata digest. - let metadata_digest = metadata.digest(b, slot.base.slot); + // Compute the metadata digest and number of actual columns. + let (metadata_digest, num_actual_columns) = metadata.digest_info(b, slot.base.slot); // Compute the values digest. let values_digest = ColumnGadget::::new( @@ -94,12 +95,12 @@ where ) .build(b); - // row_id = H2int(H("") || metadata_digest) + // row_id = H2int(H("") || num_actual_columns) let empty_hash = b.constant_hash(*empty_poseidon_hash()); let inputs = empty_hash .to_targets() .into_iter() - .chain(metadata_digest.to_targets()) + .chain(once(num_actual_columns)) .collect(); let hash = b.hash_n_to_hash_no_pad::(inputs); let row_id = hash_to_int_target(b, hash); @@ -253,7 +254,6 @@ mod tests { let table_info = metadata.actual_table_info().to_vec(); let extracted_column_identifiers = metadata.extracted_column_identifiers(); let values_digest = compute_leaf_single_values_digest::( - &metadata_digest, table_info, &extracted_column_identifiers, value.clone().try_into().unwrap(), diff --git a/mp2-v1/src/values_extraction/mod.rs b/mp2-v1/src/values_extraction/mod.rs index 4dafa12c6..f5508acc8 100644 --- a/mp2-v1/src/values_extraction/mod.rs +++ b/mp2-v1/src/values_extraction/mod.rs @@ -182,20 +182,20 @@ pub fn compute_leaf_single_metadata_digest< /// Compute the values digest for single variable leaf. pub fn compute_leaf_single_values_digest( - metadata_digest: &Digest, table_info: Vec, extracted_column_identifiers: &[u64], value: [u8; MAPPING_LEAF_VALUE_LEN], ) -> Digest { + let num_actual_columns = F::from_canonical_usize(table_info.len()); let values_digest = ColumnGadgetData::::new(table_info, extracted_column_identifiers, value) .digest(); - // row_id = H2int(H("") || metadata_digest) + // row_id = H2int(H("") || num_actual_columns) let inputs = empty_poseidon_hash() .to_fields() .into_iter() - .chain(metadata_digest.to_fields()) + .chain(once(num_actual_columns)) .collect_vec(); let hash = H::hash_no_pad(&inputs); let row_id = hash_to_int_value(hash); @@ -238,7 +238,6 @@ pub fn compute_leaf_mapping_metadata_digest< /// Compute the values digest for mapping variable leaf. pub fn compute_leaf_mapping_values_digest( - metadata_digest: &Digest, table_info: Vec, extracted_column_identifiers: &[u64], value: [u8; MAPPING_LEAF_VALUE_LEN], @@ -246,6 +245,8 @@ pub fn compute_leaf_mapping_values_digest( evm_word: u32, key_id: u64, ) -> Digest { + // We add key column to number of actual columns. + let num_actual_columns = F::from_canonical_usize(table_info.len() + 1); let mut values_digest = ColumnGadgetData::::new(table_info, extracted_column_identifiers, value) .digest(); @@ -264,11 +265,11 @@ pub fn compute_leaf_mapping_values_digest( } // row_unique_data = H(pack(left_pad32(key)) let row_unique_data = H::hash_no_pad(&packed_mapping_key.collect_vec()); - // row_id = H2int(row_unique_data || metadata_digest) + // row_id = H2int(row_unique_data || num_actual_columns) let inputs = row_unique_data .to_fields() .into_iter() - .chain(metadata_digest.to_fields()) + .chain(once(num_actual_columns)) .collect_vec(); let hash = H::hash_no_pad(&inputs); let row_id = hash_to_int_value(hash); @@ -319,7 +320,6 @@ pub fn compute_leaf_mapping_of_mappings_metadata_digest< /// Compute the values digest for mapping of mappings leaf. pub fn compute_leaf_mapping_of_mappings_values_digest( - metadata_digest: &Digest, table_info: Vec, extracted_column_identifiers: &[u64], value: [u8; MAPPING_LEAF_VALUE_LEN], @@ -329,6 +329,8 @@ pub fn compute_leaf_mapping_of_mappings_values_digest Digest { + // Add inner key and outer key columns to the number of actual columns. + let num_actual_columns = F::from_canonical_usize(table_info.len() + 2); let mut values_digest = ColumnGadgetData::::new(table_info, extracted_column_identifiers, value) .digest(); @@ -360,11 +362,11 @@ pub fn compute_leaf_mapping_of_mappings_values_digest(inputs); + let row_id_multiplier = hash_to_int_target(b, hash); // multiplier_digest = rows_tree_proof.row_id_multiplier * rows_tree_proof.multiplier_vd let multiplier_vd = rows_tree_pi.multiplier_digest_target(); - let row_id_multiplier = b.biguint_to_nonnative(&rows_tree_pi.row_id_multiplier_target()); + let row_id_multiplier = b.biguint_to_nonnative(&row_id_multiplier); let multiplier_digest = b.curve_scalar_mul(multiplier_vd, &row_id_multiplier); // rows_digest_merge = multiplier_digest * rows_tree_proof.DR let individual_digest = rows_tree_pi.individual_digest_target(); @@ -110,7 +133,6 @@ pub(crate) mod tests { use alloy::primitives::U256; use mp2_common::{ keccak::PACKED_HASH_LEN, - poseidon::HASH_TO_INT_LEN, types::CBuilder, utils::{FromFields, ToFields}, C, F, @@ -119,7 +141,6 @@ pub(crate) mod tests { circuit::{run_circuit, UserCircuit}, utils::random_vector, }; - use num::BigUint; use plonky2::{ field::types::{Field, Sample}, hash::hash_types::NUM_HASH_OUT_ELTS, @@ -170,9 +191,9 @@ pub(crate) mod tests { } else { Point::NEUTRAL }; - let row_id_multiplier = BigUint::from_slice(&random_vector::(HASH_TO_INT_LEN)); + let mulitplier_cnt = rng.gen_range(1..100); - row_tree::PublicInputs::sample(multiplier_digest, row_id_multiplier, min, max) + row_tree::PublicInputs::sample(multiplier_digest, min, max, mulitplier_cnt) } /// Generate a random extraction public inputs. diff --git a/verifiable-db/src/cells_tree/api.rs b/verifiable-db/src/cells_tree/api.rs index 5353ad197..3eb707a43 100644 --- a/verifiable-db/src/cells_tree/api.rs +++ b/verifiable-db/src/cells_tree/api.rs @@ -39,13 +39,12 @@ impl CircuitInput { /// Create a circuit input for proving a leaf node. /// It is not considered a multiplier column. Please use `leaf_multiplier` for registering a /// multiplier column. - pub fn leaf(identifier: u64, value: U256, mpt_metadata: HashOut) -> Self { + pub fn leaf(identifier: u64, value: U256) -> Self { CircuitInput::Leaf( Cell { identifier: F::from_canonical_u64(identifier), value, is_multiplier: false, - mpt_metadata, } .into(), ) @@ -53,18 +52,12 @@ impl CircuitInput { /// Create a circuit input for proving a leaf node whose value is considered as a multiplier /// depending on the boolean value. /// i.e. it means it's one of the repeated value amongst all the rows - pub fn leaf_multiplier( - identifier: u64, - value: U256, - is_multiplier: bool, - mpt_metadata: HashOut, - ) -> Self { + pub fn leaf_multiplier(identifier: u64, value: U256, is_multiplier: bool) -> Self { CircuitInput::Leaf( Cell { identifier: F::from_canonical_u64(identifier), value, is_multiplier, - mpt_metadata, } .into(), ) @@ -73,17 +66,11 @@ impl CircuitInput { /// Create a circuit input for proving a full node of 2 children. /// It is not considered a multiplier column. Please use `full_multiplier` for registering a /// multiplier column. - pub fn full( - identifier: u64, - value: U256, - mpt_metadata: HashOut, - child_proofs: [Vec; 2], - ) -> Self { + pub fn full(identifier: u64, value: U256, child_proofs: [Vec; 2]) -> Self { CircuitInput::FullNode(new_child_input( F::from_canonical_u64(identifier), value, false, - mpt_metadata, child_proofs.to_vec(), )) } @@ -93,31 +80,23 @@ impl CircuitInput { identifier: u64, value: U256, is_multiplier: bool, - mpt_metadata: HashOut, child_proofs: [Vec; 2], ) -> Self { CircuitInput::FullNode(new_child_input( F::from_canonical_u64(identifier), value, is_multiplier, - mpt_metadata, child_proofs.to_vec(), )) } /// Create a circuit input for proving a partial node of 1 child. /// It is not considered a multiplier column. Please use `partial_multiplier` for registering a /// multiplier column. - pub fn partial( - identifier: u64, - value: U256, - mpt_metadata: HashOut, - child_proof: Vec, - ) -> Self { + pub fn partial(identifier: u64, value: U256, child_proof: Vec) -> Self { CircuitInput::PartialNode(new_child_input( F::from_canonical_u64(identifier), value, false, - mpt_metadata, vec![child_proof], )) } @@ -125,14 +104,12 @@ impl CircuitInput { identifier: u64, value: U256, is_multiplier: bool, - mpt_metadata: HashOut, child_proof: Vec, ) -> Self { CircuitInput::PartialNode(new_child_input( F::from_canonical_u64(identifier), value, is_multiplier, - mpt_metadata, vec![child_proof], )) } @@ -143,7 +120,6 @@ fn new_child_input( identifier: F, value: U256, is_multiplier: bool, - mpt_metadata: HashOut, serialized_child_proofs: Vec>, ) -> ChildInput { ChildInput { @@ -151,7 +127,6 @@ fn new_child_input( identifier, value, is_multiplier, - mpt_metadata, }, serialized_child_proofs, } @@ -308,13 +283,12 @@ mod tests { fn generate_leaf_proof(params: &PublicParameters) -> Vec { // Build the circuit input. - let cell = Cell::sample(false); + let is_multiplier = false; + let cell = Cell::sample(is_multiplier); let id = cell.identifier; let value = cell.value; - let mpt_metadata = cell.mpt_metadata; let values_digests = cell.split_values_digest(); - let metadata_digests = cell.split_metadata_digest(); - let input = CircuitInput::leaf(id.to_canonical_u64(), value, mpt_metadata); + let input = CircuitInput::leaf(id.to_canonical_u64(), value); // Generate proof. let proof = params.generate_proof(input).unwrap(); @@ -351,16 +325,11 @@ mod tests { pi.multiplier_values_digest_point(), values_digests.multiplier.to_weierstrass(), ); - // Check individual metadata digest - assert_eq!( - pi.individual_metadata_digest_point(), - metadata_digests.individual.to_weierstrass(), - ); - // Check multiplier metadata digest - assert_eq!( - pi.multiplier_metadata_digest_point(), - metadata_digests.multiplier.to_weierstrass(), - ); + // Check individual counter + let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; + assert_eq!(pi.individual_counter(), individual_cnt); + // Check multiplier counter + assert_eq!(pi.multiplier_counter(), F::ONE - individual_cnt); proof } @@ -389,16 +358,10 @@ mod tests { pi.multiplier_values_digest_point(), WeierstrassPoint::NEUTRAL ); - // Check individual metadata digest - assert_eq!( - pi.individual_metadata_digest_point(), - WeierstrassPoint::NEUTRAL - ); - // Check multiplier metadata digest - assert_eq!( - pi.multiplier_metadata_digest_point(), - WeierstrassPoint::NEUTRAL - ); + // Check individual counter + assert_eq!(pi.individual_counter(), F::ZERO); + // Check multiplier counter + assert_eq!(pi.multiplier_counter(), F::ZERO); proof } @@ -415,13 +378,12 @@ mod tests { .collect(); // Build the circuit input. - let cell = Cell::sample(false); + let is_multiplier = false; + let cell = Cell::sample(is_multiplier); let id = cell.identifier; let value = cell.value; - let mpt_metadata = cell.mpt_metadata; let values_digests = cell.split_values_digest(); - let metadata_digests = cell.split_metadata_digest(); - let input = CircuitInput::full(id.to_canonical_u64(), value, mpt_metadata, child_proofs); + let input = CircuitInput::full(id.to_canonical_u64(), value, child_proofs); // Generate proof. let proof = params.generate_proof(input).unwrap(); @@ -436,9 +398,6 @@ mod tests { let values_digests = child_pis.iter().fold(values_digests, |acc, pi| { acc.accumulate(&pi.split_values_digest_point()) }); - let metadata_digests = child_pis.iter().fold(metadata_digests, |acc, pi| { - acc.accumulate(&pi.split_metadata_digest_point()) - }); // Check the node hash { @@ -465,15 +424,19 @@ mod tests { pi.multiplier_values_digest_point(), values_digests.multiplier.to_weierstrass(), ); - // Check individual metadata digest + // Check individual counter + let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; assert_eq!( - pi.individual_metadata_digest_point(), - metadata_digests.individual.to_weierstrass(), + pi.individual_counter(), + child_pis + .iter() + .fold(individual_cnt, |acc, pi| acc + pi.individual_counter()), ); - // Check multiplier metadata digest + // Check multiplier counter assert_eq!( - pi.multiplier_metadata_digest_point(), - metadata_digests.multiplier.to_weierstrass(), + pi.multiplier_counter(), + child_pis.iter().fold(F::ONE - individual_cnt, |acc, pi| acc + + pi.multiplier_counter()), ); proof @@ -488,13 +451,12 @@ mod tests { let child_pi = PublicInputs::from_slice(&child_pi); // Build the circuit input. - let cell = Cell::sample(false); + let is_multiplier = false; + let cell = Cell::sample(is_multiplier); let id = cell.identifier; let value = cell.value; - let mpt_metadata = cell.mpt_metadata; let values_digests = cell.split_values_digest(); - let metadata_digests = cell.split_metadata_digest(); - let input = CircuitInput::partial(id.to_canonical_u64(), value, mpt_metadata, child_proof); + let input = CircuitInput::partial(id.to_canonical_u64(), value, child_proof); // Generate proof. let proof = params.generate_proof(input).unwrap(); @@ -507,7 +469,6 @@ mod tests { let pi = PublicInputs::from_slice(&pi); let values_digests = values_digests.accumulate(&child_pi.split_values_digest_point()); - let metadata_digests = metadata_digests.accumulate(&child_pi.split_metadata_digest_point()); // Check the node hash { @@ -535,15 +496,16 @@ mod tests { pi.multiplier_values_digest_point(), values_digests.multiplier.to_weierstrass(), ); - // Check individual metadata digest + // Check individual counter + let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; assert_eq!( - pi.individual_metadata_digest_point(), - metadata_digests.individual.to_weierstrass(), + pi.individual_counter(), + individual_cnt + child_pi.individual_counter(), ); - // Check multiplier metadata digest + // Check multiplier counter assert_eq!( - pi.multiplier_metadata_digest_point(), - metadata_digests.multiplier.to_weierstrass(), + pi.multiplier_counter(), + F::ONE - individual_cnt + child_pi.multiplier_counter(), ); proof diff --git a/verifiable-db/src/cells_tree/empty_node.rs b/verifiable-db/src/cells_tree/empty_node.rs index f1f936ddb..212a297a8 100644 --- a/verifiable-db/src/cells_tree/empty_node.rs +++ b/verifiable-db/src/cells_tree/empty_node.rs @@ -23,11 +23,14 @@ impl EmptyNodeCircuit { let empty_hash = empty_poseidon_hash(); let h = b.constant_hash(*empty_hash).elements; + // ZERO + let zero = b.zero(); + // CURVE_ZERO let curve_zero = b.curve_zero().to_targets(); // Register the public inputs. - PublicInputs::new(&h, &curve_zero, &curve_zero, &curve_zero, &curve_zero).register(b); + PublicInputs::new(&h, &curve_zero, &curve_zero, &zero, &zero).register(b); EmptyNodeWires } @@ -59,6 +62,7 @@ mod tests { use super::*; use mp2_common::C; use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::field::types::Field; use plonky2_ecgfp5::curve::curve::WeierstrassPoint; impl UserCircuit for EmptyNodeCircuit { @@ -91,15 +95,9 @@ mod tests { pi.multiplier_values_digest_point(), WeierstrassPoint::NEUTRAL ); - // Check individual metadata digest - assert_eq!( - pi.individual_metadata_digest_point(), - WeierstrassPoint::NEUTRAL - ); - // Check multiplier metadata digest - assert_eq!( - pi.multiplier_metadata_digest_point(), - WeierstrassPoint::NEUTRAL - ); + // Check individual counter + assert_eq!(pi.individual_counter(), F::ZERO); + // Check multiplier counter + assert_eq!(pi.multiplier_counter(), F::ZERO); } } diff --git a/verifiable-db/src/cells_tree/full_node.rs b/verifiable-db/src/cells_tree/full_node.rs index 6dcc64754..79a87df95 100644 --- a/verifiable-db/src/cells_tree/full_node.rs +++ b/verifiable-db/src/cells_tree/full_node.rs @@ -25,13 +25,22 @@ impl FullNodeCircuit { let [p1, p2] = child_proofs; let cell = CellWire::new(b); - let metadata_digests = - cell.split_and_accumulate_metadata_digest(b, &p1.split_metadata_digest_target()); let values_digests = cell.split_and_accumulate_values_digest(b, &p1.split_values_digest_target()); - let metadata_digests = metadata_digests.accumulate(b, &p2.split_metadata_digest_target()); let values_digests = values_digests.accumulate(b, &p2.split_values_digest_target()); + let is_individual = cell.is_individual(b); + let individual_cnt = b.add_many([ + is_individual.target, + p1.individual_counter_target(), + p2.individual_counter_target(), + ]); + let multiplier_cnt = b.add_many([ + cell.is_multiplier().target, + p1.multiplier_counter_target(), + p2.multiplier_counter_target(), + ]); + // H(p1.H || p2.H || identifier || value) let inputs = p1 .node_hash_target() @@ -47,8 +56,8 @@ impl FullNodeCircuit { &h.to_targets(), &values_digests.individual.to_targets(), &values_digests.multiplier.to_targets(), - &metadata_digests.individual.to_targets(), - &metadata_digests.multiplier.to_targets(), + &individual_cnt, + &multiplier_cnt, ) .register(b); @@ -91,7 +100,7 @@ mod tests { use itertools::Itertools; use mp2_common::{poseidon::H, utils::ToFields, C}; use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; + use plonky2::{field::types::Field, iop::witness::WitnessWrite, plonk::config::Hasher}; #[derive(Clone, Debug)] struct TestFullNodeCircuit<'a> { @@ -162,7 +171,6 @@ mod tests { let id = cell.identifier; let value = cell.value; let values_digests = cell.split_values_digest(); - let metadata_digests = cell.split_metadata_digest(); let child_pis = &[ PublicInputs::::sample(is_left_child_multiplier), @@ -184,9 +192,6 @@ mod tests { let values_digests = child_pis.iter().fold(values_digests, |acc, pi| { acc.accumulate(&pi.split_values_digest_point()) }); - let metadata_digests = child_pis.iter().fold(metadata_digests, |acc, pi| { - acc.accumulate(&pi.split_metadata_digest_point()) - }); // Check the node hash { @@ -212,15 +217,19 @@ mod tests { pi.multiplier_values_digest_point(), values_digests.multiplier.to_weierstrass(), ); - // Check individual metadata digest + // Check individual counter + let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; assert_eq!( - pi.individual_metadata_digest_point(), - metadata_digests.individual.to_weierstrass(), + pi.individual_counter(), + child_pis + .iter() + .fold(individual_cnt, |acc, pi| acc + pi.individual_counter()), ); - // Check multiplier metadata digest + // Check multiplier counter assert_eq!( - pi.multiplier_metadata_digest_point(), - metadata_digests.multiplier.to_weierstrass(), + pi.multiplier_counter(), + child_pis.iter().fold(F::ONE - individual_cnt, |acc, pi| acc + + pi.multiplier_counter()), ); } } diff --git a/verifiable-db/src/cells_tree/leaf.rs b/verifiable-db/src/cells_tree/leaf.rs index 4d8d7663e..564f22a17 100644 --- a/verifiable-db/src/cells_tree/leaf.rs +++ b/verifiable-db/src/cells_tree/leaf.rs @@ -26,8 +26,9 @@ pub struct LeafCircuit(Cell); impl LeafCircuit { fn build(b: &mut CBuilder) -> LeafWires { let cell = CellWire::new(b); - let metadata_digests = cell.split_metadata_digest(b); let values_digests = cell.split_values_digest(b); + let individual_cnt = cell.is_individual(b).target; + let multiplier_cnt = cell.is_multiplier().target; // H(H("") || H("") || identifier || pack_u32(value)) let empty_hash = b.constant_hash(*empty_poseidon_hash()).to_targets(); @@ -45,8 +46,8 @@ impl LeafCircuit { &h.to_targets(), &values_digests.individual.to_targets(), &values_digests.multiplier.to_targets(), - &metadata_digests.individual.to_targets(), - &metadata_digests.multiplier.to_targets(), + &individual_cnt, + &multiplier_cnt, ) .register(b); @@ -87,7 +88,7 @@ mod tests { use itertools::Itertools; use mp2_common::{poseidon::H, utils::ToFields, C}; use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::plonk::config::Hasher; + use plonky2::{field::types::Field, plonk::config::Hasher}; impl UserCircuit for LeafCircuit { type Wires = LeafWires; @@ -112,7 +113,6 @@ mod tests { let id = cell.identifier; let value = cell.value; let values_digests = cell.split_values_digest(); - let metadata_digests = cell.split_metadata_digest(); let test_circuit: LeafCircuit = cell.into(); let proof = run_circuit::(test_circuit); @@ -143,15 +143,10 @@ mod tests { pi.multiplier_values_digest_point(), values_digests.multiplier.to_weierstrass(), ); - // Check individual metadata digest - assert_eq!( - pi.individual_metadata_digest_point(), - metadata_digests.individual.to_weierstrass(), - ); - // Check multiplier metadata digest - assert_eq!( - pi.multiplier_metadata_digest_point(), - metadata_digests.multiplier.to_weierstrass(), - ); + // Check individual counter + let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; + assert_eq!(pi.individual_counter(), individual_cnt); + // Check multiplier counter + assert_eq!(pi.multiplier_counter(), F::ONE - individual_cnt); } } diff --git a/verifiable-db/src/cells_tree/mod.rs b/verifiable-db/src/cells_tree/mod.rs index 13a974dbb..3ac65fc1b 100644 --- a/verifiable-db/src/cells_tree/mod.rs +++ b/verifiable-db/src/cells_tree/mod.rs @@ -21,12 +21,9 @@ use mp2_common::{ use serde::{Deserialize, Serialize}; use std::iter::once; -use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget}, - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, +use plonky2::iop::{ + target::{BoolTarget, Target}, + witness::{PartialWitness, WitnessWrite}, }; use plonky2_ecgfp5::gadgets::curve::CurveTarget; pub use public_inputs::PublicInputs; @@ -41,8 +38,6 @@ pub struct Cell { pub(crate) value: U256, /// is the secondary value should be included in multiplier digest or not pub(crate) is_multiplier: bool, - /// Hash of the metadata associated to this cell, as computed in MPT extraction circuits - pub(crate) mpt_metadata: HashOut, } impl Cell { @@ -50,23 +45,17 @@ impl Cell { pw.set_u256_target(&wires.value, self.value); pw.set_target(wires.identifier, self.identifier); pw.set_bool_target(wires.is_multiplier, self.is_multiplier); - pw.set_hash_target(wires.mpt_metadata, self.mpt_metadata); } - pub fn split_metadata_digest(&self) -> SplitDigestPoint { - let digest = self.metadata_digest(); - SplitDigestPoint::from_single_digest_point(digest, self.is_multiplier) + pub fn is_multiplier(&self) -> bool { + self.is_multiplier + } + pub fn is_individual(&self) -> bool { + !self.is_multiplier } pub fn split_values_digest(&self) -> SplitDigestPoint { let digest = self.values_digest(); SplitDigestPoint::from_single_digest_point(digest, self.is_multiplier) } - pub fn split_and_accumulate_metadata_digest( - &self, - child_digest: SplitDigestPoint, - ) -> SplitDigestPoint { - let split_digest = self.split_metadata_digest(); - split_digest.accumulate(&child_digest) - } pub fn split_and_accumulate_values_digest( &self, child_digest: SplitDigestPoint, @@ -74,17 +63,6 @@ impl Cell { let split_digest = self.split_values_digest(); split_digest.accumulate(&child_digest) } - fn metadata_digest(&self) -> Digest { - // D(mpt_metadata || identifier) - let inputs = self - .mpt_metadata - .to_fields() - .into_iter() - .chain(once(self.identifier)) - .collect_vec(); - - map_to_curve_point(&inputs) - } fn values_digest(&self) -> Digest { // D(identifier || pack_u32(value)) let inputs = once(self.identifier) @@ -102,8 +80,6 @@ pub struct CellWire { pub(crate) identifier: Target, #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] pub(crate) is_multiplier: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - pub(crate) mpt_metadata: HashOutTarget, } impl CellWire { @@ -112,25 +88,18 @@ impl CellWire { value: b.add_virtual_u256(), identifier: b.add_virtual_target(), is_multiplier: b.add_virtual_bool_target_safe(), - mpt_metadata: b.add_virtual_hash(), } } - pub fn split_metadata_digest(&self, b: &mut CBuilder) -> SplitDigestTarget { - let digest = self.metadata_digest(b); - SplitDigestTarget::from_single_digest_target(b, digest, self.is_multiplier) + pub fn is_multiplier(&self) -> BoolTarget { + self.is_multiplier + } + pub fn is_individual(&self, b: &mut CBuilder) -> BoolTarget { + b.not(self.is_multiplier) } pub fn split_values_digest(&self, b: &mut CBuilder) -> SplitDigestTarget { let digest = self.values_digest(b); SplitDigestTarget::from_single_digest_target(b, digest, self.is_multiplier) } - pub fn split_and_accumulate_metadata_digest( - &self, - b: &mut CBuilder, - child_digest: &SplitDigestTarget, - ) -> SplitDigestTarget { - let split_digest = self.split_metadata_digest(b); - split_digest.accumulate(b, child_digest) - } pub fn split_and_accumulate_values_digest( &self, b: &mut CBuilder, @@ -139,17 +108,6 @@ impl CellWire { let split_digest = self.split_values_digest(b); split_digest.accumulate(b, child_digest) } - fn metadata_digest(&self, b: &mut CBuilder) -> CurveTarget { - // D(mpt_metadata || identifier) - let inputs = self - .mpt_metadata - .to_targets() - .into_iter() - .chain(once(self.identifier)) - .collect_vec(); - - b.map_to_curve_point(&inputs) - } fn values_digest(&self, b: &mut CBuilder) -> CurveTarget { // D(identifier || pack_u32(value)) let inputs = once(self.identifier) @@ -183,9 +141,8 @@ pub(crate) mod tests { let identifier = rng.gen::().to_field(); let value = U256::from_limbs(rng.gen()); - let mpt_metadata = HashOut::rand(); - Cell::new(identifier, value, is_multiplier, mpt_metadata) + Cell::new(identifier, value, is_multiplier) } } @@ -215,13 +172,9 @@ pub(crate) mod tests { let cell = CellWire::new(b); let values_digest = cell.split_and_accumulate_values_digest(b, &child_values_digest); - let metadata_digest = - cell.split_and_accumulate_metadata_digest(b, &child_metadata_digest); b.register_curve_public_input(values_digest.individual); b.register_curve_public_input(values_digest.multiplier); - b.register_curve_public_input(metadata_digest.individual); - b.register_curve_public_input(metadata_digest.multiplier); (cell, child_values_digest, child_metadata_digest) } @@ -264,9 +217,7 @@ pub(crate) mod tests { let cell = &Cell::sample(rng.gen()); let values_digests = cell.split_values_digest(); - let metadata_digests = cell.split_metadata_digest(); let exp_values_digests = values_digests.accumulate(child_values_digest); - let exp_metadata_digests = metadata_digests.accumulate(child_metadata_digest); let test_circuit = TestCellCircuit { cell, @@ -276,16 +227,13 @@ pub(crate) mod tests { let proof = run_circuit::(test_circuit); - let [values_individual, values_multiplier, metadata_individual, metadata_multiplier] = - array::from_fn(|i| { - Point::from_fields( - &proof.public_inputs[i * CURVE_TARGET_LEN..(i + 1) * CURVE_TARGET_LEN], - ) - }); + let [values_individual, values_multiplier] = array::from_fn(|i| { + Point::from_fields( + &proof.public_inputs[i * CURVE_TARGET_LEN..(i + 1) * CURVE_TARGET_LEN], + ) + }); assert_eq!(values_individual, exp_values_digests.individual); assert_eq!(values_multiplier, exp_values_digests.multiplier); - assert_eq!(metadata_individual, exp_metadata_digests.individual); - assert_eq!(metadata_multiplier, exp_metadata_digests.multiplier); } } diff --git a/verifiable-db/src/cells_tree/partial_node.rs b/verifiable-db/src/cells_tree/partial_node.rs index 2724e5554..d8c081b75 100644 --- a/verifiable-db/src/cells_tree/partial_node.rs +++ b/verifiable-db/src/cells_tree/partial_node.rs @@ -27,11 +27,13 @@ pub struct PartialNodeCircuit(Cell); impl PartialNodeCircuit { pub fn build(b: &mut CBuilder, p: PublicInputs) -> PartialNodeWires { let cell = CellWire::new(b); - let metadata_digests = - cell.split_and_accumulate_metadata_digest(b, &p.split_metadata_digest_target()); let values_digests = cell.split_and_accumulate_values_digest(b, &p.split_values_digest_target()); + let is_individual = cell.is_individual(b); + let individual_cnt = b.add(is_individual.target, p.individual_counter_target()); + let multiplier_cnt = b.add(cell.is_multiplier().target, p.multiplier_counter_target()); + /* # since there is no sorting constraint among the nodes of this tree, to simplify # the circuits, when we build a node with only one child, we can always place @@ -54,8 +56,8 @@ impl PartialNodeCircuit { &h.to_targets(), &values_digests.individual.to_targets(), &values_digests.multiplier.to_targets(), - &metadata_digests.individual.to_targets(), - &metadata_digests.multiplier.to_targets(), + &individual_cnt, + &multiplier_cnt, ) .register(b); @@ -97,7 +99,7 @@ mod tests { use itertools::Itertools; use mp2_common::{poseidon::H, utils::ToFields, C}; use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; + use plonky2::{field::types::Field, iop::witness::WitnessWrite, plonk::config::Hasher}; #[derive(Clone, Debug)] struct TestPartialNodeCircuit<'a> { @@ -138,22 +140,20 @@ mod tests { let cell = Cell::sample(is_multiplier); let id = cell.identifier; let value = cell.value; - let values_digests = cell.split_values_digest(); - let metadata_digests = cell.split_metadata_digest(); - let child_pi = &PublicInputs::::sample(is_child_multiplier); + let child_proof = &PublicInputs::::sample(is_child_multiplier); + let child_pi = PublicInputs::from_slice(child_proof); + + let values_digests = + cell.split_and_accumulate_values_digest(child_pi.split_values_digest_point()); let test_circuit = TestPartialNodeCircuit { c: cell.into(), - child_pi, + child_pi: child_proof, }; let proof = run_circuit::(test_circuit); let pi = PublicInputs::from_slice(&proof.public_inputs); - let child_pi = PublicInputs::from_slice(child_pi); - - let values_digests = values_digests.accumulate(&child_pi.split_values_digest_point()); - let metadata_digests = metadata_digests.accumulate(&child_pi.split_metadata_digest_point()); // Check the node hash { @@ -180,15 +180,16 @@ mod tests { pi.multiplier_values_digest_point(), values_digests.multiplier.to_weierstrass(), ); - // Check individual metadata digest + // Check individual counter + let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; assert_eq!( - pi.individual_metadata_digest_point(), - metadata_digests.individual.to_weierstrass(), + pi.individual_counter(), + individual_cnt + child_pi.individual_counter(), ); - // Check multiplier metadata digest + // Check multiplier counter assert_eq!( - pi.multiplier_metadata_digest_point(), - metadata_digests.multiplier.to_weierstrass(), + pi.multiplier_counter(), + F::ONE - individual_cnt + child_pi.multiplier_counter(), ); } } diff --git a/verifiable-db/src/cells_tree/public_inputs.rs b/verifiable-db/src/cells_tree/public_inputs.rs index 422cfbc38..e39764f0e 100644 --- a/verifiable-db/src/cells_tree/public_inputs.rs +++ b/verifiable-db/src/cells_tree/public_inputs.rs @@ -13,6 +13,7 @@ use plonky2::{ iop::target::Target, }; use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; +use std::iter::once; pub enum CellsTreePublicInputs { // `H : F[4]` - Poseidon hash of the subtree at this node @@ -21,10 +22,10 @@ pub enum CellsTreePublicInputs { IndividualValuesDigest, // - `multiplier_vd : Digest` - Cumulative digest of values of cells accumulated as multiplier MultiplierValuesDigest, - // - `individual_md : Digest` - Cumulative digest of metadata of cells accumulated as individual - IndividualMetadataDigest, - // - `multiplier_md : Digest` - Cumulative digest of metadata of cells accumulated as multiplier - MultiplierMetadataDigest, + // - `individual_counter : F` - Counter of the number of cells accumulated so far as individual + IndividualCounter, + // - `multiplier_counter : F` - Counter of the number of cells accumulated so far as multiplier + MultiplierCounter, } /// Public inputs for Cells Tree Construction @@ -33,19 +34,19 @@ pub struct PublicInputs<'a, T> { pub(crate) h: &'a [T], pub(crate) individual_vd: &'a [T], pub(crate) multiplier_vd: &'a [T], - pub(crate) individual_md: &'a [T], - pub(crate) multiplier_md: &'a [T], + pub(crate) individual_cnt: &'a T, + pub(crate) multiplier_cnt: &'a T, } -const NUM_PUBLIC_INPUTS: usize = CellsTreePublicInputs::MultiplierMetadataDigest as usize + 1; +const NUM_PUBLIC_INPUTS: usize = CellsTreePublicInputs::MultiplierCounter as usize + 1; impl<'a, T: Clone> PublicInputs<'a, T> { const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ Self::to_range(CellsTreePublicInputs::NodeHash), Self::to_range(CellsTreePublicInputs::IndividualValuesDigest), Self::to_range(CellsTreePublicInputs::MultiplierValuesDigest), - Self::to_range(CellsTreePublicInputs::IndividualMetadataDigest), - Self::to_range(CellsTreePublicInputs::MultiplierMetadataDigest), + Self::to_range(CellsTreePublicInputs::IndividualCounter), + Self::to_range(CellsTreePublicInputs::MultiplierCounter), ]; const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ @@ -55,10 +56,10 @@ impl<'a, T: Clone> PublicInputs<'a, T> { CURVE_TARGET_LEN, // Cumulative digest of values of cells accumulated as multiplier CURVE_TARGET_LEN, - // Cumulative digest of metadata of cells accumulated as individual - CURVE_TARGET_LEN, - // Cumulative digest of metadata of cells accumulated as multiplier - CURVE_TARGET_LEN, + // Counter of the number of cells accumulated so far as individual + 1, + // Counter of the number of cells accumulated so far as multiplier + 1, ]; pub(crate) const fn to_range(pi: CellsTreePublicInputs) -> PublicInputRange { @@ -73,7 +74,7 @@ impl<'a, T: Clone> PublicInputs<'a, T> { } pub const fn total_len() -> usize { - Self::to_range(CellsTreePublicInputs::MultiplierMetadataDigest).end + Self::to_range(CellsTreePublicInputs::MultiplierCounter).end } pub fn to_node_hash_raw(&self) -> &[T] { @@ -88,12 +89,12 @@ impl<'a, T: Clone> PublicInputs<'a, T> { self.multiplier_vd } - pub fn to_individual_metadata_digest_raw(&self) -> &[T] { - self.individual_md + pub fn to_individual_counter_raw(&self) -> &T { + self.individual_cnt } - pub fn to_multiplier_metadata_digest_raw(&self) -> &[T] { - self.multiplier_md + pub fn to_multiplier_counter_raw(&self) -> &T { + self.multiplier_cnt } pub fn from_slice(input: &'a [T]) -> Self { @@ -107,8 +108,8 @@ impl<'a, T: Clone> PublicInputs<'a, T> { h: &input[Self::PI_RANGES[0].clone()], individual_vd: &input[Self::PI_RANGES[1].clone()], multiplier_vd: &input[Self::PI_RANGES[2].clone()], - individual_md: &input[Self::PI_RANGES[3].clone()], - multiplier_md: &input[Self::PI_RANGES[4].clone()], + individual_cnt: &input[Self::PI_RANGES[3].clone()][0], + multiplier_cnt: &input[Self::PI_RANGES[4].clone()][0], } } @@ -116,15 +117,15 @@ impl<'a, T: Clone> PublicInputs<'a, T> { h: &'a [T], individual_vd: &'a [T], multiplier_vd: &'a [T], - individual_md: &'a [T], - multiplier_md: &'a [T], + individual_cnt: &'a T, + multiplier_cnt: &'a T, ) -> Self { Self { h, individual_vd, multiplier_vd, - individual_md, - multiplier_md, + individual_cnt, + multiplier_cnt, } } @@ -133,8 +134,8 @@ impl<'a, T: Clone> PublicInputs<'a, T> { .iter() .chain(self.individual_vd) .chain(self.multiplier_vd) - .chain(self.individual_md) - .chain(self.multiplier_md) + .chain(once(self.individual_cnt)) + .chain(once(self.multiplier_cnt)) .cloned() .collect() } @@ -147,8 +148,8 @@ impl<'a> PublicInputCommon for PublicInputs<'a, Target> { cb.register_public_inputs(self.h); cb.register_public_inputs(self.individual_vd); cb.register_public_inputs(self.multiplier_vd); - cb.register_public_inputs(self.individual_md); - cb.register_public_inputs(self.multiplier_md); + cb.register_public_input(*self.individual_cnt); + cb.register_public_input(*self.multiplier_cnt); } } @@ -165,14 +166,6 @@ impl<'a> PublicInputs<'a, Target> { CurveTarget::from_targets(self.multiplier_vd) } - pub fn individual_metadata_digest_target(&self) -> CurveTarget { - CurveTarget::from_targets(self.individual_md) - } - - pub fn multiplier_metadata_digest_target(&self) -> CurveTarget { - CurveTarget::from_targets(self.multiplier_md) - } - pub fn split_values_digest_target(&self) -> SplitDigestTarget { SplitDigestTarget { individual: self.individual_values_digest_target(), @@ -180,11 +173,12 @@ impl<'a> PublicInputs<'a, Target> { } } - pub fn split_metadata_digest_target(&self) -> SplitDigestTarget { - SplitDigestTarget { - individual: self.individual_metadata_digest_target(), - multiplier: self.multiplier_metadata_digest_target(), - } + pub fn individual_counter_target(&self) -> Target { + *self.to_individual_counter_raw() + } + + pub fn multiplier_counter_target(&self) -> Target { + *self.to_multiplier_counter_raw() } } @@ -201,14 +195,6 @@ impl<'a> PublicInputs<'a, F> { WeierstrassPoint::from_fields(self.multiplier_vd) } - pub fn individual_metadata_digest_point(&self) -> WeierstrassPoint { - WeierstrassPoint::from_fields(self.individual_md) - } - - pub fn multiplier_metadata_digest_point(&self) -> WeierstrassPoint { - WeierstrassPoint::from_fields(self.multiplier_md) - } - pub fn split_values_digest_point(&self) -> SplitDigestPoint { SplitDigestPoint { individual: weierstrass_to_point(&self.individual_values_digest_point()), @@ -216,11 +202,12 @@ impl<'a> PublicInputs<'a, F> { } } - pub fn split_metadata_digest_point(&self) -> SplitDigestPoint { - SplitDigestPoint { - individual: weierstrass_to_point(&self.individual_metadata_digest_point()), - multiplier: weierstrass_to_point(&self.multiplier_metadata_digest_point()), - } + pub fn individual_counter(&self) -> F { + *self.to_individual_counter_raw() + } + + pub fn multiplier_counter(&self) -> F { + *self.to_multiplier_counter_raw() } } @@ -241,7 +228,7 @@ pub(crate) mod tests { }; use plonky2_ecgfp5::curve::curve::Point; use rand::{thread_rng, Rng}; - use std::array; + use std::slice; impl<'a> PublicInputs<'a, F> { pub(crate) fn sample(is_multiplier: bool) -> Vec { @@ -250,30 +237,20 @@ pub(crate) mod tests { let h = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); let point_zero = WeierstrassPoint::NEUTRAL.to_fields(); - let [values_digest, metadata_digest] = - array::from_fn(|_| Point::sample(rng).to_weierstrass().to_fields()); - let [individual_vd, multiplier_vd, individual_md, multiplier_md] = if is_multiplier { - [ - point_zero.clone(), - values_digest, - point_zero, - metadata_digest, - ] + let values_digest = Point::sample(rng).to_weierstrass().to_fields(); + let [individual_vd, multiplier_vd] = if is_multiplier { + [point_zero.clone(), values_digest] } else { - [ - values_digest, - point_zero.clone(), - metadata_digest, - point_zero, - ] + [values_digest, point_zero] }; + let [individual_cnt, multiplier_cnt] = F::rand_array(); PublicInputs::new( &h, &individual_vd, &multiplier_vd, - &individual_md, - &multiplier_md, + &individual_cnt, + &multiplier_cnt, ) .to_vec() } @@ -324,12 +301,12 @@ pub(crate) mod tests { pi.to_multiplier_values_digest_raw(), ); assert_eq!( - &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::IndividualMetadataDigest)], - pi.to_individual_metadata_digest_raw(), + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::IndividualCounter)], + slice::from_ref(pi.to_individual_counter_raw()), ); assert_eq!( - &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::MultiplierMetadataDigest)], - pi.to_multiplier_metadata_digest_raw(), + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::MultiplierCounter)], + slice::from_ref(pi.to_multiplier_counter_raw()), ); } } diff --git a/verifiable-db/src/row_tree/api.rs b/verifiable-db/src/row_tree/api.rs index 6fd85ffb2..f9273699c 100644 --- a/verifiable-db/src/row_tree/api.rs +++ b/verifiable-db/src/row_tree/api.rs @@ -184,33 +184,19 @@ impl CircuitInput { pub fn leaf( identifier: u64, value: U256, - mpt_metadata: HashOut, row_unique_data: HashOut, cells_proof: Vec, ) -> Result { - Self::leaf_multiplier( - identifier, - value, - false, - mpt_metadata, - row_unique_data, - cells_proof, - ) + Self::leaf_multiplier(identifier, value, false, row_unique_data, cells_proof) } pub fn leaf_multiplier( identifier: u64, value: U256, is_multiplier: bool, - mpt_metadata: HashOut, row_unique_data: HashOut, cells_proof: Vec, ) -> Result { - let cell = Cell::new( - F::from_canonical_u64(identifier), - value, - is_multiplier, - mpt_metadata, - ); + let cell = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); let row = Row::new(cell, row_unique_data); Ok(CircuitInput::Leaf { witness: row.into(), @@ -221,7 +207,6 @@ impl CircuitInput { pub fn full( identifier: u64, value: U256, - mpt_metadata: HashOut, row_unique_data: HashOut, left_proof: Vec, right_proof: Vec, @@ -231,7 +216,6 @@ impl CircuitInput { identifier, value, false, - mpt_metadata, row_unique_data, left_proof, right_proof, @@ -242,18 +226,12 @@ impl CircuitInput { identifier: u64, value: U256, is_multiplier: bool, - mpt_metadata: HashOut, row_unique_data: HashOut, left_proof: Vec, right_proof: Vec, cells_proof: Vec, ) -> Result { - let cell = Cell::new( - F::from_canonical_u64(identifier), - value, - is_multiplier, - mpt_metadata, - ); + let cell = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); let row = Row::new(cell, row_unique_data); Ok(CircuitInput::Full { witness: row.into(), @@ -266,7 +244,6 @@ impl CircuitInput { identifier: u64, value: U256, is_child_left: bool, - mpt_metadata: HashOut, row_unique_data: HashOut, child_proof: Vec, cells_proof: Vec, @@ -276,7 +253,6 @@ impl CircuitInput { value, false, is_child_left, - mpt_metadata, row_unique_data, child_proof, cells_proof, @@ -287,17 +263,11 @@ impl CircuitInput { value: U256, is_multiplier: bool, is_child_left: bool, - mpt_metadata: HashOut, row_unique_data: HashOut, child_proof: Vec, cells_proof: Vec, ) -> Result { - let cell = Cell::new( - F::from_canonical_u64(identifier), - value, - is_multiplier, - mpt_metadata, - ); + let cell = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); let row = Row::new(cell, row_unique_data); let witness = PartialNodeCircuit::new(row, is_child_left); Ok(CircuitInput::Partial { @@ -376,22 +346,10 @@ mod test { params, cells_proof: cells_proof[0].clone(), cells_vk, - leaf1: Row::new( - Cell::new(identifier, v1, false, HashOut::rand()), - HashOut::rand(), - ), - leaf2: Row::new( - Cell::new(identifier, v2, false, HashOut::rand()), - HashOut::rand(), - ), - full: Row::new( - Cell::new(identifier, v_full, false, HashOut::rand()), - HashOut::rand(), - ), - partial: Row::new( - Cell::new(identifier, v_partial, false, HashOut::rand()), - HashOut::rand(), - ), + leaf1: Row::new(Cell::new(identifier, v1, false), HashOut::rand()), + leaf2: Row::new(Cell::new(identifier, v2, false), HashOut::rand()), + full: Row::new(Cell::new(identifier, v_full, false), HashOut::rand()), + partial: Row::new(Cell::new(identifier, v_partial, false), HashOut::rand()), }) } @@ -429,7 +387,6 @@ mod test { let row = &p.partial; let id = row.cell.identifier; let value = row.cell.value; - let mpt_metadata = row.cell.mpt_metadata; let row_unique_data = row.row_unique_data; let row_digest = row.digest(&p.cells_pi()); @@ -444,7 +401,6 @@ mod test { id.to_canonical_u64(), value, is_left, - mpt_metadata, row_unique_data, child_proof_buff.clone(), p.cells_proof_vk().serialize()?, @@ -491,12 +447,12 @@ mod test { pi.multiplier_digest_point(), row_digest.multiplier_vd.to_weierstrass() ); - // Check row ID multiplier - assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); // Check minimum value assert_eq!(pi.min_value(), value.min(child_min)); // Check maximum value assert_eq!(pi.max_value(), value.max(child_max)); + // Check multiplier counter + assert_eq!(pi.multiplier_counter(), row_digest.multiplier_cnt); Ok(vec![]) } @@ -505,14 +461,12 @@ mod test { let row = &p.full; let id = row.cell.identifier; let value = row.cell.value; - let mpt_metadata = row.cell.mpt_metadata; let row_unique_data = row.row_unique_data; let row_digest = row.digest(&p.cells_pi()); let input = CircuitInput::full( id.to_canonical_u64(), value, - mpt_metadata, row_unique_data, child_proof[0].to_vec(), child_proof[1].to_vec(), @@ -557,8 +511,8 @@ mod test { pi.multiplier_digest_point(), row_digest.multiplier_vd.to_weierstrass() ); - // Check row ID multiplier - assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); + // Check multiplier counter + assert_eq!(pi.multiplier_counter(), row_digest.multiplier_cnt); Ok(proof) } @@ -566,7 +520,6 @@ mod test { fn generate_leaf_proof(p: &TestParams, row: &Row) -> Result> { let id = row.cell.identifier; let value = row.cell.value; - let mpt_metadata = row.cell.mpt_metadata; let row_unique_data = row.row_unique_data; let row_digest = row.digest(&p.cells_pi()); @@ -574,7 +527,6 @@ mod test { let input = CircuitInput::leaf( id.to_canonical_u64(), value, - mpt_metadata, row_unique_data, p.cells_proof_vk().serialize()?, )?; @@ -615,12 +567,12 @@ mod test { pi.multiplier_digest_point(), row_digest.multiplier_vd.to_weierstrass() ); - // Check row ID multiplier - assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); // Check minimum value assert_eq!(pi.min_value(), value); // Check maximum value assert_eq!(pi.max_value(), value); + // Check multiplier counter + assert_eq!(pi.multiplier_counter(), row_digest.multiplier_cnt); Ok(proof) } diff --git a/verifiable-db/src/row_tree/full_node.rs b/verifiable-db/src/row_tree/full_node.rs index 36c5a3760..633910b0d 100644 --- a/verifiable-db/src/row_tree/full_node.rs +++ b/verifiable-db/src/row_tree/full_node.rs @@ -9,7 +9,6 @@ use plonky2::{ iop::{target::Target, witness::PartialWitness}, plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, }; -use plonky2_ecdsa::gadgets::biguint::CircuitBuilderBiguint; use recursion_framework::{ circuit_builder::CircuitLogicWires, framework::{ @@ -44,19 +43,13 @@ impl FullNodeCircuit { let value = row.value(); let digest = row.digest(b, &cells_pi); - // Check multiplier_vd and row_id_multiplier are the same as children proofs. + // Check multiplier_vd and multiplier_counter are the same as children proofs. // assert multiplier_vd == p1.multiplier_vd == p2.multiplier_vd b.connect_curve_points(digest.multiplier_vd, min_child.multiplier_digest_target()); b.connect_curve_points(digest.multiplier_vd, max_child.multiplier_digest_target()); - // assert row_id_multiplier == p1.row_id_multiplier == p2.row_id_multiplier - b.connect_biguint( - &digest.row_id_multiplier, - &min_child.row_id_multiplier_target(), - ); - b.connect_biguint( - &digest.row_id_multiplier, - &max_child.row_id_multiplier_target(), - ); + // assert multiplier_counter == p1.multiplier_counter == p2.multiplier_counter + b.connect(digest.multiplier_cnt, min_child.multiplier_counter_target()); + b.connect(digest.multiplier_cnt, max_child.multiplier_counter_target()); let node_min = min_child.min_value_target(); let node_max = max_child.max_value_target(); @@ -85,9 +78,9 @@ impl FullNodeCircuit { &hash.to_targets(), &digest.individual_vd.to_targets(), &digest.multiplier_vd.to_targets(), - &digest.row_id_multiplier.to_targets(), &node_min.to_targets(), &node_max.to_targets(), + &digest.multiplier_cnt, ) .register(b); FullNodeWires(row) @@ -154,7 +147,7 @@ pub(crate) mod test { use itertools::Itertools; use mp2_common::{utils::ToFields, C, D, F}; use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; + use plonky2::{field::types::PrimeField64, iop::witness::WitnessWrite, plonk::config::Hasher}; #[derive(Clone, Debug)] struct TestFullNodeCircuit { @@ -199,17 +192,14 @@ pub(crate) mod test { let (left_min, left_max) = (10, 15); // this should work since we allow multipleicities of indexes in the row tree let (right_min, right_max) = (18, 30); - let left_pi = PublicInputs::sample( - row_digest.multiplier_vd, - row_digest.row_id_multiplier.clone(), - left_min, - left_max, - ); + let multiplier_cnt = row_digest.multiplier_cnt.to_canonical_u64(); + let left_pi = + PublicInputs::sample(row_digest.multiplier_vd, left_min, left_max, multiplier_cnt); let right_pi = PublicInputs::sample( row_digest.multiplier_vd, - row_digest.row_id_multiplier.clone(), right_min, right_max, + multiplier_cnt, ); let test_circuit = TestFullNodeCircuit { circuit: node_circuit, @@ -250,12 +240,12 @@ pub(crate) mod test { pi.multiplier_digest_point(), row_digest.multiplier_vd.to_weierstrass() ); - // Check row ID multiplier - assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); // Check minimum value assert_eq!(pi.min_value(), U256::from(left_min)); // Check maximum value assert_eq!(pi.max_value(), U256::from(right_max)); + // Check multiplier counter + assert_eq!(pi.multiplier_counter(), row_digest.multiplier_cnt); } #[test] diff --git a/verifiable-db/src/row_tree/leaf.rs b/verifiable-db/src/row_tree/leaf.rs index 4738c537a..5ffa4ea87 100644 --- a/verifiable-db/src/row_tree/leaf.rs +++ b/verifiable-db/src/row_tree/leaf.rs @@ -60,9 +60,9 @@ impl LeafCircuit { &row_hash.elements, &digest.individual_vd.to_targets(), &digest.multiplier_vd.to_targets(), - &digest.row_id_multiplier.to_targets(), &value, &value, + &digest.multiplier_cnt, ) .register(b); @@ -206,7 +206,7 @@ mod test { row_digest.multiplier_vd.to_weierstrass() ); // Check row ID multiplier - assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); + assert_eq!(pi.multiplier_counter(), row_digest.multiplier_cnt); // Check minimum value assert_eq!(pi.min_value(), value); // Check maximum value diff --git a/verifiable-db/src/row_tree/partial_node.rs b/verifiable-db/src/row_tree/partial_node.rs index 047bab775..ef631614a 100644 --- a/verifiable-db/src/row_tree/partial_node.rs +++ b/verifiable-db/src/row_tree/partial_node.rs @@ -20,7 +20,6 @@ use plonky2::{ }, plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, }; -use plonky2_ecdsa::gadgets::biguint::CircuitBuilderBiguint; use recursion_framework::{ circuit_builder::CircuitLogicWires, framework::{ @@ -64,14 +63,11 @@ impl PartialNodeCircuit { let value = row.value(); let digest = row.digest(b, &cells_pi); - // Check multiplier_vd and row_id_multiplier are the same as child proof - // assert multiplier_vd == child_proof.multiplier_vd + // Check multiplier_vd and multiplier_counter are the same as children proof. + // assert multiplier_vd == p.multiplier_vd b.connect_curve_points(digest.multiplier_vd, child_pi.multiplier_digest_target()); - //assert row_id_multiplier == child_proof.row_id_multiplier - b.connect_biguint( - &digest.row_id_multiplier, - &child_pi.row_id_multiplier_target(), - ); + // assert multiplier_counter == p.multiplier_counter + b.connect(digest.multiplier_cnt, child_pi.multiplier_counter_target()); // bool target range checked in poseidon gate let is_child_at_left = b.add_virtual_bool_target_unsafe(); @@ -116,9 +112,9 @@ impl PartialNodeCircuit { &node_hash, &digest.individual_vd.to_targets(), &digest.multiplier_vd.to_targets(), - &digest.row_id_multiplier.to_targets(), &node_min.to_targets(), &node_max.to_targets(), + &digest.multiplier_cnt, ) .register(b); PartialNodeWires { @@ -193,7 +189,7 @@ pub mod test { C, D, F, }; use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::plonk::config::Hasher; + use plonky2::{field::types::PrimeField64, plonk::config::Hasher}; use std::iter::once; #[derive(Clone, Debug)] @@ -292,9 +288,9 @@ pub mod test { let node_circuit = PartialNodeCircuit::new(row.clone(), child_at_left); let child_pi = PublicInputs::sample( row_digest.multiplier_vd, - row_digest.row_id_multiplier.clone(), child_min.to(), child_max.to(), + row_digest.multiplier_cnt.to_canonical_u64(), ); let test_circuit = TestPartialNodeCircuit { circuit: node_circuit, @@ -343,11 +339,11 @@ pub mod test { pi.multiplier_digest_point(), row_digest.multiplier_vd.to_weierstrass() ); - // Check row ID multiplier - assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); // Check minimum value assert_eq!(pi.min_value(), value.min(child_min)); // Check maximum value assert_eq!(pi.max_value(), value.max(child_max)); + // Check multiplier counter + assert_eq!(pi.multiplier_counter(), row_digest.multiplier_cnt); } } diff --git a/verifiable-db/src/row_tree/public_inputs.rs b/verifiable-db/src/row_tree/public_inputs.rs index 3415ca0bd..192f2fbf2 100644 --- a/verifiable-db/src/row_tree/public_inputs.rs +++ b/verifiable-db/src/row_tree/public_inputs.rs @@ -1,24 +1,19 @@ //! Public inputs for rows trees creation circuits use alloy::primitives::U256; -use itertools::Itertools; use mp2_common::{ - poseidon::HASH_TO_INT_LEN, public_inputs::{PublicInputCommon, PublicInputRange}, types::{CBuilder, CURVE_TARGET_LEN}, u256::{self, UInt256Target}, utils::{FromFields, FromTargets}, F, }; -use num::BigUint; use plonky2::{ - field::types::PrimeField64, hash::hash_types::{HashOut, NUM_HASH_OUT_ELTS}, iop::target::Target, }; -use plonky2_crypto::u32::arithmetic_u32::U32Target; -use plonky2_ecdsa::gadgets::biguint::BigUintTarget; use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; +use std::iter::once; pub enum RowsTreePublicInputs { // `H : F[4]` - Poseidon hash of the leaf @@ -27,12 +22,12 @@ pub enum RowsTreePublicInputs { IndividualDigest, // `multiplier_digest : Digest` - Cumulative digest of the values of the cells which are accumulated in multiplier digest MultiplierDigest, - // `row_id_multiplier : F[4]` - `H2Int(H("") || multiplier_md)`, where `multiplier_md` is the metadata digest of cells accumulated in `multiplier_digest` - RowIdMultiplier, // `min : Uint256` - Minimum alue of the secondary index stored up to this node MinValue, // `max : Uint256` - Maximum value of the secondary index stored up to this node MaxValue, + // `multiplier_counter : F` - Number of cells accumulated as multiplier + MultiplierCounter, } /// Public inputs for Rows Tree Construction @@ -41,21 +36,21 @@ pub struct PublicInputs<'a, T> { pub(crate) h: &'a [T], pub(crate) individual_digest: &'a [T], pub(crate) multiplier_digest: &'a [T], - pub(crate) row_id_multiplier: &'a [T], pub(crate) min: &'a [T], pub(crate) max: &'a [T], + pub(crate) multiplier_cnt: &'a T, } -const NUM_PUBLIC_INPUTS: usize = RowsTreePublicInputs::MaxValue as usize + 1; +const NUM_PUBLIC_INPUTS: usize = RowsTreePublicInputs::MultiplierCounter as usize + 1; impl<'a, T: Clone> PublicInputs<'a, T> { const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ Self::to_range(RowsTreePublicInputs::RootHash), Self::to_range(RowsTreePublicInputs::IndividualDigest), Self::to_range(RowsTreePublicInputs::MultiplierDigest), - Self::to_range(RowsTreePublicInputs::RowIdMultiplier), Self::to_range(RowsTreePublicInputs::MinValue), Self::to_range(RowsTreePublicInputs::MaxValue), + Self::to_range(RowsTreePublicInputs::MultiplierCounter), ]; const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ @@ -65,12 +60,12 @@ impl<'a, T: Clone> PublicInputs<'a, T> { CURVE_TARGET_LEN, // Cumulative digest of the values of the cells which are accumulated in multiplier digest CURVE_TARGET_LEN, - // `H2Int(H("") || multiplier_md)`, where `multiplier_md` is the metadata digest of cells accumulated in `multiplier_digest` - HASH_TO_INT_LEN, // Minimum value of the secondary index stored up to this node u256::NUM_LIMBS, // Maximum value of the secondary index stored up to this node u256::NUM_LIMBS, + // Counter of the number of cells accumulated so far as multiplier + 1, ]; pub(crate) const fn to_range(pi: RowsTreePublicInputs) -> PublicInputRange { @@ -85,7 +80,7 @@ impl<'a, T: Clone> PublicInputs<'a, T> { } pub const fn total_len() -> usize { - Self::to_range(RowsTreePublicInputs::MaxValue).end + Self::to_range(RowsTreePublicInputs::MultiplierCounter).end } pub fn to_root_hash_raw(&self) -> &[T] { @@ -100,10 +95,6 @@ impl<'a, T: Clone> PublicInputs<'a, T> { self.multiplier_digest } - pub fn to_row_id_multiplier_raw(&self) -> &[T] { - self.row_id_multiplier - } - pub fn to_min_value_raw(&self) -> &[T] { self.min } @@ -112,6 +103,10 @@ impl<'a, T: Clone> PublicInputs<'a, T> { self.max } + pub fn to_multiplier_counter_raw(&self) -> &T { + self.multiplier_cnt + } + pub fn from_slice(input: &'a [T]) -> Self { assert!( input.len() >= Self::total_len(), @@ -123,9 +118,9 @@ impl<'a, T: Clone> PublicInputs<'a, T> { h: &input[Self::PI_RANGES[0].clone()], individual_digest: &input[Self::PI_RANGES[1].clone()], multiplier_digest: &input[Self::PI_RANGES[2].clone()], - row_id_multiplier: &input[Self::PI_RANGES[3].clone()], - min: &input[Self::PI_RANGES[4].clone()], - max: &input[Self::PI_RANGES[5].clone()], + min: &input[Self::PI_RANGES[3].clone()], + max: &input[Self::PI_RANGES[4].clone()], + multiplier_cnt: &input[Self::PI_RANGES[5].clone()][0], } } @@ -133,17 +128,17 @@ impl<'a, T: Clone> PublicInputs<'a, T> { h: &'a [T], individual_digest: &'a [T], multiplier_digest: &'a [T], - row_id_multiplier: &'a [T], min: &'a [T], max: &'a [T], + multiplier_cnt: &'a T, ) -> Self { Self { h, individual_digest, multiplier_digest, - row_id_multiplier, min, max, + multiplier_cnt, } } @@ -152,9 +147,9 @@ impl<'a, T: Clone> PublicInputs<'a, T> { .iter() .chain(self.individual_digest) .chain(self.multiplier_digest) - .chain(self.row_id_multiplier) .chain(self.min) .chain(self.max) + .chain(once(self.multiplier_cnt)) .cloned() .collect() } @@ -167,9 +162,9 @@ impl<'a> PublicInputCommon for PublicInputs<'a, Target> { cb.register_public_inputs(self.h); cb.register_public_inputs(self.individual_digest); cb.register_public_inputs(self.multiplier_digest); - cb.register_public_inputs(self.row_id_multiplier); cb.register_public_inputs(self.min); cb.register_public_inputs(self.max); + cb.register_public_input(*self.multiplier_cnt); } } @@ -186,17 +181,6 @@ impl<'a> PublicInputs<'a, Target> { CurveTarget::from_targets(self.multiplier_digest) } - pub fn row_id_multiplier_target(&self) -> BigUintTarget { - let limbs = self - .row_id_multiplier - .iter() - .cloned() - .map(U32Target) - .collect(); - - BigUintTarget { limbs } - } - pub fn min_value_target(&self) -> UInt256Target { UInt256Target::from_targets(self.min) } @@ -204,6 +188,10 @@ impl<'a> PublicInputs<'a, Target> { pub fn max_value_target(&self) -> UInt256Target { UInt256Target::from_targets(self.max) } + + pub fn multiplier_counter_target(&self) -> Target { + *self.to_multiplier_counter_raw() + } } impl<'a> PublicInputs<'a, F> { @@ -219,16 +207,6 @@ impl<'a> PublicInputs<'a, F> { WeierstrassPoint::from_fields(self.multiplier_digest) } - pub fn row_id_multiplier(&self) -> BigUint { - let limbs = self - .row_id_multiplier - .iter() - .map(|f| u32::try_from(f.to_canonical_u64()).unwrap()) - .collect_vec(); - - BigUint::from_slice(&limbs) - } - pub fn min_value(&self) -> U256 { U256::from_fields(self.min) } @@ -236,16 +214,17 @@ impl<'a> PublicInputs<'a, F> { pub fn max_value(&self) -> U256 { U256::from_fields(self.max) } + + pub fn multiplier_counter(&self) -> F { + *self.to_multiplier_counter_raw() + } } #[cfg(test)] pub(crate) mod tests { use super::*; use mp2_common::{utils::ToFields, C, D, F}; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::random_vector, - }; + use mp2_test::circuit::{run_circuit, UserCircuit}; use plonky2::{ field::types::{Field, Sample}, iop::{ @@ -255,32 +234,28 @@ pub(crate) mod tests { }; use plonky2_ecgfp5::curve::curve::Point; use rand::{thread_rng, Rng}; - use std::array; + use std::{array, slice}; impl<'a> PublicInputs<'a, F> { pub(crate) fn sample( multiplier_digest: Point, - row_id_multiplier: BigUint, min: usize, max: usize, + multiplier_cnt: u64, ) -> Vec { let h = HashOut::rand().to_fields(); let individual_digest = Point::rand(); let [individual_digest, multiplier_digest] = [individual_digest, multiplier_digest].map(|p| p.to_weierstrass().to_fields()); - let row_id_multiplier = row_id_multiplier - .to_u32_digits() - .into_iter() - .map(F::from_canonical_u32) - .collect_vec(); let [min, max] = [min, max].map(|v| U256::from(v).to_fields()); + let multiplier_cnt = F::from_canonical_u64(multiplier_cnt); PublicInputs::new( &h, &individual_digest, &multiplier_digest, - &row_id_multiplier, &min, &max, + &multiplier_cnt, ) .to_vec() } @@ -312,9 +287,9 @@ pub(crate) mod tests { // Prepare the public inputs. let multiplier_digest = Point::sample(rng); - let row_id_multiplier = BigUint::from_slice(&random_vector::(HASH_TO_INT_LEN)); let [min, max] = array::from_fn(|_| rng.gen()); - let exp_pi = PublicInputs::sample(multiplier_digest, row_id_multiplier, min, max); + let multiplier_cnt = rng.gen(); + let exp_pi = PublicInputs::sample(multiplier_digest, min, max, multiplier_cnt); let exp_pi = &exp_pi.to_vec(); let test_circuit = TestPublicInputs { exp_pi }; @@ -335,10 +310,6 @@ pub(crate) mod tests { &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MultiplierDigest)], pi.to_multiplier_digest_raw(), ); - assert_eq!( - &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::RowIdMultiplier)], - pi.to_row_id_multiplier_raw(), - ); assert_eq!( &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MinValue)], pi.to_min_value_raw(), @@ -347,5 +318,9 @@ pub(crate) mod tests { &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MaxValue)], pi.to_max_value_raw(), ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MultiplierCounter)], + slice::from_ref(pi.to_multiplier_counter_raw()), + ); } } diff --git a/verifiable-db/src/row_tree/row.rs b/verifiable-db/src/row_tree/row.rs index ca57a86c4..0926d0a5a 100644 --- a/verifiable-db/src/row_tree/row.rs +++ b/verifiable-db/src/row_tree/row.rs @@ -4,16 +4,15 @@ use crate::cells_tree::{Cell, CellWire, PublicInputs as CellsPublicInputs}; use derive_more::Constructor; use itertools::Itertools; use mp2_common::{ - poseidon::{empty_poseidon_hash, hash_to_int_target, hash_to_int_value, H, HASH_TO_INT_LEN}, + poseidon::{hash_to_int_target, hash_to_int_value, H}, serialization::{deserialize, serialize}, types::{CBuilder, CURVE_TARGET_LEN}, u256::UInt256Target, utils::{FromFields, ToFields, ToTargets}, F, }; -use num::BigUint; use plonky2::{ - field::types::{Field, PrimeField64}, + field::types::Field, hash::hash_types::{HashOut, HashOutTarget}, iop::{ target::Target, @@ -21,16 +20,17 @@ use plonky2::{ }, plonk::config::Hasher, }; -use plonky2_ecdsa::gadgets::{biguint::BigUintTarget, nonnative::CircuitBuilderNonNative}; +use plonky2_ecdsa::gadgets::nonnative::CircuitBuilderNonNative; use plonky2_ecgfp5::{ curve::{curve::Point, scalar_field::Scalar}, gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}, }; use serde::{Deserialize, Serialize}; +use std::iter::once; #[derive(Clone, Debug, Eq, PartialEq)] pub(crate) struct RowDigest { - pub(crate) row_id_multiplier: BigUint, + pub(crate) multiplier_cnt: F, pub(crate) individual_vd: Point, pub(crate) multiplier_vd: Point, } @@ -39,13 +39,8 @@ impl FromFields for RowDigest { fn from_fields(t: &[F]) -> Self { let mut pos = 0; - let row_id_multiplier = BigUint::new( - t[pos..pos + HASH_TO_INT_LEN] - .iter() - .map(|f| u32::try_from(f.to_canonical_u64()).unwrap()) - .collect_vec(), - ); - pos += HASH_TO_INT_LEN; + let multiplier_cnt = t[0]; + pos += 1; let individual_vd = Point::from_fields(&t[pos..pos + CURVE_TARGET_LEN]); pos += CURVE_TARGET_LEN; @@ -53,7 +48,7 @@ impl FromFields for RowDigest { let multiplier_vd = Point::from_fields(&t[pos..pos + CURVE_TARGET_LEN]); Self { - row_id_multiplier, + multiplier_cnt, individual_vd, multiplier_vd, } @@ -62,7 +57,7 @@ impl FromFields for RowDigest { #[derive(Clone, Debug)] pub(crate) struct RowDigestTarget { - pub(crate) row_id_multiplier: BigUintTarget, + pub(crate) multiplier_cnt: Target, pub(crate) individual_vd: CurveTarget, pub(crate) multiplier_vd: CurveTarget, } @@ -79,20 +74,42 @@ impl Row { pw.set_hash_target(wires.row_unique_data, self.row_unique_data); } - pub(crate) fn digest(&self, cells_pi: &CellsPublicInputs) -> RowDigest { - let metadata_digests = self.cell.split_metadata_digest(); - let values_digests = self.cell.split_values_digest(); + pub fn is_individual(&self) -> bool { + self.cell.is_individual() + } + + pub fn is_multiplier(&self) -> bool { + self.cell.is_multiplier() + } - let metadata_digests = metadata_digests.accumulate(&cells_pi.split_metadata_digest_point()); - let values_digests = values_digests.accumulate(&cells_pi.split_values_digest_point()); + pub(crate) fn digest(&self, cells_pi: &CellsPublicInputs) -> RowDigest { + let values_digests = self + .cell + .split_and_accumulate_values_digest(cells_pi.split_values_digest_point()); + + // individual_counter = p.individual_counter + is_individual + let individual_cnt = cells_pi.individual_counter() + + if self.cell.is_individual() { + F::ONE + } else { + F::ZERO + }; + + // multiplier_counter = p.multiplier_counter + not is_individual + let multiplier_cnt = cells_pi.multiplier_counter() + + if self.cell.is_multiplier() { + F::ONE + } else { + F::ZERO + }; // Compute row ID for individual cells: - // row_id_individual = H2Int(row_unique_data || individual_md) + // row_id_individual = H2Int(row_unique_data || individual_counter) let inputs = self .row_unique_data .to_fields() .into_iter() - .chain(metadata_digests.individual.to_fields()) + .chain(once(individual_cnt)) .collect_vec(); let hash = H::hash_no_pad(&inputs); let row_id_individual = hash_to_int_value(hash); @@ -102,22 +119,10 @@ impl Row { // individual_vd = row_id_individual * individual_vd let individual_vd = values_digests.individual * row_id_individual; - // Multiplier is always employed for set of scalar variables, and `row_unique_data` - // for such a set is always `H("")``, so we can hardocode it in the circuit: - // row_id_multiplier = H2Int(H("") || multiplier_md) - let empty_hash = empty_poseidon_hash(); - let inputs = empty_hash - .to_fields() - .into_iter() - .chain(metadata_digests.multiplier.to_fields()) - .collect_vec(); - let hash = H::hash_no_pad(&inputs); - let row_id_multiplier = hash_to_int_value(hash); - let multiplier_vd = values_digests.multiplier; RowDigest { - row_id_multiplier, + multiplier_cnt, individual_vd, multiplier_vd, } @@ -152,20 +157,25 @@ impl RowWire { b: &mut CBuilder, cells_pi: &CellsPublicInputs, ) -> RowDigestTarget { - let metadata_digests = self.cell.split_metadata_digest(b); - let values_digests = self.cell.split_values_digest(b); + let values_digests = self + .cell + .split_and_accumulate_values_digest(b, &cells_pi.split_values_digest_target()); - let metadata_digests = - metadata_digests.accumulate(b, &cells_pi.split_metadata_digest_target()); - let values_digests = values_digests.accumulate(b, &cells_pi.split_values_digest_target()); + // individual_counter = p.individual_counter + is_individual + let is_individual = self.cell.is_individual(b); + let individual_cnt = b.add(cells_pi.individual_counter_target(), is_individual.target); + + // multiplier_counter = p.multiplier_counter + not is_individual + let is_multiplier = self.cell.is_multiplier(); + let multiplier_cnt = b.add(cells_pi.multiplier_counter_target(), is_multiplier.target); // Compute row ID for individual cells: - // row_id_individual = H2Int(row_unique_data || individual_md) + // row_id_individual = H2Int(row_unique_data || individual_counter) let inputs = self .row_unique_data .to_targets() .into_iter() - .chain(metadata_digests.individual.to_targets()) + .chain(once(individual_cnt)) .collect(); let hash = b.hash_n_to_hash_no_pad::(inputs); let row_id_individual = hash_to_int_target(b, hash); @@ -175,23 +185,10 @@ impl RowWire { // individual_vd = row_id_individual * individual_vd let individual_vd = b.curve_scalar_mul(values_digests.individual, &row_id_individual); - // Multiplier is always employed for set of scalar variables, and `row_unique_data` - // for such a set is always `H("")``, so we can hardocode it in the circuit: - // row_id_multiplier = H2Int(H("") || multiplier_md) - let empty_hash = b.constant_hash(*empty_poseidon_hash()); - let inputs = empty_hash - .to_targets() - .into_iter() - .chain(metadata_digests.multiplier.to_targets()) - .collect(); - let hash = b.hash_n_to_hash_no_pad::(inputs); - let row_id_multiplier = hash_to_int_target(b, hash); - assert_eq!(row_id_multiplier.num_limbs(), HASH_TO_INT_LEN); - let multiplier_vd = values_digests.multiplier; RowDigestTarget { - row_id_multiplier, + multiplier_cnt, individual_vd, multiplier_vd, } @@ -232,7 +229,7 @@ pub(crate) mod tests { let digest = row.digest(b, &cells_pi); - b.register_public_inputs(&digest.row_id_multiplier.to_targets()); + b.register_public_inputs(&digest.multiplier_cnt.to_targets()); b.register_public_inputs(&digest.individual_vd.to_targets()); b.register_public_inputs(&digest.multiplier_vd.to_targets()); From aa84a07f9b0b2a750c3e9b74ead6c5aa2ab5634b Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Thu, 31 Oct 2024 20:27:59 +0800 Subject: [PATCH 2/5] Remove `child_metadata_digest` in the cells tree test. --- verifiable-db/src/cells_tree/mod.rs | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/verifiable-db/src/cells_tree/mod.rs b/verifiable-db/src/cells_tree/mod.rs index 3ac65fc1b..867d9f018 100644 --- a/verifiable-db/src/cells_tree/mod.rs +++ b/verifiable-db/src/cells_tree/mod.rs @@ -150,25 +150,20 @@ pub(crate) mod tests { struct TestCellCircuit<'a> { cell: &'a Cell, child_values_digest: &'a SplitDigestPoint, - child_metadata_digest: &'a SplitDigestPoint, } impl<'a> UserCircuit for TestCellCircuit<'a> { // Cell wire + child values digest + child metadata digest - type Wires = (CellWire, SplitDigestTarget, SplitDigestTarget); + type Wires = (CellWire, SplitDigestTarget); fn build(b: &mut CBuilder) -> Self::Wires { - let [values_individual, values_multiplier, metadata_individual, metadata_multiplier] = + let [values_individual, values_multiplier] = array::from_fn(|_| b.add_virtual_curve_target()); let child_values_digest = SplitDigestTarget { individual: values_individual, multiplier: values_multiplier, }; - let child_metadata_digest = SplitDigestTarget { - individual: metadata_individual, - multiplier: metadata_multiplier, - }; let cell = CellWire::new(b); let values_digest = cell.split_and_accumulate_values_digest(b, &child_values_digest); @@ -176,7 +171,7 @@ pub(crate) mod tests { b.register_curve_public_input(values_digest.individual); b.register_curve_public_input(values_digest.multiplier); - (cell, child_values_digest, child_metadata_digest) + (cell, child_values_digest) } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { @@ -189,14 +184,6 @@ pub(crate) mod tests { wires.1.multiplier, self.child_values_digest.multiplier.to_weierstrass(), ); - pw.set_curve_target( - wires.2.individual, - self.child_metadata_digest.individual.to_weierstrass(), - ); - pw.set_curve_target( - wires.2.multiplier, - self.child_metadata_digest.multiplier.to_weierstrass(), - ); } } @@ -204,16 +191,11 @@ pub(crate) mod tests { fn test_cells_tree_cell_circuit() { let rng = &mut thread_rng(); - let [values_individual, values_multiplier, metadata_individual, metadata_multiplier] = - array::from_fn(|_| Point::sample(rng)); + let [values_individual, values_multiplier] = array::from_fn(|_| Point::sample(rng)); let child_values_digest = &SplitDigestPoint { individual: values_individual, multiplier: values_multiplier, }; - let child_metadata_digest = &SplitDigestPoint { - individual: metadata_individual, - multiplier: metadata_multiplier, - }; let cell = &Cell::sample(rng.gen()); let values_digests = cell.split_values_digest(); @@ -222,7 +204,6 @@ pub(crate) mod tests { let test_circuit = TestCellCircuit { cell, child_values_digest, - child_metadata_digest, }; let proof = run_circuit::(test_circuit); From 3a686f95d6c1f6aceb2ac0a8ef87cd1905041725 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Thu, 31 Oct 2024 20:37:49 +0800 Subject: [PATCH 3/5] Fix to use `F::from_bool`. --- verifiable-db/src/cells_tree/api.rs | 24 ++++++++++---------- verifiable-db/src/cells_tree/full_node.rs | 12 +++++----- verifiable-db/src/cells_tree/leaf.rs | 6 ++--- verifiable-db/src/cells_tree/partial_node.rs | 6 ++--- verifiable-db/src/row_tree/row.rs | 16 ++++--------- 5 files changed, 28 insertions(+), 36 deletions(-) diff --git a/verifiable-db/src/cells_tree/api.rs b/verifiable-db/src/cells_tree/api.rs index 3eb707a43..63626601c 100644 --- a/verifiable-db/src/cells_tree/api.rs +++ b/verifiable-db/src/cells_tree/api.rs @@ -326,10 +326,10 @@ mod tests { values_digests.multiplier.to_weierstrass(), ); // Check individual counter - let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; - assert_eq!(pi.individual_counter(), individual_cnt); + let multiplier_cnt = F::from_bool(is_multiplier); + assert_eq!(pi.individual_counter(), F::ONE - multiplier_cnt); // Check multiplier counter - assert_eq!(pi.multiplier_counter(), F::ONE - individual_cnt); + assert_eq!(pi.multiplier_counter(), multiplier_cnt); proof } @@ -425,18 +425,18 @@ mod tests { values_digests.multiplier.to_weierstrass(), ); // Check individual counter - let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; + let multiplier_cnt = F::from_bool(is_multiplier); assert_eq!( pi.individual_counter(), - child_pis - .iter() - .fold(individual_cnt, |acc, pi| acc + pi.individual_counter()), + child_pis.iter().fold(F::ONE - multiplier_cnt, |acc, pi| acc + + pi.individual_counter()), ); // Check multiplier counter assert_eq!( pi.multiplier_counter(), - child_pis.iter().fold(F::ONE - individual_cnt, |acc, pi| acc - + pi.multiplier_counter()), + child_pis + .iter() + .fold(multiplier_cnt, |acc, pi| acc + pi.multiplier_counter()), ); proof @@ -497,15 +497,15 @@ mod tests { values_digests.multiplier.to_weierstrass(), ); // Check individual counter - let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; + let multiplier_cnt = F::from_bool(is_multiplier); assert_eq!( pi.individual_counter(), - individual_cnt + child_pi.individual_counter(), + F::ONE - multiplier_cnt + child_pi.individual_counter(), ); // Check multiplier counter assert_eq!( pi.multiplier_counter(), - F::ONE - individual_cnt + child_pi.multiplier_counter(), + multiplier_cnt + child_pi.multiplier_counter(), ); proof diff --git a/verifiable-db/src/cells_tree/full_node.rs b/verifiable-db/src/cells_tree/full_node.rs index 79a87df95..ccfac89ef 100644 --- a/verifiable-db/src/cells_tree/full_node.rs +++ b/verifiable-db/src/cells_tree/full_node.rs @@ -218,18 +218,18 @@ mod tests { values_digests.multiplier.to_weierstrass(), ); // Check individual counter - let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; + let multiplier_cnt = F::from_bool(is_multiplier); assert_eq!( pi.individual_counter(), - child_pis - .iter() - .fold(individual_cnt, |acc, pi| acc + pi.individual_counter()), + child_pis.iter().fold(F::ONE - multiplier_cnt, |acc, pi| acc + + pi.individual_counter()), ); // Check multiplier counter assert_eq!( pi.multiplier_counter(), - child_pis.iter().fold(F::ONE - individual_cnt, |acc, pi| acc - + pi.multiplier_counter()), + child_pis + .iter() + .fold(multiplier_cnt, |acc, pi| acc + pi.multiplier_counter()), ); } } diff --git a/verifiable-db/src/cells_tree/leaf.rs b/verifiable-db/src/cells_tree/leaf.rs index 564f22a17..5d97f69c8 100644 --- a/verifiable-db/src/cells_tree/leaf.rs +++ b/verifiable-db/src/cells_tree/leaf.rs @@ -144,9 +144,9 @@ mod tests { values_digests.multiplier.to_weierstrass(), ); // Check individual counter - let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; - assert_eq!(pi.individual_counter(), individual_cnt); + let multiplier_cnt = F::from_bool(is_multiplier); + assert_eq!(pi.individual_counter(), F::ONE - multiplier_cnt); // Check multiplier counter - assert_eq!(pi.multiplier_counter(), F::ONE - individual_cnt); + assert_eq!(pi.multiplier_counter(), multiplier_cnt); } } diff --git a/verifiable-db/src/cells_tree/partial_node.rs b/verifiable-db/src/cells_tree/partial_node.rs index d8c081b75..ca1c236a1 100644 --- a/verifiable-db/src/cells_tree/partial_node.rs +++ b/verifiable-db/src/cells_tree/partial_node.rs @@ -181,15 +181,15 @@ mod tests { values_digests.multiplier.to_weierstrass(), ); // Check individual counter - let individual_cnt = if is_multiplier { F::ZERO } else { F::ONE }; + let multiplier_cnt = F::from_bool(is_multiplier); assert_eq!( pi.individual_counter(), - individual_cnt + child_pi.individual_counter(), + F::ONE - multiplier_cnt + child_pi.individual_counter(), ); // Check multiplier counter assert_eq!( pi.multiplier_counter(), - F::ONE - individual_cnt + child_pi.multiplier_counter(), + multiplier_cnt + child_pi.multiplier_counter(), ); } } diff --git a/verifiable-db/src/row_tree/row.rs b/verifiable-db/src/row_tree/row.rs index 0926d0a5a..c07f9f556 100644 --- a/verifiable-db/src/row_tree/row.rs +++ b/verifiable-db/src/row_tree/row.rs @@ -88,20 +88,12 @@ impl Row { .split_and_accumulate_values_digest(cells_pi.split_values_digest_point()); // individual_counter = p.individual_counter + is_individual - let individual_cnt = cells_pi.individual_counter() - + if self.cell.is_individual() { - F::ONE - } else { - F::ZERO - }; + let individual_cnt = + cells_pi.individual_counter() + F::from_bool(self.cell.is_individual()); // multiplier_counter = p.multiplier_counter + not is_individual - let multiplier_cnt = cells_pi.multiplier_counter() - + if self.cell.is_multiplier() { - F::ONE - } else { - F::ZERO - }; + let multiplier_cnt = + cells_pi.multiplier_counter() + F::from_bool(self.cell.is_multiplier()); // Compute row ID for individual cells: // row_id_individual = H2Int(row_unique_data || individual_counter) From d5c9e24f00020105ae39154c669b325c6cd2b7b1 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Thu, 31 Oct 2024 21:05:36 +0800 Subject: [PATCH 4/5] Replace `HashOut` with `HashOutput` in the API functions. --- mp2-common/src/types.rs | 12 ++++++++++++ verifiable-db/src/row_tree/api.rs | 28 ++++++++++++++-------------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/mp2-common/src/types.rs b/mp2-common/src/types.rs index 2b9d056a0..1855bd8d4 100644 --- a/mp2-common/src/types.rs +++ b/mp2-common/src/types.rs @@ -111,3 +111,15 @@ impl From> for HashOutput { value.to_bytes().try_into().unwrap() } } + +impl From for HashOut { + fn from(value: HashOutput) -> Self { + Self::from_bytes(&value.0) + } +} + +impl From<&HashOutput> for HashOut { + fn from(value: &HashOutput) -> Self { + Self::from_bytes(&value.0) + } +} diff --git a/verifiable-db/src/row_tree/api.rs b/verifiable-db/src/row_tree/api.rs index f9273699c..f6b59fc66 100644 --- a/verifiable-db/src/row_tree/api.rs +++ b/verifiable-db/src/row_tree/api.rs @@ -1,7 +1,7 @@ use alloy::primitives::U256; use anyhow::Result; -use mp2_common::{default_config, proof::ProofWithVK, C, D, F}; -use plonky2::{field::types::Field, hash::hash_types::HashOut}; +use mp2_common::{default_config, proof::ProofWithVK, types::HashOutput, C, D, F}; +use plonky2::{field::types::Field, hash::hash_types::HashOut, plonk::config::GenericHashOut}; use recursion_framework::{ circuit_builder::{CircuitWithUniversalVerifier, CircuitWithUniversalVerifierBuilder}, framework::{prepare_recursive_circuit_for_circuit_set as p, RecursiveCircuits}, @@ -184,7 +184,7 @@ impl CircuitInput { pub fn leaf( identifier: u64, value: U256, - row_unique_data: HashOut, + row_unique_data: HashOutput, cells_proof: Vec, ) -> Result { Self::leaf_multiplier(identifier, value, false, row_unique_data, cells_proof) @@ -193,11 +193,11 @@ impl CircuitInput { identifier: u64, value: U256, is_multiplier: bool, - row_unique_data: HashOut, + row_unique_data: HashOutput, cells_proof: Vec, ) -> Result { let cell = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); - let row = Row::new(cell, row_unique_data); + let row = Row::new(cell, row_unique_data.into()); Ok(CircuitInput::Leaf { witness: row.into(), cells_proof, @@ -207,7 +207,7 @@ impl CircuitInput { pub fn full( identifier: u64, value: U256, - row_unique_data: HashOut, + row_unique_data: HashOutput, left_proof: Vec, right_proof: Vec, cells_proof: Vec, @@ -226,13 +226,13 @@ impl CircuitInput { identifier: u64, value: U256, is_multiplier: bool, - row_unique_data: HashOut, + row_unique_data: HashOutput, left_proof: Vec, right_proof: Vec, cells_proof: Vec, ) -> Result { let cell = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); - let row = Row::new(cell, row_unique_data); + let row = Row::new(cell, row_unique_data.into()); Ok(CircuitInput::Full { witness: row.into(), left_proof, @@ -244,7 +244,7 @@ impl CircuitInput { identifier: u64, value: U256, is_child_left: bool, - row_unique_data: HashOut, + row_unique_data: HashOutput, child_proof: Vec, cells_proof: Vec, ) -> Result { @@ -263,12 +263,12 @@ impl CircuitInput { value: U256, is_multiplier: bool, is_child_left: bool, - row_unique_data: HashOut, + row_unique_data: HashOutput, child_proof: Vec, cells_proof: Vec, ) -> Result { let cell = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); - let row = Row::new(cell, row_unique_data); + let row = Row::new(cell, row_unique_data.into()); let witness = PartialNodeCircuit::new(row, is_child_left); Ok(CircuitInput::Partial { witness, @@ -387,7 +387,7 @@ mod test { let row = &p.partial; let id = row.cell.identifier; let value = row.cell.value; - let row_unique_data = row.row_unique_data; + let row_unique_data = row.row_unique_data.into(); let row_digest = row.digest(&p.cells_pi()); let child_proof = ProofWithVK::deserialize(&child_proof_buff)?; @@ -461,7 +461,7 @@ mod test { let row = &p.full; let id = row.cell.identifier; let value = row.cell.value; - let row_unique_data = row.row_unique_data; + let row_unique_data = row.row_unique_data.into(); let row_digest = row.digest(&p.cells_pi()); let input = CircuitInput::full( @@ -520,7 +520,7 @@ mod test { fn generate_leaf_proof(p: &TestParams, row: &Row) -> Result> { let id = row.cell.identifier; let value = row.cell.value; - let row_unique_data = row.row_unique_data; + let row_unique_data = row.row_unique_data.into(); let row_digest = row.digest(&p.cells_pi()); // generate row leaf proof From 5ce5ca5358b0a3aed0906fa29bb13e0b8d00a29e Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Thu, 31 Oct 2024 21:12:36 +0800 Subject: [PATCH 5/5] Fix test --- mp2-v1/tests/common/rowtree.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mp2-v1/tests/common/rowtree.rs b/mp2-v1/tests/common/rowtree.rs index 329966885..c6f252db8 100644 --- a/mp2-v1/tests/common/rowtree.rs +++ b/mp2-v1/tests/common/rowtree.rs @@ -152,7 +152,7 @@ impl TestContext { value, multiplier, // TODO: row_unique_data - HashOut::rand(), + HashOut::rand().into(), cell_tree_proof, ) .unwrap(), @@ -190,7 +190,7 @@ impl TestContext { multiplier, context.left.is_some(), // TODO: row_unique_data - HashOut::rand(), + HashOut::rand().into(), child_proof, cell_tree_proof, ) @@ -234,7 +234,7 @@ impl TestContext { value, multiplier, // TODO: row_unique_data - HashOut::rand(), + HashOut::rand().into(), left_proof, right_proof, cell_tree_proof,