From 0b9c80f5337b5167d383bf79eabc5d55947fc235 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 6 Dec 2024 09:51:00 +0100 Subject: [PATCH 01/12] Replace query params with batching circuits + use batching public inputs for tabular queries --- verifiable-db/src/api.rs | 60 +- verifiable-db/src/lib.rs | 1 + verifiable-db/src/query/aggregation/mod.rs | 13 +- .../query/aggregation/output_computation.rs | 62 +- verifiable-db/src/query/aggregation/utils.rs | 1 - verifiable-db/src/query/api.rs | 1946 ++++------------- .../src/query/batching/circuits/api.rs | 269 --- .../batching/circuits/chunk_aggregation.rs | 10 +- .../src/query/batching/circuits/mod.rs | 3 +- .../query/batching/circuits/non_existence.rs | 24 +- .../batching/circuits/row_chunk_processing.rs | 51 +- verifiable-db/src/query/batching/mod.rs | 4 - .../src/query/batching/public_inputs.rs | 695 ------ .../batching/row_chunk/aggregate_chunks.rs | 6 +- .../src/query/batching/row_chunk/mod.rs | 5 +- .../{ => row_chunk}/row_process_gadget.rs | 10 +- verifiable-db/src/query/mod.rs | 4 +- verifiable-db/src/query/public_inputs.rs | 594 +++-- .../universal_query_circuit.rs | 334 +-- .../universal_query_gadget.rs | 83 +- .../results_tree/binding/binding_results.rs | 11 +- verifiable-db/src/results_tree/mod.rs | 76 + .../src/results_tree/old_public_inputs.rs | 539 +++++ verifiable-db/src/revelation/api.rs | 94 +- verifiable-db/src/revelation/mod.rs | 19 +- .../revelation/revelation_unproven_offset.rs | 289 ++- .../revelation_without_results_tree.rs | 331 +-- verifiable-db/src/test_utils.rs | 173 +- 28 files changed, 2270 insertions(+), 3437 deletions(-) delete mode 100644 verifiable-db/src/query/batching/public_inputs.rs rename verifiable-db/src/query/batching/{ => row_chunk}/row_process_gadget.rs (97%) create mode 100644 verifiable-db/src/results_tree/old_public_inputs.rs diff --git a/verifiable-db/src/api.rs b/verifiable-db/src/api.rs index d71636d13..8088f605e 100644 --- a/verifiable-db/src/api.rs +++ b/verifiable-db/src/api.rs @@ -7,11 +7,11 @@ use crate::{ extraction::{ExtractionPI, ExtractionPIWrap}, ivc, query::{ - self, api::Parameters as QueryParams, batching::circuits::api::num_io as batching_num_io, + self, api::Parameters as QueryParams, pi_len as query_pi_len, }, revelation::{ - self, api::Parameters as RevelationParams, num_query_io, num_query_io_no_results_tree, + self, api::Parameters as RevelationParams, pi_len as revelation_pi_len, }, row_tree::{self}, @@ -218,15 +218,12 @@ pub struct QueryParameters< [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, [(); MAX_NUM_ITEMS_PER_OUTPUT - 1]:, [(); query_pi_len::()]:, - [(); num_query_io::()]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, [(); ROW_TREE_MAX_DEPTH - 1]:, [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, - [(); num_query_io_no_results_tree::()]:, { - #[cfg(feature = "batching_circuits")] - batching_query_params: BatchingQueryParams< + query_params: QueryParams< NUM_CHUNKS, NUM_ROWS, ROW_TREE_MAX_DEPTH, @@ -236,12 +233,6 @@ pub struct QueryParameters< MAX_NUM_RESULT_OPS, MAX_NUM_ITEMS_PER_OUTPUT, >, - query_params: QueryParams< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_ITEMS_PER_OUTPUT, - >, revelation_params: RevelationParams< ROW_TREE_MAX_DEPTH, INDEX_TREE_MAX_DEPTH, @@ -275,9 +266,8 @@ pub enum QueryCircuitInput< [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, { - #[cfg(feature = "batching_circuits")] - BatchingQuery( - query::batching::circuits::api::CircuitInput< + Query( + query::api::CircuitInput< NUM_CHUNKS, NUM_ROWS, ROW_TREE_MAX_DEPTH, @@ -288,14 +278,6 @@ pub enum QueryCircuitInput< MAX_NUM_ITEMS_PER_OUTPUT, >, ), - Query( - query::api::CircuitInput< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_ITEMS_PER_OUTPUT, - >, - ), Revelation( revelation::api::CircuitInput< ROW_TREE_MAX_DEPTH, @@ -337,34 +319,21 @@ impl< where [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, [(); MAX_NUM_ITEMS_PER_OUTPUT - 1]:, - [(); num_query_io_no_results_tree::()]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, [(); ROW_TREE_MAX_DEPTH - 1]:, [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, - [(); batching_num_io::()]:, - [(); num_query_io::()]:, [(); query_pi_len::()]:, [(); revelation_pi_len::()]:, { /// Build `QueryParameters` from serialized `ParamsInfo` of `PublicParamaters` pub fn build_params(preprocessing_params_info: &[u8]) -> Result { let params_info: ParamsInfo = bincode::deserialize(preprocessing_params_info)?; - #[cfg(feature = "batching_circuits")] - let batching_query_params = BatchingQueryParams::build(); let query_params = QueryParams::build(); info!("Building the revelation circuit parameters..."); - #[cfg(feature = "batching_circuits")] - let revelation_params = RevelationParams::build( - batching_query_params.get_circuit_set(), - query_params.get_circuit_set(), - ¶ms_info.preprocessing_circuit_set, - ¶ms_info.preprocessing_vk, - ); - #[cfg(not(feature = "batching_circuits"))] let revelation_params = RevelationParams::build( query_params.get_circuit_set(), // unused, so we provide same query params - query_params.get_circuit_set(), + query_params.get_universal_circuit().data.verifier_data(), ¶ms_info.preprocessing_circuit_set, ¶ms_info.preprocessing_vk, ); @@ -372,8 +341,6 @@ where let wrap_circuit = WrapCircuitParams::build(revelation_params.get_circuit_set()); info!("All QUERY parameters built !"); Ok(Self { - #[cfg(feature = "batching_circuits")] - batching_query_params, query_params, revelation_params, wrap_circuit, @@ -396,25 +363,12 @@ where >, ) -> Result> { match input { - #[cfg(feature = "batching_circuits")] - QueryCircuitInput::BatchingQuery(input) => { - self.batching_query_params.generate_proof(input) - } QueryCircuitInput::Query(input) => self.query_params.generate_proof(input), QueryCircuitInput::Revelation(input) => { - #[cfg(feature = "batching_circuits")] - let proof = self.revelation_params.generate_proof( - input, - self.batching_query_params.get_circuit_set(), - self.query_params.get_circuit_set(), - Some(&self.query_params), - )?; - #[cfg(not(feature = "batching_circuits"))] let proof = self.revelation_params.generate_proof( input, - self.query_params.get_circuit_set(), // unused, so we provide a dummy one self.query_params.get_circuit_set(), - Some(&self.query_params), + Some(self.query_params.get_universal_circuit()), )?; self.wrap_circuit.generate_proof( self.revelation_params.get_circuit_set(), diff --git a/verifiable-db/src/lib.rs b/verifiable-db/src/lib.rs index 1cac73092..3a678639e 100644 --- a/verifiable-db/src/lib.rs +++ b/verifiable-db/src/lib.rs @@ -16,4 +16,5 @@ pub mod results_tree; /// Module for the query revelation circuits pub mod revelation; pub mod row_tree; +#[cfg(test)] pub mod test_utils; diff --git a/verifiable-db/src/query/aggregation/mod.rs b/verifiable-db/src/query/aggregation/mod.rs index f179ff60b..3afa18042 100644 --- a/verifiable-db/src/query/aggregation/mod.rs +++ b/verifiable-db/src/query/aggregation/mod.rs @@ -28,22 +28,13 @@ use plonky2::{ }; use serde::{Deserialize, Serialize}; -pub(crate) mod child_proven_single_path_node; -pub(crate) mod embedded_tree_proven_single_path_node; -pub(crate) mod full_node_index_leaf; -pub(crate) mod full_node_with_one_child; -pub(crate) mod full_node_with_two_children; -pub(crate) mod non_existence_inter; pub(crate) mod output_computation; -pub(crate) mod partial_node; -mod utils; use super::{ - api::CircuitInput, computational_hash_ids::{ColumnIDs, Identifiers, PlaceholderIdentifier}, universal_circuit::{ universal_circuit_inputs::{BasicOperation, PlaceholderId, Placeholders, ResultStructure}, - universal_query_circuit::{placeholder_hash, placeholder_hash_without_query_bounds}, + universal_query_circuit::{placeholder_hash, placeholder_hash_without_query_bounds, UniversalCircuitInput}, universal_query_gadget::QueryBound, ComputationalHash, PlaceholderHash, }, @@ -474,7 +465,7 @@ impl QueryHashNonExistenceCircuits { .into(), ) }; - let placeholder_hash_ids = CircuitInput::< + let placeholder_hash_ids = UniversalCircuitInput::< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, diff --git a/verifiable-db/src/query/aggregation/output_computation.rs b/verifiable-db/src/query/aggregation/output_computation.rs index 49ede7394..04b71f2bc 100644 --- a/verifiable-db/src/query/aggregation/output_computation.rs +++ b/verifiable-db/src/query/aggregation/output_computation.rs @@ -2,11 +2,9 @@ use crate::query::{ computational_hash_ids::{AggregationOperation, Identifiers}, - public_inputs::PublicInputs, universal_circuit::universal_query_gadget::{CurveOrU256Target, OutputValuesTarget}, }; use alloy::primitives::U256; -use itertools::Itertools; use mp2_common::{ array::ToField, group_hashing::CircuitBuilderGroupHashing, @@ -153,42 +151,18 @@ where } } -/// Compute the node output item at the specified index by the proofs, -/// and return the output item with the overflow number. -pub(crate) fn compute_output_item( - b: &mut CBuilder, - i: usize, - proofs: &[&PublicInputs], -) -> (Vec, Target) -where - [(); S - 1]:, -{ - let proof0 = &proofs[0]; - let op = proof0.operation_ids_target()[i]; - - // Check that the all proofs are employing the same aggregation operation. - proofs[1..] - .iter() - .for_each(|p| b.connect(p.operation_ids_target()[i], op)); - - let outputs = proofs - .iter() - .map(|p| OutputValuesTarget::from_targets(p.to_values_raw())) - .collect_vec(); - - OutputValuesTarget::aggregate_outputs(b, &outputs, op, i) -} - #[cfg(test)] pub(crate) mod tests { use super::*; use crate::{ query::{ aggregation::tests::compute_output_item_value, pi_len, + public_inputs::PublicInputs, universal_circuit::universal_query_gadget::CurveOrU256, }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, + test_utils::random_aggregation_operations, }; + use itertools::Itertools; use mp2_common::{types::CURVE_TARGET_LEN, u256::NUM_LIMBS, utils::ToFields, C, D, F}; use mp2_test::circuit::{run_circuit, UserCircuit}; use plonky2::{ @@ -198,6 +172,32 @@ pub(crate) mod tests { use plonky2_ecgfp5::curve::curve::Point; use std::array; + /// Compute the node output item at the specified index by the proofs, + /// and return the output item with the overflow number. + pub(crate) fn compute_output_item( + b: &mut CBuilder, + i: usize, + proofs: &[&PublicInputs], + ) -> (Vec, Target) + where + [(); S - 1]:, + { + let proof0 = &proofs[0]; + let op = proof0.operation_ids_target()[i]; + + // Check that the all proofs are employing the same aggregation operation. + proofs[1..] + .iter() + .for_each(|p| b.connect(p.operation_ids_target()[i], op)); + + let outputs = proofs + .iter() + .map(|p| OutputValuesTarget::from_targets(p.to_values_raw())) + .collect_vec(); + + OutputValuesTarget::aggregate_outputs(b, &outputs, op, i) + } + /// Compute the dummy values for each of the `S` values to be returned as output. /// It's the test function corresponding to `compute_dummy_output_targets`. pub(crate) fn compute_dummy_output_values(ops: &[F; S]) -> Vec { @@ -333,7 +333,7 @@ pub(crate) mod tests { let ops: [_; S] = random_aggregation_operations(); // Build the input proofs. - let inputs = random_aggregation_public_inputs(&ops); + let inputs = PublicInputs::::sample_from_ops(&ops); // Construct the test circuit. let test_circuit = TestOutputComputationCircuit::::new(inputs); @@ -354,7 +354,7 @@ pub(crate) mod tests { ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); // Build the input proofs. - let inputs = random_aggregation_public_inputs(&ops); + let inputs = PublicInputs::::sample_from_ops(&ops); // Construct the test circuit. let test_circuit = TestOutputComputationCircuit::::new(inputs); diff --git a/verifiable-db/src/query/aggregation/utils.rs b/verifiable-db/src/query/aggregation/utils.rs index e9f1f4454..eff88308f 100644 --- a/verifiable-db/src/query/aggregation/utils.rs +++ b/verifiable-db/src/query/aggregation/utils.rs @@ -97,7 +97,6 @@ pub(crate) fn constrain_input_proofs( #[cfg(test)] pub(crate) mod tests { use super::*; - use crate::query::public_inputs::QueryPublicInputs; use alloy::primitives::U256; use mp2_common::utils::ToFields; diff --git a/verifiable-db/src/query/api.rs b/verifiable-db/src/query/api.rs index b4c26270a..322b90042 100644 --- a/verifiable-db/src/query/api.rs +++ b/verifiable-db/src/query/api.rs @@ -1,85 +1,161 @@ -use std::iter::repeat; +use std::iter::{repeat, repeat_with}; -use crate::query::aggregation::full_node_index_leaf::FullNodeIndexLeafCircuit; +use anyhow::{bail, ensure, Result}; -use super::{ - aggregation::{ - child_proven_single_path_node::{ - ChildProvenSinglePathNodeCircuit, ChildProvenSinglePathNodeWires, - NUM_VERIFIED_PROOFS as NUM_PROOFS_CHILD, - }, - embedded_tree_proven_single_path_node::{ - EmbeddedTreeProvenSinglePathNodeCircuit, EmbeddedTreeProvenSinglePathNodeWires, - NUM_VERIFIED_PROOFS as NUM_PROOFS_EMBEDDED, - }, - full_node_index_leaf::{FullNodeIndexLeafWires, NUM_VERIFIED_PROOFS as NUM_PROOFS_LEAF}, - full_node_with_one_child::{ - FullNodeWithOneChildCircuit, FullNodeWithOneChildWires, - NUM_VERIFIED_PROOFS as NUM_PROOFS_FN1, - }, - full_node_with_two_children::{ - FullNodeWithTwoChildrenCircuit, FullNodeWithTwoChildrenWires, - NUM_VERIFIED_PROOFS as NUM_PROOFS_FN2, - }, - non_existence_inter::{ - NonExistenceInterNodeCircuit, NonExistenceInterNodeWires, - NUM_VERIFIED_PROOFS as NUM_PROOFS_NE_INTER, - }, - partial_node::{ - PartialNodeCircuit, PartialNodeWires, NUM_VERIFIED_PROOFS as NUM_PROOFS_PN, - }, - ChildPosition, ChildProof, CommonInputs, NodeInfo, NonExistenceInput, - OneProvenChildNodeInput, QueryBounds, QueryHashNonExistenceCircuits, SinglePathInput, - SubProof, TwoProvenChildNodeInput, - }, - computational_hash_ids::{AggregationOperation, HashPermutation, Output}, - pi_len, - universal_circuit::{ - output_no_aggregation::Circuit as NoAggOutputCircuit, - output_with_aggregation::Circuit as AggOutputCircuit, - universal_circuit_inputs::{ - BasicOperation, PlaceholderId, Placeholders, ResultStructure, RowCells, - }, - universal_query_circuit::{ - placeholder_hash, UniversalCircuitInput, UniversalQueryCircuitInputs, - UniversalQueryCircuitWires, - }, - universal_query_gadget::QueryBound, - }, -}; -use alloy::primitives::U256; -use anyhow::{ensure, Result}; use itertools::Itertools; -use log::info; -use mp2_common::{ - array::ToField, - default_config, - poseidon::H, - proof::ProofWithVK, - types::HashOutput, - utils::{Fieldable, ToFields}, - C, D, F, -}; -use plonky2::{ - hash::hashing::hash_n_to_hash_no_pad, - plonk::config::{GenericHashOut, Hasher}, +use mp2_common::{array::ToField, default_config, poseidon::{HashPermutation, H}, proof::{serialize_proof, ProofWithVK}, types::HashOutput, utils::ToFields, C, D, F}; +use plonky2::{hash::hashing::hash_n_to_hash_no_pad, plonk::config::{GenericHashOut, Hasher}}; +use recursion_framework::{ + circuit_builder::{CircuitWithUniversalVerifier, CircuitWithUniversalVerifierBuilder}, framework::{prepare_recursive_circuit_for_circuit_set, RecursiveCircuits}, }; +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "batching_circuits")] +use mp2_common::{default_config, poseidon::H}; +#[cfg(feature = "batching_circuits")] +use plonky2::plonk::config::Hasher; +#[cfg(feature = "batching_circuits")] use recursion_framework::{ - circuit_builder::{CircuitWithUniversalVerifier, CircuitWithUniversalVerifierBuilder}, - framework::{ - prepare_recursive_circuit_for_circuit_set, RecursiveCircuitInfo, RecursiveCircuits, + circuit_builder::CircuitWithUniversalVerifierBuilder, + framework::prepare_recursive_circuit_for_circuit_set, +}; + +use crate::query::{ + aggregation::{ChildPosition, NodeInfo, QueryBounds, QueryHashNonExistenceCircuits}, + batching::{ + circuits::{ + chunk_aggregation::{ChunkAggregationCircuit, ChunkAggregationInputs, ChunkAggregationWires}, + non_existence::{NonExistenceCircuit, NonExistenceWires}, + row_chunk_processing::{RowChunkProcessingCircuit, RowChunkProcessingWires}, + }, + row_chunk::row_process_gadget::RowProcessingGadgetInputs, + }, + computational_hash_ids::{AggregationOperation, ColumnIDs, Identifiers}, + universal_circuit::{ + output_with_aggregation::Circuit as OutputAggCircuit, + output_no_aggregation::Circuit as OutputNoAggCircuit, + universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure, RowCells}, }, }; -use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Serialize, Deserialize)] -#[allow(clippy::large_enum_variant)] // we need to clone data if we fix by put variants inside a `Box` +use super::{computational_hash_ids::Output, pi_len, universal_circuit::{universal_circuit_inputs::PlaceholderId, universal_query_circuit::{placeholder_hash, UniversalCircuitInput, UniversalQueryCircuitParams}}}; + +/// Data structure containing all the information needed to verify the membership of +/// a node in a tree and to compute info about its predecessor/successor +#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)] +pub struct TreePathInputs { + /// Info about the node + pub(crate) node_info: NodeInfo, + /// Info about the nodes in the path from the node up to the root of the tree; The `ChildPosition` refers to + /// the position of the previous node in the path as a child of the current node + pub(crate) path: Vec<(NodeInfo, ChildPosition)>, + /// Hash of the siblings of the nodes in path (except for the root) + pub(crate) siblings: Vec>, + /// Info about the children of the node + pub(crate) children: [Option; 2], +} + +impl TreePathInputs { + /// Instantiate a new instance of `TreePathInputs` for a given node from the following input data: + /// - `node_info`: data about the given node + /// - `path`: data about the nodes in the path from the node up to the root of the tree; + /// The `ChildPosition` refers to the position of the previous node in the path as a child of the current node + /// - `siblings`: hash of the siblings of the nodes in the path (except for the root) + /// - `children`: data about the children of the given node + pub fn new( + node_info: NodeInfo, + path: Vec<(NodeInfo, ChildPosition)>, + children: [Option; 2], + ) -> Self { + let siblings = path + .iter() + .map(|(node, child_pos)| { + let sibling_index = match *child_pos { + ChildPosition::Left => 1, + ChildPosition::Right => 0, + }; + Some(HashOutput::from(node.child_hashes[sibling_index])) + }) + .collect_vec(); + Self { + node_info, + path, + siblings, + children, + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)] +/// Data structure containing the information about the paths in both the rows tree +/// and the index tree for a node in a rows tree +pub struct NodePath { + pub(crate) row_tree_path: TreePathInputs, + /// Info about the node of the index tree storing the rows tree containing the row + pub(crate) index_tree_path: TreePathInputs, +} + +impl NodePath { + /// Instantiate a new instance of `NodePath` for a given proven row from the following input data: + /// - `row_path`: path from the node to the root of the rows tree storing the node + /// - `index_path` : path from the index tree node storing the rows tree containing the node, up to the + /// root of the index tree + pub fn new(row_path: TreePathInputs, index_path: TreePathInputs) -> Self { + Self { + row_tree_path: row_path, + index_tree_path: index_path, + } + } +} + +#[derive(Clone, Debug)] +/// Data structure containing the inputs necessary to prove a query for a row +/// of the DB table. +pub struct RowInput { + pub(crate) cells: RowCells, + pub(crate) path: NodePath, +} + +impl RowInput { + /// Initialize `RowInput` from the set of cells of the given row and the path + /// in the tree of the node of the rows tree associated to the given row + pub fn new(cells: &RowCells, path: &NodePath) -> Self { + Self { + cells: cells.clone(), + path: path.clone(), + } + } +} + +#[derive(Serialize, Deserialize)] +#[allow(clippy::large_enum_variant)] pub enum CircuitInput< + const NUM_CHUNKS: usize, + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, const MAX_NUM_COLUMNS: usize, const MAX_NUM_PREDICATE_OPS: usize, const MAX_NUM_RESULT_OPS: usize, const MAX_NUM_RESULTS: usize, -> { +> where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, +{ + RowChunkWithAggregation( + RowChunkProcessingCircuit< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + OutputAggCircuit, + >, + ), + ChunkAggregation(ChunkAggregationInputs), + NonExistence(NonExistenceCircuit), /// Inputs for the universal query circuit UniversalCircuit( UniversalCircuitInput< @@ -89,27 +165,160 @@ pub enum CircuitInput< MAX_NUM_RESULTS, >, ), - /// Inputs for circuits with 2 proven children and a proven embedded tree - TwoProvenChildNode(TwoProvenChildNodeInput), - /// Inputs for circuits proving a node with one proven child and a proven embedded tree - OneProvenChildNode(OneProvenChildNodeInput), - /// Inputs for circuits proving a node with only one proven subtree (either a proven child or the embedded tree) - SinglePath(SinglePathInput), - /// Inputs for circuits to prove non-existence of results for the current query - NonExistence(NonExistenceInput), } impl< + const NUM_CHUNKS: usize, + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, const MAX_NUM_COLUMNS: usize, const MAX_NUM_PREDICATE_OPS: usize, const MAX_NUM_RESULT_OPS: usize, const MAX_NUM_RESULTS: usize, - > CircuitInput + > + CircuitInput< + NUM_CHUNKS, + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + > where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_RESULTS - 1]:, [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, { + /// Construct the input necessary to prove a query over a chunk of rows provided as input. + /// It requires to provide at least 1 row; in case there are no rows to be proven, then + /// `Self::new_non_existence_input` should be used instead + pub fn new_row_chunks_input( + rows: &[RowInput], + predicate_operations: &[BasicOperation], + placeholders: &Placeholders, + query_bounds: &QueryBounds, + results: &ResultStructure, + ) -> Result { + ensure!( + !rows.is_empty(), + "there must be at least 1 row to be proven" + ); + ensure!( + rows.len() <= NUM_ROWS, + format!( + "Found {} rows provided as input, maximum allowed is {NUM_ROWS}", + rows.len() + ) + ); + let column_ids = &rows[0].cells.column_ids(); + ensure!( + rows.iter() + .all(|row| row.cells.column_ids().to_vec() == column_ids.to_vec()), + "Rows provided as input don't have the same column ids", + ); + let row_inputs = rows + .iter() + .map(RowProcessingGadgetInputs::try_from) + .collect::>>()?; + + Ok(Self::RowChunkWithAggregation( + RowChunkProcessingCircuit::new( + row_inputs, + column_ids, + predicate_operations, + placeholders, + query_bounds, + results, + )?, + )) + } + + /// Construct the input necessary to aggregate 2 or more row chunks already proven. + /// It requires at least 2 chunks to be aggregated + pub fn new_chunk_aggregation_input(chunks_proofs: &[Vec]) -> Result { + ensure!( + chunks_proofs.len() >= 2, + "At least 2 chunk proofs must be provided" + ); + // deserialize `chunk_proofs`` and pad to NUM_CHUNKS proofs by replicating the last proof in `chunk_proofs` + let last_proof = chunks_proofs.last().unwrap(); + let proofs = chunks_proofs + .iter() + .map(|p| ProofWithVK::deserialize(p)) + .chain(repeat_with(|| ProofWithVK::deserialize(last_proof))) + .take(NUM_CHUNKS) + .collect::>>()?; + + let num_proofs = chunks_proofs.len(); + + ensure!( + num_proofs <= NUM_CHUNKS, + format!("Found {num_proofs} proofs provided as input, maximum allowed is {NUM_CHUNKS}") + ); + + Ok(Self::ChunkAggregation(ChunkAggregationInputs { + chunk_proofs: proofs.try_into().unwrap(), + circuit: ChunkAggregationCircuit { + num_non_dummy_chunks: num_proofs, + }, + })) + } + + /// Construct the input to prove a query in case there are no rows with a primary index value + /// in the primary query range. The circuit employed to prove the non-existence of such a row + /// requires to provide a specific node of the index tree, as described in the docs + /// https://www.notion.so/lagrangelabs/Batching-Query-10628d1c65a880b1b151d4ac017fa445?pvs=4#10e28d1c65a880498f41cd1cad0c61c3 + pub fn new_non_existence_input( + index_node_path: TreePathInputs, + column_ids: &ColumnIDs, + predicate_operations: &[BasicOperation], + results: &ResultStructure, + placeholders: &Placeholders, + query_bounds: &QueryBounds, + ) -> Result { + let QueryHashNonExistenceCircuits { + computational_hash, + placeholder_hash, + } = QueryHashNonExistenceCircuits::new::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >( + column_ids, + predicate_operations, + results, + placeholders, + query_bounds, + false, + )?; + + let aggregation_operations = results + .aggregation_operations() + .into_iter() + .chain(repeat( + Identifiers::AggregationOperations(AggregationOperation::default()).to_field(), + )) + .take(MAX_NUM_RESULTS) + .collect_vec() + .try_into() + .unwrap(); + + Ok(Self::NonExistence(NonExistenceCircuit::new( + &index_node_path, + column_ids.primary, + aggregation_operations, + computational_hash, + placeholder_hash, + query_bounds, + )?)) + } + pub const fn num_placeholders_ids() -> usize { 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS) } @@ -136,14 +345,9 @@ where ) -> Result { Ok(CircuitInput::UniversalCircuit( match results.output_variant { - Output::Aggregation => UniversalCircuitInput::new_query_with_agg( - column_cells, - predicate_operations, - placeholders, - is_leaf, - query_bounds, - results, - )?, + Output::Aggregation => bail!( + "Universal query circuit should only be used for queries with no aggregation" + ), Output::NoAggregation => UniversalCircuitInput::new_query_no_agg( column_cells, predicate_operations, @@ -156,145 +360,6 @@ where )) } - /// Initialize input to prove a full node from the following inputs: - /// - `left_child_proof`: proof for the left child of the node being proven - /// - `right_child_proof`: proof for the right child of the node being proven - /// - `embedded_tree_proof`: proof for the embedded tree stored in the full node: can be either the proof for a single - /// row (if proving a rows tree node) of the proof for the root node of a rows tree (if proving an index tree node) - /// - `is_rows_tree_node`: flag specifying whether the full node belongs to the rows tree or to the index tree - /// - `query_bounds`: bounds on primary and secondary indexes specified in the query - pub fn new_full_node( - left_child_proof: Vec, - right_child_proof: Vec, - embedded_tree_proof: Vec, - is_rows_tree_node: bool, - query_bounds: &QueryBounds, - ) -> Result { - Ok(CircuitInput::TwoProvenChildNode(TwoProvenChildNodeInput { - left_child_proof: ProofWithVK::deserialize(&left_child_proof)?, - right_child_proof: ProofWithVK::deserialize(&right_child_proof)?, - embedded_tree_proof: ProofWithVK::deserialize(&embedded_tree_proof)?, - common: CommonInputs::new(is_rows_tree_node, query_bounds), - })) - } - - /// Initialize input to prove a partial node from the following inputs: - /// - `proven_child_proof`: Proof for the child being a proven node - /// - `embedded_tree_proof`: Proof for the embedded tree stored in the partial node: can be either the proof - /// for a single row (if proving a rows tree node) of the proof for the root node of a rows - /// tree (if proving an index tree node) - /// - `unproven_child`: Data about the child not being a proven node; if the node has only one child, - /// then, this parameter must be `None` - /// - `proven_child_position`: Enum specifying whether the proven child is the left or right child - /// of the partial node being proven - /// - `is_rows_tree_node`: flag specifying whether the full node belongs to the rows tree or to the index tree - /// - `query_bounds`: bounds on primary and secondary indexes specified in the query - pub fn new_partial_node( - proven_child_proof: Vec, - embedded_tree_proof: Vec, - unproven_child: Option, - proven_child_position: ChildPosition, - is_rows_tree_node: bool, - query_bounds: &QueryBounds, - ) -> Result { - Ok(CircuitInput::OneProvenChildNode(OneProvenChildNodeInput { - unproven_child, - proven_child_proof: ChildProof { - proof: ProofWithVK::deserialize(&proven_child_proof)?, - child_position: proven_child_position, - }, - embedded_tree_proof: ProofWithVK::deserialize(&embedded_tree_proof)?, - common: CommonInputs::new(is_rows_tree_node, query_bounds), - })) - } - /// Initialize input to prove a single path node from the following inputs: - /// - `subtree_proof`: Proof of either a child node or of the embedded tree stored in the current node - /// - `left_child`: Data about the left child of the current node, if any; must be `None` if the node has - /// no left child - /// - `right_child`: Data about the right child of the current node, if any; must be `None` if the node has - /// no right child - /// - `node_info`: Data about the current node being proven - /// - `is_rows_tree_node`: flag specifying whether the full node belongs to the rows tree or to the index tree - /// - `query_bounds`: bounds on primary and secondary indexes specified in the query - pub fn new_single_path( - subtree_proof: SubProof, - left_child: Option, - right_child: Option, - node_info: NodeInfo, - is_rows_tree_node: bool, - query_bounds: &QueryBounds, - ) -> Result { - Ok(CircuitInput::SinglePath(SinglePathInput { - left_child, - right_child, - node_info, - subtree_proof, - common: CommonInputs::new(is_rows_tree_node, query_bounds), - })) - } - /// Initialize input to prove a node storing a value of the primary or secondary index which - /// is outside of the query bounds, from the following inputs: - /// - `node_info`: Data about the node being proven - /// - `left_child_info`: Data aboout the left child of the node being proven; must be `None` if - /// the node being proven has no left child - /// - `right_child_info`: Data aboout the right child of the node being proven; must be `None` if - /// the node being proven has no right child - /// - `primary_index_value`: Value of the primary index associated to the current node - /// - `index_ids`: Identifiers of the primary and secondary index columns - /// - `aggregation_ops`: Set of aggregation operations employed to aggregate the results of the query - /// - `query_hashes`: Computational hash and placeholder hash associated to the query; can be computed with the `new` - /// method of `QueryHashNonExistenceCircuits` data structure - /// - `is_rows_tree_node`: flag specifying whether the full node belongs to the rows tree or to the index tree - /// - `query_bounds`: bounds on primary and secondary indexes specified in the query - #[allow(clippy::too_many_arguments)] // doesn't make sense to aggregate arguments - pub fn new_non_existence_input( - node_info: NodeInfo, - left_child_info: Option, - right_child_info: Option, - primary_index_value: U256, - index_ids: &[u64; 2], - aggregation_ops: &[AggregationOperation], - query_hashes: QueryHashNonExistenceCircuits, - is_rows_tree_node: bool, - query_bounds: &QueryBounds, - placeholders: &Placeholders, - ) -> Result { - let aggregation_ops = aggregation_ops - .iter() - .map(|op| op.to_field()) - .chain(repeat(AggregationOperation::default().to_field())) - .take(MAX_NUM_RESULTS) - .collect_vec(); - let min_query = if is_rows_tree_node { - QueryBound::new_secondary_index_bound(placeholders, query_bounds.min_query_secondary()) - } else { - QueryBound::new_primary_index_bound(placeholders, true) - }?; - let max_query = if is_rows_tree_node { - QueryBound::new_secondary_index_bound(placeholders, query_bounds.max_query_secondary()) - } else { - QueryBound::new_primary_index_bound(placeholders, false) - }?; - Ok(CircuitInput::NonExistence(NonExistenceInput { - node_info, - left_child_info, - right_child_info, - primary_index_value, - index_ids: index_ids - .iter() - .map(|id| id.to_field()) - .collect_vec() - .try_into() - .unwrap(), - computational_hash: query_hashes.computational_hash(), - placeholder_hash: query_hashes.placeholder_hash(), - aggregation_ops: aggregation_ops.try_into().unwrap(), - is_rows_tree_node, - min_query, - max_query, - })) - } - /// This method returns the ids of the placeholders employed to compute the placeholder hash, /// in the same order, so that those ids can be provided as input to other circuits that need /// to recompute this hash @@ -304,45 +369,12 @@ where placeholders: &Placeholders, query_bounds: &QueryBounds, ) -> Result<[PlaceholderId; 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]> { - let row_cells = &RowCells::default(); - Ok(match results.output_variant { - Output::Aggregation => { - let circuit = UniversalQueryCircuitInputs::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - AggOutputCircuit, - >::new( - row_cells, - predicate_operations, - placeholders, - false, // doesn't matter for placeholder hash computation - query_bounds, - results, - )?; - circuit.ids_for_placeholder_hash() - } - Output::NoAggregation => { - let circuit = UniversalQueryCircuitInputs::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - NoAggOutputCircuit, - >::new( - row_cells, - predicate_operations, - placeholders, - false, // doesn't matter for placeholder hash computation - query_bounds, - results, - )?; - circuit.ids_for_placeholder_hash() - } - } - .try_into() - .unwrap()) + UniversalCircuitInput::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >::ids_for_placeholder_hash(predicate_operations, results, placeholders, query_bounds) } /// Compute the `placeholder_hash` associated to a query @@ -373,1324 +405,212 @@ where ) } } -#[derive(Serialize, Deserialize)] -pub struct Parameters< + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Parameters< + const NUM_CHUNKS: usize, + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, const MAX_NUM_COLUMNS: usize, const MAX_NUM_PREDICATE_OPS: usize, const MAX_NUM_RESULT_OPS: usize, const MAX_NUM_RESULTS: usize, > where - [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_RESULTS - 1]:, + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, { - circuit_with_agg: CircuitWithUniversalVerifier< + row_chunk_agg_circuit: CircuitWithUniversalVerifier< F, C, D, 0, - UniversalQueryCircuitWires< + RowChunkProcessingWires< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, - AggOutputCircuit, + OutputAggCircuit, >, >, - circuit_no_agg: CircuitWithUniversalVerifier< + //ToDo: add row_chunk_circuit for queries without aggregation, once we integrate results tree + aggregation_circuit: CircuitWithUniversalVerifier< F, C, D, - 0, - UniversalQueryCircuitWires< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - NoAggOutputCircuit, - >, + NUM_CHUNKS, + ChunkAggregationWires, >, - full_node_two_children: CircuitWithUniversalVerifier< + non_existence_circuit: CircuitWithUniversalVerifier< F, C, D, - NUM_PROOFS_FN2, - FullNodeWithTwoChildrenWires, - >, - full_node_one_child: CircuitWithUniversalVerifier< - F, - C, - D, - NUM_PROOFS_FN1, - FullNodeWithOneChildWires, - >, - full_node_leaf: CircuitWithUniversalVerifier< - F, - C, - D, - NUM_PROOFS_LEAF, - FullNodeIndexLeafWires, - >, - partial_node: - CircuitWithUniversalVerifier>, - single_path_proven_child: CircuitWithUniversalVerifier< - F, - C, - D, - NUM_PROOFS_CHILD, - ChildProvenSinglePathNodeWires, - >, - single_path_embedded_tree: CircuitWithUniversalVerifier< - F, - C, - D, - NUM_PROOFS_EMBEDDED, - EmbeddedTreeProvenSinglePathNodeWires, + 0, + NonExistenceWires, >, - non_existence_intermediate: CircuitWithUniversalVerifier< - F, - C, - D, - NUM_PROOFS_NE_INTER, - NonExistenceInterNodeWires, + universal_circuit: UniversalQueryCircuitParams< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + OutputNoAggCircuit, >, circuit_set: RecursiveCircuits, } -const QUERY_CIRCUIT_SET_SIZE: usize = 10; impl< - const MAX_NUM_COLUMNS: usize, - const MAX_NUM_PREDICATE_OPS: usize, - const MAX_NUM_RESULT_OPS: usize, - const MAX_NUM_RESULTS: usize, - > Parameters + const NUM_CHUNKS: usize, + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, +> + Parameters< + NUM_CHUNKS, + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + > where - [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_RESULTS - 1]:, - [(); pi_len::()]:, + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, [(); >::HASH_SIZE]:, + [(); pi_len::()]:, { - /// Build `Parameters` for query circuits - pub fn build() -> Self { + const CIRCUIT_SET_SIZE: usize = 3; + + pub(crate) fn build() -> Self { let builder = CircuitWithUniversalVerifierBuilder::() }>::new::( default_config(), - QUERY_CIRCUIT_SET_SIZE, + Self::CIRCUIT_SET_SIZE, ); - info!("Building the query circuits parameters..."); - info!("Building universal circuits..."); - let circuit_with_agg = builder.build_circuit(()); - let circuit_no_agg = builder.build_circuit(()); - info!("Building aggregation circuits.."); - let full_node_two_children = builder.build_circuit(()); - let full_node_one_child = builder.build_circuit(()); - let full_node_leaf = builder.build_circuit(()); - let partial_node = builder.build_circuit(()); - let single_path_proven_child = builder.build_circuit(()); - let single_path_embedded_tree = builder.build_circuit(()); - info!("Building non-existence circuits.."); - let non_existence_intermediate = builder.build_circuit(()); + let row_chunk_agg_circuit = builder.build_circuit(()); + let aggregation_circuit = builder.build_circuit(()); + let non_existence_circuit = builder.build_circuit(()); let circuits = vec![ - prepare_recursive_circuit_for_circuit_set(&circuit_with_agg), - prepare_recursive_circuit_for_circuit_set(&circuit_no_agg), - prepare_recursive_circuit_for_circuit_set(&full_node_two_children), - prepare_recursive_circuit_for_circuit_set(&full_node_one_child), - prepare_recursive_circuit_for_circuit_set(&full_node_leaf), - prepare_recursive_circuit_for_circuit_set(&partial_node), - prepare_recursive_circuit_for_circuit_set(&single_path_proven_child), - prepare_recursive_circuit_for_circuit_set(&single_path_embedded_tree), - prepare_recursive_circuit_for_circuit_set(&non_existence_intermediate), + prepare_recursive_circuit_for_circuit_set(&row_chunk_agg_circuit), + prepare_recursive_circuit_for_circuit_set(&aggregation_circuit), + prepare_recursive_circuit_for_circuit_set(&non_existence_circuit), ]; - let circuit_set = RecursiveCircuits::new(circuits); + let universal_circuit = UniversalQueryCircuitParams::build(default_config()); + Self { - circuit_with_agg, - circuit_no_agg, + row_chunk_agg_circuit, + aggregation_circuit, + non_existence_circuit, + universal_circuit, circuit_set, - full_node_two_children, - full_node_one_child, - full_node_leaf, - partial_node, - single_path_proven_child, - single_path_embedded_tree, - non_existence_intermediate, } } - pub fn generate_proof( + pub(crate) fn generate_proof( &self, input: CircuitInput< + NUM_CHUNKS, + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, >, ) -> Result> { - let proof = ProofWithVK::from(match input { - CircuitInput::UniversalCircuit(input) => match input { - UniversalCircuitInput::QueryWithAgg(input) => ( - self.circuit_set - .generate_proof(&self.circuit_with_agg, [], [], input)?, - self.circuit_with_agg.circuit_data().verifier_only.clone(), - ), - UniversalCircuitInput::QueryNoAgg(input) => ( - self.circuit_set - .generate_proof(&self.circuit_no_agg, [], [], input)?, - self.circuit_no_agg.circuit_data().verifier_only.clone(), - ), - }, - CircuitInput::TwoProvenChildNode(TwoProvenChildNodeInput { - left_child_proof, - right_child_proof, - embedded_tree_proof, - common, - }) => { - let (left_proof, left_vk) = left_child_proof.into(); - let (right_proof, right_vk) = right_child_proof.into(); - let (embedded_proof, embedded_vk) = embedded_tree_proof.into(); - let input = FullNodeWithTwoChildrenCircuit { - is_rows_tree_node: common.is_rows_tree_node, - min_query: common.min_query, - max_query: common.max_query, - }; - ( + match input { + CircuitInput::RowChunkWithAggregation(row_chunk_processing_circuit) => + ProofWithVK::serialize( + &( + self.circuit_set.generate_proof( + &self.row_chunk_agg_circuit, + [], + [], + row_chunk_processing_circuit, + )?, + self.row_chunk_agg_circuit + .circuit_data() + .verifier_only + .clone(), + ).into()), + CircuitInput::ChunkAggregation(chunk_aggregation_inputs) => { + let ChunkAggregationInputs { + chunk_proofs, + circuit, + } = chunk_aggregation_inputs; + let input_vd = chunk_proofs + .iter() + .map(|p| p.verifier_data()) + .cloned() + .collect_vec(); + let input_proofs = chunk_proofs.map(|p| p.proof); + ProofWithVK::serialize( + &( self.circuit_set.generate_proof( - &self.full_node_two_children, - [embedded_proof, left_proof, right_proof], - [&embedded_vk, &left_vk, &right_vk], - input, + &self.aggregation_circuit, + input_proofs, + input_vd.iter().collect_vec().try_into().unwrap(), + circuit, )?, - self.full_node_two_children + self.aggregation_circuit .circuit_data() .verifier_only .clone(), ) + .into()) } - CircuitInput::OneProvenChildNode(OneProvenChildNodeInput { - unproven_child, - proven_child_proof, - embedded_tree_proof, - common, - }) => { - let ChildProof { - proof, - child_position, - } = proven_child_proof; - let (child_proof, child_vk) = proof.into(); - let (embedded_proof, embedded_vk) = embedded_tree_proof.into(); - match unproven_child { - Some(child_node) => { - // the node has 2 children, so we use the partial node circuit - let input = PartialNodeCircuit { - is_rows_tree_node: common.is_rows_tree_node, - is_left_child: child_position.to_flag(), - sibling_tree_hash: child_node.embedded_tree_hash, - sibling_child_hashes: child_node.child_hashes, - sibling_value: child_node.value, - sibling_min: child_node.min, - sibling_max: child_node.max, - min_query: common.min_query, - max_query: common.max_query, - }; - ( - self.circuit_set.generate_proof( - &self.partial_node, - [embedded_proof, child_proof], - [&embedded_vk, &child_vk], - input, - )?, - self.partial_node.get_verifier_data().clone(), - ) - } - None => { - // the node has 1 child, so use the circuit for full node with 1 child - let input = FullNodeWithOneChildCircuit { - is_rows_tree_node: common.is_rows_tree_node, - is_left_child: child_position.to_flag(), - min_query: common.min_query, - max_query: common.max_query, - }; - ( - self.circuit_set.generate_proof( - &self.full_node_one_child, - [embedded_proof, child_proof], - [&embedded_vk, &child_vk], - input, - )?, - self.full_node_one_child.get_verifier_data().clone(), - ) - } - } - } - CircuitInput::SinglePath(SinglePathInput { - left_child, - right_child, - node_info, - subtree_proof, - common, - }) => { - let left_child_exists = left_child.is_some(); - let right_child_exists = right_child.is_some(); - let left_child_data = left_child.unwrap_or_default(); - let right_child_data = right_child.unwrap_or_default(); - - match subtree_proof { - SubProof::Embedded(input_proof) => { - let (proof, vk) = input_proof.into(); - if !(left_child_exists || right_child_exists) { - // leaf node, so call full node circuit for leaf node - ensure!(!common.is_rows_tree_node, "providing single-path input for a rows tree node leaf, call universal circuit instead"); - let input = FullNodeIndexLeafCircuit { - min_query: common.min_query, - max_query: common.max_query, - }; - ( - self.circuit_set.generate_proof( - &self.full_node_leaf, - [proof], - [&vk], - input, - )?, - self.full_node_leaf.get_verifier_data().clone(), - ) - } else { - // the input proof refers to the embedded tree stored in the node - let input = EmbeddedTreeProvenSinglePathNodeCircuit { - left_child_min: left_child_data.min, - left_child_max: left_child_data.max, - left_child_value: left_child_data.value, - left_tree_hash: left_child_data.embedded_tree_hash, - left_grand_children: left_child_data.child_hashes, - right_child_min: right_child_data.min, - right_child_max: right_child_data.max, - right_child_value: right_child_data.value, - right_tree_hash: right_child_data.embedded_tree_hash, - right_grand_children: right_child_data.child_hashes, - left_child_exists, - right_child_exists, - is_rows_tree_node: common.is_rows_tree_node, - min_query: common.min_query, - max_query: common.max_query, - }; - ( - self.circuit_set.generate_proof( - &self.single_path_embedded_tree, - [proof], - [&vk], - input, - )?, - self.single_path_embedded_tree.get_verifier_data().clone(), - ) - } - } - SubProof::Child(ChildProof { - proof, - child_position, - }) => { - // the input proof refers to a child of the node - let (proof, vk) = proof.into(); - let is_left_child = child_position.to_flag(); - let input = ChildProvenSinglePathNodeCircuit { - value: node_info.value, - subtree_hash: node_info.embedded_tree_hash, - sibling_hash: if is_left_child { - node_info.child_hashes[1] // set the hash of the right child, since proven child is left - } else { - node_info.child_hashes[0] // set the hash of the left child, since proven child is right - }, - is_left_child, - unproven_min: node_info.min, - unproven_max: node_info.max, - is_rows_tree_node: common.is_rows_tree_node, - }; - ( - self.circuit_set.generate_proof( - &self.single_path_proven_child, - [proof], - [&vk], - input, - )?, - self.single_path_proven_child.get_verifier_data().clone(), - ) - } + CircuitInput::NonExistence(non_existence_circuit) => + ProofWithVK::serialize( + &( + self.circuit_set.generate_proof( + &self.non_existence_circuit, + [], + [], + non_existence_circuit, + )?, + self.non_existence_circuit + .circuit_data() + .verifier_only + .clone(), + ) + .into()), + CircuitInput::UniversalCircuit(universal_circuit_input) => + if let UniversalCircuitInput::QueryNoAgg(input) = universal_circuit_input { + serialize_proof(&self.universal_circuit.generate_proof(&input)?) + } else { + unreachable!("Universal circuit should only be used for queries with no aggregation operations") } - } - CircuitInput::NonExistence(NonExistenceInput { - node_info, - left_child_info, - right_child_info, - primary_index_value, - index_ids, - computational_hash, - placeholder_hash, - aggregation_ops, - is_rows_tree_node, - min_query, - max_query, - }) => { - // intermediate node - let left_child_exists = left_child_info.is_some(); - let right_child_exists = right_child_info.is_some(); - let left_child_data = left_child_info.unwrap_or_default(); - let right_child_data = right_child_info.unwrap_or_default(); - let input = NonExistenceInterNodeCircuit { - is_rows_tree_node, - left_child_exists, - right_child_exists, - min_query, - max_query, - value: node_info.value, - index_value: primary_index_value, - left_child_value: left_child_data.value, - left_child_min: left_child_data.min, - left_child_max: left_child_data.max, - right_child_value: right_child_data.value, - right_child_min: right_child_data.min, - right_child_max: right_child_data.max, - index_ids, - ops: aggregation_ops, - subtree_hash: node_info.embedded_tree_hash, - computational_hash, - placeholder_hash, - left_tree_hash: left_child_data.embedded_tree_hash, - left_grand_children: left_child_data.child_hashes, - right_tree_hash: right_child_data.embedded_tree_hash, - right_grand_children: right_child_data.child_hashes, - }; - ( - self.circuit_set.generate_proof( - &self.non_existence_intermediate, - [], - [], - input, - )?, - self.non_existence_intermediate.get_verifier_data().clone(), - ) - } - }); - - proof.serialize() + , + } } pub(crate) fn get_circuit_set(&self) -> &RecursiveCircuits { &self.circuit_set } -} - -#[cfg(test)] -mod tests { - use std::cmp::Ordering; - - use alloy::primitives::U256; - use itertools::Itertools; - use mp2_common::{proof::ProofWithVK, types::HashOutput, utils::Fieldable, F}; - use mp2_test::utils::{gen_random_field_hash, gen_random_u256}; - use plonky2::{ - field::types::{PrimeField64, Sample}, - plonk::config::GenericHashOut, - }; - use rand::{thread_rng, Rng}; - - use crate::query::{ - aggregation::{ - ChildPosition, NodeInfo, QueryBoundSource, QueryBounds, QueryHashNonExistenceCircuits, - SubProof, - }, - api::{CircuitInput, Parameters}, - computational_hash_ids::{ - AggregationOperation, ColumnIDs, Operation, PlaceholderIdentifier, - }, - public_inputs::PublicInputs, - universal_circuit::universal_circuit_inputs::{ - BasicOperation, ColumnCell, InputOperand, OutputItem, Placeholders, ResultStructure, - RowCells, - }, - }; - - #[test] - fn test_api() { - // Simple query for testing SELECT SUM(C1 + C3) FROM T WHERE C3 >= 5 AND C1 > 56 AND C1 <= 67 AND C2 > 34 AND C2 <= $1 - let rng = &mut thread_rng(); - const NUM_COLUMNS: usize = 3; - const MAX_NUM_COLUMNS: usize = 20; - const MAX_NUM_PREDICATE_OPS: usize = 20; - const MAX_NUM_RESULT_OPS: usize = 20; - const MAX_NUM_RESULTS: usize = 10; - - let column_ids = ColumnIDs::new( - F::rand().to_canonical_u64(), - F::rand().to_canonical_u64(), - (0..NUM_COLUMNS - 2) - .map(|_| F::rand().to_canonical_u64()) - .collect_vec(), - ); - - let primary_index_id: F = column_ids.primary; - let secondary_index_id: F = column_ids.secondary; - - let min_query_primary = 57; - let max_query_primary = 67; - let min_query_secondary = 35; - let max_query_secondary = 78; - // define Enum to specify whether to generate index values in range or not - enum IndexValueBounds { - InRange, // generate index value within query bounds - Smaller, // generate index value smaller than minimum query bound - Bigger, // generate inde value bigger than maximum query bound - } - // generate a new row with `NUM_COLUMNS` where value of secondary index is within the query bounds - let mut gen_row = |primary_index: usize, secondary_index: IndexValueBounds| { - (0..NUM_COLUMNS) - .map(|i| match i { - 0 => U256::from(primary_index), - 1 => match secondary_index { - IndexValueBounds::InRange => { - U256::from(rng.gen_range(min_query_secondary..max_query_secondary)) - } - IndexValueBounds::Smaller => { - U256::from(rng.gen_range(0..min_query_secondary)) - } - IndexValueBounds::Bigger => { - U256::from(rng.gen_range(0..min_query_secondary)) - } - }, - _ => gen_random_u256(rng), - }) - .collect_vec() - }; - - let predicate_operations = vec![BasicOperation { - first_operand: InputOperand::Column(2), - second_operand: Some(InputOperand::Constant(U256::from(5))), - op: Operation::GreaterThanOrEqOp, - }]; - let result_operations = vec![BasicOperation { - first_operand: InputOperand::Column(0), - second_operand: Some(InputOperand::Column(2)), - op: Operation::AddOp, - }]; - let aggregation_op_ids = vec![AggregationOperation::SumOp.to_id()]; - let output_items = vec![OutputItem::ComputedValue(0)]; - let results = ResultStructure::new_for_query_with_aggregation( - result_operations, - output_items, - aggregation_op_ids.clone(), - ) - .unwrap(); - let first_placeholder_id = PlaceholderIdentifier::Generic(0); - let placeholders = Placeholders::from(( - vec![(first_placeholder_id, U256::from(max_query_secondary))], - U256::from(min_query_primary), - U256::from(max_query_primary), - )); - let query_bounds = QueryBounds::new( - &placeholders, - Some(QueryBoundSource::Constant(U256::from(min_query_secondary))), - Some(QueryBoundSource::Placeholder(first_placeholder_id)), - ) - .unwrap(); - - let mut params = Parameters::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >::build(); - - // Test serialization of params - let serialized_params = bincode::serialize(¶ms).unwrap(); - // use deserialized params to generate proofs - params = bincode::deserialize(&serialized_params).unwrap(); - - type Input = CircuitInput< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >; - - // test an index tree with all proven nodes: we assume to have index tree built as follows - // (node identified according to their sorting order): - // 4 - // 0 - // 2 - // 1 3 - - // build a vector of 5 rows with values of index columns within the query bounds. The entries in the - // vector are sorted according to primary index value - let column_values = (min_query_primary..max_query_primary) - .step_by((max_query_primary - min_query_primary) / 5) - .take(5) - .map(|index| gen_row(index, IndexValueBounds::InRange)) - .collect_vec(); - - // generate proof with universal for a row with the `values` provided as input. - // The flag `is_leaf` specifies whether the row is stored in a leaf node of a rows tree - // or not - let gen_universal_circuit_proofs = |values: &[U256], is_leaf: bool| { - let column_cells = values - .iter() - .zip(column_ids.to_vec().iter()) - .map(|(&value, &id)| ColumnCell::new(id.to_canonical_u64(), value)) - .collect_vec(); - let row_cells = RowCells::new( - column_cells[0].clone(), - column_cells[1].clone(), - column_cells[2..].to_vec(), - ); - let input = Input::new_universal_circuit( - &row_cells, - &predicate_operations, - &results, - &placeholders, - is_leaf, - &query_bounds, - ) - .unwrap(); - params.generate_proof(input).unwrap() - }; - - // generate base proofs with universal circuits for each node - let base_proofs = column_values - .iter() - .map(|values| gen_universal_circuit_proofs(values, true)) - .collect_vec(); - - // closure to extract the tree hash from a proof - let get_tree_hash_from_proof = |proof: &[u8]| { - let (proof, _) = ProofWithVK::deserialize(proof).unwrap().into(); - let pis = PublicInputs::::from_slice(&proof.public_inputs); - pis.tree_hash() - }; - - // closure to generate the proof for a leaf node of the index tree, corresponding to the node_index-th row - let gen_leaf_proof_for_node = |node_index: usize| { - let embedded_tree_hash = get_tree_hash_from_proof(&base_proofs[node_index]); - let node_info = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - None, - None, - column_values[node_index][0], // primary index value for this row - column_values[node_index][0], - column_values[node_index][0], - ); - let tree_hash = node_info.compute_node_hash(primary_index_id); - let subtree_proof = - SubProof::new_embedded_tree_proof(base_proofs[node_index].clone()).unwrap(); - let input = Input::new_single_path( - subtree_proof, - None, - None, - node_info, - false, // index tree node - &query_bounds, - ) - .unwrap(); - let proof = params.generate_proof(input).unwrap(); - // check tree hash is correct - assert_eq!(tree_hash, get_tree_hash_from_proof(&proof)); - proof - }; - - // generate proof for node 1 of index tree above - let leaf_proof_left = gen_leaf_proof_for_node(1); - - // generate proof for node 3 of index tree above - let leaf_proof_right = gen_leaf_proof_for_node(3); - - // generate proof for node 2 of index tree above - let left_child_hash = get_tree_hash_from_proof(&leaf_proof_left); - let right_child_hash = get_tree_hash_from_proof(&leaf_proof_right); - let input = Input::new_full_node( - leaf_proof_left, - leaf_proof_right, - base_proofs[2].clone(), - false, - &query_bounds, - ) - .unwrap(); - let full_node_proof = params.generate_proof(input).unwrap(); - - // verify hash is correct - let full_node_info = NodeInfo::new( - &HashOutput::try_from(get_tree_hash_from_proof(&base_proofs[2]).to_bytes()).unwrap(), - Some(&HashOutput::try_from(left_child_hash.to_bytes()).unwrap()), - Some(&HashOutput::try_from(right_child_hash.to_bytes()).unwrap()), - column_values[2][0], // primary index value for that row - column_values[1][0], // primary index value for the min node in the left subtree - column_values[3][0], // primary index value for the max node in the right subtree - ); - let full_node_hash = get_tree_hash_from_proof(&full_node_proof); - assert_eq!( - full_node_hash, - full_node_info.compute_node_hash(primary_index_id), - ); - - // generate proof for node 0 of the index tree above - let input = Input::new_partial_node( - full_node_proof, - base_proofs[0].clone(), - None, // there is no left child - ChildPosition::Right, // proven child is the right child of node 0 - false, - &query_bounds, - ) - .unwrap(); - let one_child_node_proof = params.generate_proof(input).unwrap(); - // verify hash is correct - let one_child_node_info = NodeInfo::new( - &HashOutput::try_from(get_tree_hash_from_proof(&base_proofs[0]).to_bytes()).unwrap(), - None, - Some(&HashOutput::try_from(full_node_hash.to_bytes()).unwrap()), - column_values[0][0], - column_values[0][0], - column_values[3][0], - ); - let one_child_node_hash = get_tree_hash_from_proof(&one_child_node_proof); - assert_eq!( - one_child_node_hash, - one_child_node_info.compute_node_hash(primary_index_id) - ); - - // generate proof for root node - let input = Input::new_partial_node( - one_child_node_proof, - base_proofs[4].clone(), - None, // there is no right child - ChildPosition::Left, // proven child is the left child of root node - false, - &query_bounds, - ) - .unwrap(); - let (root_proof, _) = ProofWithVK::deserialize(¶ms.generate_proof(input).unwrap()) - .unwrap() - .into(); - // check some public inputs for root proof - let check_pis = |root_proof_pis: &[F], node_info: NodeInfo, column_values: &[Vec]| { - let pis = PublicInputs::::from_slice(root_proof_pis); - assert_eq!( - pis.tree_hash(), - node_info.compute_node_hash(primary_index_id), - ); - assert_eq!(pis.min_value(), node_info.min,); - assert_eq!(pis.max_value(), node_info.max,); - assert_eq!(pis.min_query_value(), query_bounds.min_query_primary()); - assert_eq!(pis.max_query_value(), query_bounds.max_query_primary()); - assert_eq!( - pis.index_ids().to_vec(), - vec![column_ids.primary, column_ids.secondary,], - ); - // compute output value: SUM(C1 + C3) for all the rows where C3 >= 5 - let (output, overflow, count) = - column_values - .iter() - .fold((U256::ZERO, false, 0u64), |acc, value| { - if value[2] >= U256::from(5) - && value[0] >= query_bounds.min_query_primary() - && value[0] <= query_bounds.max_query_primary() - && value[1] >= query_bounds.min_query_secondary().value - && value[1] <= query_bounds.max_query_secondary().value - { - let (sum, overflow) = value[0].overflowing_add(value[2]); - let new_overflow = acc.1 || overflow; - let (new_sum, overflow) = sum.overflowing_add(acc.0); - (new_sum, new_overflow || overflow, acc.2 + 1) - } else { - acc - } - }); - assert_eq!(pis.first_value_as_u256(), output,); - assert_eq!(pis.overflow_flag(), overflow,); - assert_eq!(pis.num_matching_rows(), count.to_field(),); - }; - - let root_node_info = NodeInfo::new( - &HashOutput::try_from(get_tree_hash_from_proof(&base_proofs[4]).to_bytes()).unwrap(), - Some(&HashOutput::try_from(one_child_node_hash.to_bytes()).unwrap()), - None, - column_values[4][0], - column_values[0][0], - column_values[4][0], - ); - - check_pis(&root_proof.public_inputs, root_node_info, &column_values); - - // build an index tree with a mix of proven and unproven nodes. The tree is built as follows: - // 0 - // 8 - // 3 9 - // 2 5 - // 1 4 6 - // 7 - // nodes 3,4,5,6 are in the range specified by the query for the primary index, while the other nodes - // are not - let column_values = [0, min_query_primary / 3, min_query_primary * 2 / 3] - .into_iter() // primary index values for nodes 0,1,2 - .chain( - (min_query_primary..max_query_primary) - .step_by((max_query_primary - min_query_primary) / 4) - .take(4), - ) // primary index values for nodes in the range - .chain([ - max_query_primary * 2, - max_query_primary * 3, - max_query_primary * 4, - ]) // primary index values for nodes 7,8, 9 - .map(|index| gen_row(index, IndexValueBounds::InRange)) - .collect_vec(); - - // generate base proofs with universal circuits for each node in the range - const START_NODE_IN_RANGE: usize = 3; - const LAST_NODE_IN_RANGE: usize = 6; - let base_proofs = column_values[START_NODE_IN_RANGE..=LAST_NODE_IN_RANGE] - .iter() - .map(|values| gen_universal_circuit_proofs(values, true)) - .collect_vec(); - - // generate proof for node 4 - let embedded_tree_hash = get_tree_hash_from_proof(&base_proofs[4 - START_NODE_IN_RANGE]); - let node_info = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - None, - None, - column_values[4][0], - column_values[4][0], - column_values[4][0], - ); - let subtree_proof = - SubProof::new_embedded_tree_proof(base_proofs[4 - START_NODE_IN_RANGE].clone()) - .unwrap(); - let hash_4 = node_info.compute_node_hash(primary_index_id); - let input = - Input::new_single_path(subtree_proof, None, None, node_info, false, &query_bounds) - .unwrap(); - let proof_4 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_4, get_tree_hash_from_proof(&proof_4),); - - // generate proof for node 6 - // compute node data for node 7, which is needed as input to generate the proof - let node_info_7 = NodeInfo::new( - // for the sake of this test, we can use random hash for the embedded tree stored in node 7, since it's not proven; - // in a non-test scenario, we would need to get the actual embedded hash of the node, otherwise the root hash of the - // tree computed in the proofs will be incorrect - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[7][0], - column_values[7][0], - column_values[7][0], - ); - let hash_7 = node_info_7.compute_node_hash(primary_index_id); - let embedded_tree_hash = get_tree_hash_from_proof(&base_proofs[6 - START_NODE_IN_RANGE]); - let node_info_6 = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - None, - Some(&HashOutput::try_from(hash_7.to_bytes()).unwrap()), - column_values[6][0], - column_values[6][0], - column_values[7][0], - ); - let subtree_proof = - SubProof::new_embedded_tree_proof(base_proofs[6 - START_NODE_IN_RANGE].clone()) - .unwrap(); - let hash_6 = node_info_6.compute_node_hash(primary_index_id); - let input = Input::new_single_path( - subtree_proof, - None, - Some(node_info_7), - node_info_6, - false, - &query_bounds, - ) - .unwrap(); - let proof_6 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_6, get_tree_hash_from_proof(&proof_6)); - - // generate proof for node 5 - let input = Input::new_full_node( - proof_4, - proof_6, - base_proofs[5 - START_NODE_IN_RANGE].clone(), - false, - &query_bounds, - ) - .unwrap(); - let proof_5 = params.generate_proof(input).unwrap(); - // check hash - let embedded_tree_hash = get_tree_hash_from_proof(&base_proofs[5 - START_NODE_IN_RANGE]); - let node_info_5 = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_4.to_bytes()).unwrap()), - Some(&HashOutput::try_from(hash_6.to_bytes()).unwrap()), - column_values[5][0], - column_values[4][0], - column_values[7][0], - ); - let hash_5 = node_info_5.compute_node_hash(primary_index_id); - assert_eq!(hash_5, get_tree_hash_from_proof(&proof_5),); - - // generate proof for node 3 - // compute node data for node 2, which is needed as input to generate the proof - let node_info_2 = NodeInfo::new( - // same as for node_info_7, we can use random hashes for the sake of this test - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - Some(&HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap()), - None, - column_values[2][0], - column_values[1][0], - column_values[2][0], - ); - let hash_2 = node_info_2.compute_node_hash(primary_index_id); - let input = Input::new_partial_node( - proof_5, - base_proofs[3 - START_NODE_IN_RANGE].clone(), - Some(node_info_2), - ChildPosition::Right, // proven child is right child - false, - &query_bounds, - ) - .unwrap(); - let proof_3 = params.generate_proof(input).unwrap(); - // check hash - let embedded_tree_hash = get_tree_hash_from_proof(&base_proofs[3 - START_NODE_IN_RANGE]); - let node_info_3 = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_2.to_bytes()).unwrap()), - Some(&HashOutput::try_from(hash_5.to_bytes()).unwrap()), - column_values[3][0], - column_values[1][0], - column_values[7][0], - ); - let hash_3 = node_info_3.compute_node_hash(primary_index_id); - assert_eq!(hash_3, get_tree_hash_from_proof(&proof_3),); - - // generate proof for node 8 - // compute node_info_9, which is needed as input for the proof - let node_info_9 = NodeInfo::new( - // same as for node_info_2, we can use random hashes for the sake of this test - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[9][0], - column_values[9][0], - column_values[9][0], - ); - let hash_9 = node_info_9.compute_node_hash(primary_index_id); - let node_info_8 = NodeInfo::new( - // same as for node_info_2, we can use random hashes for the sake of this test - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_3.to_bytes()).unwrap()), - Some(&HashOutput::try_from(hash_9.to_bytes()).unwrap()), - column_values[8][0], - column_values[1][0], - column_values[9][0], - ); - let hash_8 = node_info_8.compute_node_hash(primary_index_id); - let subtree_proof = SubProof::new_child_proof( - proof_3, - ChildPosition::Left, // subtree proof refers to the left child of the node - ) - .unwrap(); - let input = Input::new_single_path( - subtree_proof, - Some(node_info_3), - Some(node_info_9), - node_info_8, - false, - &query_bounds, - ) - .unwrap(); - let proof_8 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(get_tree_hash_from_proof(&proof_8), hash_8); - println!("generate proof for node 0"); - - // generate proof for node 0 (root) - let node_info_0 = NodeInfo::new( - // same as for node_info_1, we can use random hashes for the sake of this test - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - Some(&HashOutput::try_from(hash_8.to_bytes()).unwrap()), - column_values[0][0], - column_values[0][0], - column_values[9][0], - ); - let subtree_proof = SubProof::new_child_proof( - proof_8, - ChildPosition::Right, // subtree proof refers to the right child of the node - ) - .unwrap(); - let input = Input::new_single_path( - subtree_proof, - None, - Some(node_info_8), - node_info_0, - false, - &query_bounds, - ) - .unwrap(); - let (root_proof, _) = ProofWithVK::deserialize(¶ms.generate_proof(input).unwrap()) - .unwrap() - .into(); - - // check some public inputs - check_pis(&root_proof.public_inputs, node_info_0, &column_values); - - // build an index tree with all nodes outside of the primary index range. The tree is built as follows: - // 2 - // 1 3 - // 0 - // where nodes 0 stores an index value smaller than `min_query_primary`, while nodes 1, 2, 3 store index values - // bigger than `max_query_primary` - let column_values = [min_query_primary / 2] - .into_iter() - .chain([ - max_query_primary * 2, - max_query_primary * 3, - max_query_primary * 4, - ]) - .map(|index| gen_row(index, IndexValueBounds::InRange)) - .collect_vec(); - - // generate proof for node 0 with non-existence circuit, since it is outside of the query bounds - let node_info_0 = NodeInfo::new( - // we can use a randomly generated hash for the subtree, for the sake of the test - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[0][0], - column_values[0][0], - column_values[0][0], - ); - let hash_0 = node_info_0.compute_node_hash(primary_index_id); - - // compute hashes associated to query, which are needed as inputs - let query_hashes = QueryHashNonExistenceCircuits::new::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >( - &column_ids, - &predicate_operations, - &results, - &placeholders, - &query_bounds, - false, - ) - .unwrap(); - let input = Input::new_non_existence_input( - node_info_0, - None, - None, - node_info_0.value, - &[ - column_ids.primary.to_canonical_u64(), - column_ids.secondary.to_canonical_u64(), - ], - &[AggregationOperation::SumOp], - query_hashes, - false, - &query_bounds, - &placeholders, - ) - .unwrap(); - let proof_0 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_0, get_tree_hash_from_proof(&proof_0),); - - // get up to the root of the tree with proofs - // generate proof for node 1 - let node_info_1 = NodeInfo::new( - // we can use a random hash for the embedded tree - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_0.to_bytes()).unwrap()), - None, - column_values[1][0], - column_values[0][0], - column_values[1][0], - ); - let hash_1 = node_info_1.compute_node_hash(primary_index_id); - let subtree_proof = SubProof::new_child_proof(proof_0, ChildPosition::Left).unwrap(); - let input = Input::new_single_path( - subtree_proof, - Some(node_info_0), - None, - node_info_1, - false, - &query_bounds, - ) - .unwrap(); - let proof_1 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_1, get_tree_hash_from_proof(&proof_1),); - - // generate proof for root node - let node_info_2 = NodeInfo::new( - // we can use a random hash for the embedded tree - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_1.to_bytes()).unwrap()), - None, - column_values[2][0], - column_values[0][0], - column_values[2][0], - ); - let node_info_3 = NodeInfo::new( - // we can use a random hash for the embedded tree - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[3][0], - column_values[3][0], - column_values[3][0], - ); - let subtree_proof = SubProof::new_child_proof(proof_1, ChildPosition::Left).unwrap(); - let input = Input::new_single_path( - subtree_proof, - Some(node_info_1), - Some(node_info_3), - node_info_2, - false, - &query_bounds, - ) - .unwrap(); - let (root_proof, _) = ProofWithVK::deserialize(¶ms.generate_proof(input).unwrap()) - .unwrap() - .into(); - - check_pis(&root_proof.public_inputs, node_info_2, &column_values); - - // generate non-existence proof starting from intermediate node (i.e., node 1) rather than a leaf node - // generate proof with non-existence circuit for node 1 - - // compute hashes associated to query, which are needed as inputs - let query_hashes = QueryHashNonExistenceCircuits::new::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >( - &column_ids, - &predicate_operations, - &results, - &placeholders, - &query_bounds, - false, - ) - .unwrap(); - let input = Input::new_non_existence_input( - node_info_1, - Some(node_info_0), // node 0 is the left child - None, - node_info_1.value, - &[ - column_ids.primary.to_canonical_u64(), - column_ids.secondary.to_canonical_u64(), - ], - &[AggregationOperation::SumOp], - query_hashes, - false, - &query_bounds, - &placeholders, - ) - .unwrap(); - let proof_1 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_1, get_tree_hash_from_proof(&proof_1),); - - // generate proof for root node - let subtree_proof = SubProof::new_child_proof(proof_1, ChildPosition::Left).unwrap(); - let input = Input::new_single_path( - subtree_proof, - Some(node_info_1), - Some(node_info_3), - node_info_2, - false, - &query_bounds, - ) - .unwrap(); - let (root_proof, _) = ProofWithVK::deserialize(¶ms.generate_proof(input).unwrap()) - .unwrap() - .into(); - - check_pis(&root_proof.public_inputs, node_info_2, &column_values); - - // generate a tree with rows tree with more than one node. We generate an index tree with 2 nodes A and B, - // both storing a primary index value within the query bounds. - // Node A stores a rows tree with all entries outside of query bounds for secondary index, while - // node B stores a rows tree with all entries within query bounds for secondary index. - // The tree is structured as follows: - // B - // 4 - // 3 5 - // A - // 1 - // 0 2 - let mut column_values = vec![ - gen_row(min_query_primary, IndexValueBounds::Smaller), - gen_row(min_query_primary, IndexValueBounds::Smaller), - gen_row(min_query_primary, IndexValueBounds::Bigger), - gen_row(max_query_primary, IndexValueBounds::InRange), - gen_row(max_query_primary, IndexValueBounds::InRange), - gen_row(max_query_primary, IndexValueBounds::InRange), - ]; - // sort column values according to primary/secondary index values - column_values.sort_by(|a, b| match a[0].cmp(&b[0]) { - Ordering::Less => Ordering::Less, - Ordering::Greater => Ordering::Greater, - Ordering::Equal => a[1].cmp(&b[1]), - }); - - // generate proof for node A rows tree - // generate non-existence proof for node 2, which is the smallest node higher than the maximum query bound, since - // node 1, which is the highest node smaller than the minimum query bound, has 2 children - // (see non-existence circuit docs to see why we don't generate non-existence proofs for nodes with 2 children) - let node_info_2 = NodeInfo::new( - // we can use a random hash for the embedded tree - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[2][1], - column_values[2][1], - column_values[2][1], - ); - let hash_2 = node_info_2.compute_node_hash(secondary_index_id); - - // compute hashes associated to query, which are needed as inputs - let query_hashes = QueryHashNonExistenceCircuits::new::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >( - &column_ids, - &predicate_operations, - &results, - &placeholders, - &query_bounds, - true, - ) - .unwrap(); - let input = Input::new_non_existence_input( - node_info_2, - None, - None, - column_values[2][0], // we need to place the primary index value associated to this row - &[ - column_ids.primary.to_canonical_u64(), - column_ids.secondary.to_canonical_u64(), - ], - &[AggregationOperation::SumOp], - query_hashes, - true, - &query_bounds, - &placeholders, - ) - .unwrap(); - let proof_2 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_2, get_tree_hash_from_proof(&proof_2),); - - // generate proof for node 1 (root of rows tree for node A) - let node_info_1 = NodeInfo::new( - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - Some(&HashOutput::try_from(hash_2.to_bytes()).unwrap()), - column_values[1][1], - column_values[0][1], - column_values[2][1], - ); - let node_info_0 = NodeInfo::new( - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[0][1], - column_values[0][1], - column_values[0][1], - ); - let hash_1 = node_info_1.compute_node_hash(secondary_index_id); - let subtree_proof = SubProof::new_child_proof(proof_2, ChildPosition::Right).unwrap(); - let input = Input::new_single_path( - subtree_proof, - Some(node_info_0), - Some(node_info_2), - node_info_1, - true, - &query_bounds, - ) - .unwrap(); - let proof_1 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_1, get_tree_hash_from_proof(&proof_1),); - - // generate proof for node A (leaf of index tree) - let node_info_a = NodeInfo::new( - &HashOutput::try_from(hash_1.to_bytes()).unwrap(), - None, - None, - column_values[0][0], - column_values[0][0], - column_values[0][0], - ); - let hash_a = node_info_a.compute_node_hash(primary_index_id); - let subtree_proof = SubProof::new_embedded_tree_proof(proof_1).unwrap(); - let input = - Input::new_single_path(subtree_proof, None, None, node_info_a, false, &query_bounds) - .unwrap(); - let proof_a = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_a, get_tree_hash_from_proof(&proof_a),); - - // generate proof for node B rows tree - // all the nodes are in the range, so we generate proofs for each of the nodes - // generate proof for nodes 3 and 5: they are leaf nodes in the rows tree, so we directly use the universal circuit - let [proof_3, proof_5] = [&column_values[3], &column_values[5]] - .map(|values| gen_universal_circuit_proofs(values, true)); - // node 4 is not a leaf in the rows tree, so instead we need to first generate a proof for the row results using - // the universal circuit, and then we generate the proof for the rows tree node - let row_proof = gen_universal_circuit_proofs(&column_values[4], false); - let hash_3 = get_tree_hash_from_proof(&proof_3); - let hash_5 = get_tree_hash_from_proof(&proof_5); - let embedded_tree_hash = get_tree_hash_from_proof(&row_proof); - let input = Input::new_full_node(proof_3, proof_5, row_proof, true, &query_bounds).unwrap(); - let proof_4 = params.generate_proof(input).unwrap(); - // check hash - let node_info_4 = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_3.to_bytes()).unwrap()), - Some(&HashOutput::try_from(hash_5.to_bytes()).unwrap()), - column_values[4][1], - column_values[3][1], - column_values[5][1], - ); - let hash_4 = node_info_4.compute_node_hash(secondary_index_id); - assert_eq!(hash_4, get_tree_hash_from_proof(&proof_4),); - - // generate proof for node B of the index tree (root node) - let node_info_root = NodeInfo::new( - &HashOutput::try_from(hash_4.to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_a.to_bytes()).unwrap()), - None, - column_values[4][0], - column_values[0][0], - column_values[5][0], - ); - let input = Input::new_partial_node( - proof_a, - proof_4, - None, - ChildPosition::Left, - false, - &query_bounds, - ) - .unwrap(); - let (root_proof, _) = ProofWithVK::deserialize(¶ms.generate_proof(input).unwrap()) - .unwrap() - .into(); - check_pis(&root_proof.public_inputs, node_info_root, &column_values); + pub(crate) fn get_universal_circuit(&self) -> &UniversalQueryCircuitParams< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + OutputNoAggCircuit, + > { + &self.universal_circuit } } diff --git a/verifiable-db/src/query/batching/circuits/api.rs b/verifiable-db/src/query/batching/circuits/api.rs index 6ad3cd1fb..54832c99d 100644 --- a/verifiable-db/src/query/batching/circuits/api.rs +++ b/verifiable-db/src/query/batching/circuits/api.rs @@ -38,275 +38,6 @@ use super::{ non_existence::{NonExistenceCircuit, NonExistenceWires}, row_chunk_processing::{RowChunkProcessingCircuit, RowChunkProcessingWires}, }; -/// Data structure containing all the information needed to verify the membership of -/// a node in a tree and to compute info about its predecessor/successor -#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)] -pub struct TreePathInputs { - /// Info about the node - pub(crate) node_info: NodeInfo, - /// Info about the nodes in the path from the node up to the root of the tree; The `ChildPosition` refers to - /// the position of the previous node in the path as a child of the current node - pub(crate) path: Vec<(NodeInfo, ChildPosition)>, - /// Hash of the siblings of the nodes in path (except for the root) - pub(crate) siblings: Vec>, - /// Info about the children of the node - pub(crate) children: [Option; 2], -} - -impl TreePathInputs { - /// Instantiate a new instance of `TreePathInputs` for a given node from the following input data: - /// - `node_info`: data about the given node - /// - `path`: data about the nodes in the path from the node up to the root of the tree; - /// The `ChildPosition` refers to the position of the previous node in the path as a child of the current node - /// - `siblings`: hash of the siblings of the nodes in the path (except for the root) - /// - `children`: data about the children of the given node - pub fn new( - node_info: NodeInfo, - path: Vec<(NodeInfo, ChildPosition)>, - children: [Option; 2], - ) -> Self { - let siblings = path - .iter() - .map(|(node, child_pos)| { - let sibling_index = match *child_pos { - ChildPosition::Left => 1, - ChildPosition::Right => 0, - }; - Some(HashOutput::from(node.child_hashes[sibling_index])) - }) - .collect_vec(); - Self { - node_info, - path, - siblings, - children, - } - } -} - -#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)] -/// Data structure containing the information about the paths in both the rows tree -/// and the index tree for a node in a rows tree -pub struct NodePath { - pub(crate) row_tree_path: TreePathInputs, - /// Info about the node of the index tree storing the rows tree containing the row - pub(crate) index_tree_path: TreePathInputs, -} - -impl NodePath { - /// Instantiate a new instance of `NodePath` for a given proven row from the following input data: - /// - `row_path`: path from the node to the root of the rows tree storing the node - /// - `index_path` : path from the index tree node storing the rows tree containing the node, up to the - /// root of the index tree - pub fn new(row_path: TreePathInputs, index_path: TreePathInputs) -> Self { - Self { - row_tree_path: row_path, - index_tree_path: index_path, - } - } -} - -#[derive(Clone, Debug)] -/// Data structure containing the inputs necessary to prove a query for a row -/// of the DB table. -pub struct RowInput { - pub(crate) cells: RowCells, - pub(crate) path: NodePath, -} - -impl RowInput { - /// Initialize `RowInput` from the set of cells of the given row and the path - /// in the tree of the node of the rows tree associated to the given row - pub fn new(cells: &RowCells, path: &NodePath) -> Self { - Self { - cells: cells.clone(), - path: path.clone(), - } - } -} -#[derive(Serialize, Deserialize)] -#[allow(clippy::large_enum_variant)] -pub enum CircuitInput< - const NUM_CHUNKS: usize, - const NUM_ROWS: usize, - const ROW_TREE_MAX_DEPTH: usize, - const INDEX_TREE_MAX_DEPTH: usize, - const MAX_NUM_COLUMNS: usize, - const MAX_NUM_PREDICATE_OPS: usize, - const MAX_NUM_RESULT_OPS: usize, - const MAX_NUM_RESULTS: usize, -> where - [(); ROW_TREE_MAX_DEPTH - 1]:, - [(); INDEX_TREE_MAX_DEPTH - 1]:, -{ - RowChunkWithAggregation( - RowChunkProcessingCircuit< - NUM_ROWS, - ROW_TREE_MAX_DEPTH, - INDEX_TREE_MAX_DEPTH, - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - OutputAggCircuit, - >, - ), - ChunkAggregation(ChunkAggregationInputs), - NonExistence(NonExistenceCircuit), -} - -impl< - const NUM_CHUNKS: usize, - const NUM_ROWS: usize, - const ROW_TREE_MAX_DEPTH: usize, - const INDEX_TREE_MAX_DEPTH: usize, - const MAX_NUM_COLUMNS: usize, - const MAX_NUM_PREDICATE_OPS: usize, - const MAX_NUM_RESULT_OPS: usize, - const MAX_NUM_RESULTS: usize, - > - CircuitInput< - NUM_CHUNKS, - NUM_ROWS, - ROW_TREE_MAX_DEPTH, - INDEX_TREE_MAX_DEPTH, - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - > -where - [(); ROW_TREE_MAX_DEPTH - 1]:, - [(); INDEX_TREE_MAX_DEPTH - 1]:, - [(); MAX_NUM_RESULTS - 1]:, - [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, - [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, -{ - /// Construct the input necessary to prove a query over a chunk of rows provided as input. - /// It requires to provide at least 1 row; in case there are no rows to be proven, then - /// `Self::new_non_existence_input` should be used instead - pub fn new_row_chunks_input( - rows: &[RowInput], - predicate_operations: &[BasicOperation], - placeholders: &Placeholders, - query_bounds: &QueryBounds, - results: &ResultStructure, - ) -> Result { - ensure!( - !rows.is_empty(), - "there must be at least 1 row to be proven" - ); - ensure!( - rows.len() <= NUM_ROWS, - format!( - "Found {} rows provided as input, maximum allowed is {NUM_ROWS}", - rows.len() - ) - ); - let column_ids = &rows[0].cells.column_ids(); - ensure!( - rows.iter() - .all(|row| row.cells.column_ids().to_vec() == column_ids.to_vec()), - "Rows provided as input don't have the same column ids", - ); - let row_inputs = rows - .iter() - .map(RowProcessingGadgetInputs::try_from) - .collect::>>()?; - - Ok(Self::RowChunkWithAggregation( - RowChunkProcessingCircuit::new( - row_inputs, - column_ids, - predicate_operations, - placeholders, - query_bounds, - results, - )?, - )) - } - - /// Construct the input necessary to aggregate 2 or more row chunks already proven. - /// It requires at least 2 chunks to be aggregated - pub fn new_chunk_aggregation_input(chunks_proofs: &[Vec]) -> Result { - ensure!( - chunks_proofs.len() >= 2, - "At least 2 chunk proofs must be provided" - ); - // deserialize `chunk_proofs`` and pad to NUM_CHUNKS proofs by replicating the last proof in `chunk_proofs` - let last_proof = chunks_proofs.last().unwrap(); - let proofs = chunks_proofs - .iter() - .map(|p| ProofWithVK::deserialize(p)) - .chain(repeat_with(|| ProofWithVK::deserialize(last_proof))) - .take(NUM_CHUNKS) - .collect::>>()?; - - let num_proofs = chunks_proofs.len(); - - ensure!( - num_proofs <= NUM_CHUNKS, - format!("Found {num_proofs} proofs provided as input, maximum allowed is {NUM_CHUNKS}") - ); - - Ok(Self::ChunkAggregation(ChunkAggregationInputs { - chunk_proofs: proofs.try_into().unwrap(), - circuit: ChunkAggregationCircuit { - num_non_dummy_chunks: num_proofs, - }, - })) - } - - /// Construct the input to prove a query in case there are no rows with a primary index value - /// in the primary query range. The circuit employed to prove the non-existence of such a row - /// requires to provide a specific node of the index tree, as described in the docs - /// https://www.notion.so/lagrangelabs/Batching-Query-10628d1c65a880b1b151d4ac017fa445?pvs=4#10e28d1c65a880498f41cd1cad0c61c3 - pub fn new_non_existence_input( - index_node_path: TreePathInputs, - column_ids: &ColumnIDs, - predicate_operations: &[BasicOperation], - results: &ResultStructure, - placeholders: &Placeholders, - query_bounds: &QueryBounds, - ) -> Result { - let QueryHashNonExistenceCircuits { - computational_hash, - placeholder_hash, - } = QueryHashNonExistenceCircuits::new::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >( - column_ids, - predicate_operations, - results, - placeholders, - query_bounds, - false, - )?; - - let aggregation_operations = results - .aggregation_operations() - .into_iter() - .chain(repeat( - Identifiers::AggregationOperations(AggregationOperation::default()).to_field(), - )) - .take(MAX_NUM_RESULTS) - .collect_vec() - .try_into() - .unwrap(); - - Ok(Self::NonExistence(NonExistenceCircuit::new( - &index_node_path, - column_ids.primary, - aggregation_operations, - computational_hash, - placeholder_hash, - query_bounds, - )?)) - } -} #[derive(Debug, Serialize, Deserialize)] pub(crate) struct Parameters< diff --git a/verifiable-db/src/query/batching/circuits/chunk_aggregation.rs b/verifiable-db/src/query/batching/circuits/chunk_aggregation.rs index 354959fa6..90a9f0d47 100644 --- a/verifiable-db/src/query/batching/circuits/chunk_aggregation.rs +++ b/verifiable-db/src/query/batching/circuits/chunk_aggregation.rs @@ -22,12 +22,10 @@ use plonky2::{ use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; -use crate::query::batching::{ - public_inputs::PublicInputs, row_chunk::aggregate_chunks::aggregate_chunks, +use crate::query::{ + batching::row_chunk::aggregate_chunks::aggregate_chunks, pi_len, public_inputs::PublicInputs }; -use super::api::num_io; - #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ChunkAggregationWires { #[serde( @@ -156,7 +154,7 @@ where type Inputs = ChunkAggregationCircuit; - const NUM_PUBLIC_INPUTS: usize = num_io::(); + const NUM_PUBLIC_INPUTS: usize = pi_len::(); fn circuit_logic( builder: &mut CircuitBuilder, @@ -201,7 +199,7 @@ mod tests { use crate::{ query::{ aggregation::tests::aggregate_output_values, - batching::public_inputs::PublicInputs, + public_inputs::PublicInputs, computational_hash_ids::{AggregationOperation, Identifiers}, universal_circuit::universal_query_gadget::OutputValues, }, diff --git a/verifiable-db/src/query/batching/circuits/mod.rs b/verifiable-db/src/query/batching/circuits/mod.rs index 28a3215e6..a6302b080 100644 --- a/verifiable-db/src/query/batching/circuits/mod.rs +++ b/verifiable-db/src/query/batching/circuits/mod.rs @@ -1,4 +1,3 @@ -pub(crate) mod api; pub(crate) mod chunk_aggregation; pub(crate) mod non_existence; pub(crate) mod row_chunk_processing; @@ -20,7 +19,7 @@ mod tests { use crate::query::{ aggregation::{NodeInfo, QueryBounds}, - batching::public_inputs::tests::gen_values_in_range, + public_inputs::tests::gen_values_in_range, computational_hash_ids::AggregationOperation, merkle_path::tests::build_node, universal_circuit::{ diff --git a/verifiable-db/src/query/batching/circuits/non_existence.rs b/verifiable-db/src/query/batching/circuits/non_existence.rs index 5bbef4eb9..e77139bfd 100644 --- a/verifiable-db/src/query/batching/circuits/non_existence.rs +++ b/verifiable-db/src/query/batching/circuits/non_existence.rs @@ -22,21 +22,13 @@ use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; use crate::query::{ - aggregation::{output_computation::compute_dummy_output_targets, QueryBounds}, - batching::{ - public_inputs::PublicInputs, - row_chunk::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget}, - }, - merkle_path::{ + aggregation::{output_computation::compute_dummy_output_targets, QueryBounds}, api::TreePathInputs, batching::row_chunk::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget}, merkle_path::{ MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, NeighborInfoTarget, - }, - universal_circuit::{ + }, pi_len, public_inputs::PublicInputs, universal_circuit::{ ComputationalHash, ComputationalHashTarget, PlaceholderHash, PlaceholderHashTarget, - }, + } }; -use super::api::{num_io, TreePathInputs}; - #[derive(Clone, Debug, Serialize, Deserialize)] pub struct NonExistenceWires where @@ -254,7 +246,7 @@ where type Inputs = NonExistenceCircuit; - const NUM_PUBLIC_INPUTS: usize = num_io::(); + const NUM_PUBLIC_INPUTS: usize = pi_len::(); fn circuit_logic( builder: &mut CircuitBuilder, @@ -292,11 +284,9 @@ mod tests { aggregation::{ output_computation::tests::compute_dummy_output_values, ChildPosition, QueryBounds, }, - batching::{ - circuits::api::TreePathInputs, - public_inputs::{tests::gen_values_in_range, PublicInputs}, - row_chunk::tests::{BoundaryRowData, BoundaryRowNodeInfo}, - }, + api::TreePathInputs, + batching::row_chunk::tests::{BoundaryRowData, BoundaryRowNodeInfo}, + public_inputs::{tests::gen_values_in_range, PublicInputs}, merkle_path::tests::{generate_test_tree, NeighborInfo}, universal_circuit::universal_circuit_inputs::Placeholders, }, diff --git a/verifiable-db/src/query/batching/circuits/row_chunk_processing.rs b/verifiable-db/src/query/batching/circuits/row_chunk_processing.rs index ad2309af4..a752ddbe2 100644 --- a/verifiable-db/src/query/batching/circuits/row_chunk_processing.rs +++ b/verifiable-db/src/query/batching/circuits/row_chunk_processing.rs @@ -10,35 +10,28 @@ use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::query::{ - aggregation::QueryBounds, - batching::{ - public_inputs::PublicInputs, - row_chunk, + aggregation::QueryBounds, batching::row_chunk:: + { row_process_gadget::{RowProcessingGadgetInputWires, RowProcessingGadgetInputs}, - }, - computational_hash_ids::ColumnIDs, - universal_circuit::{ + aggregate_chunks::aggregate_chunks, RowChunkDataTarget, + }, + computational_hash_ids::ColumnIDs, pi_len, public_inputs::PublicInputs, universal_circuit::{ universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure}, universal_query_gadget::{ OutputComponent, UniversalQueryHashInputWires, UniversalQueryHashInputs, }, - }, + } }; use mp2_common::{ public_inputs::PublicInputCommon, serialization::{deserialize_long_array, serialize_long_array}, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, utils::ToTargets, D, F, }; -use self::row_chunk::{aggregate_chunks::aggregate_chunks, RowChunkDataTarget}; - use anyhow::{ensure, Result}; -use super::api::num_io; - #[derive(Clone, Debug, Serialize, Deserialize)] pub struct RowChunkProcessingWires< const NUM_ROWS: usize, @@ -67,8 +60,6 @@ pub struct RowChunkProcessingWires< MAX_NUM_RESULTS, T, >, - min_query_primary: UInt256Target, - max_query_primary: UInt256Target, } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -193,15 +184,11 @@ where T, > { let query_input_wires = UniversalQueryHashInputs::build(b); - let [min_query_primary, max_query_primary] = b.add_virtual_u256_arr_unsafe(); // unsafe should be ok since - // we are exposing these values as public inputs let first_row_wires = RowProcessingGadgetInputs::build( b, &query_input_wires.input_wires, &query_input_wires.min_secondary, &query_input_wires.max_secondary, - &min_query_primary, - &max_query_primary, ); // enforce first row is non-dummy b.assert_one( @@ -222,8 +209,6 @@ where &query_input_wires.input_wires, &query_input_wires.min_secondary, &query_input_wires.max_secondary, - &min_query_primary, - &max_query_primary, ); row_inputs.push(RowProcessingGadgetInputWires::from(&row_wires)); let is_second_non_dummy = row_wires.value_wires.input_wires.is_non_dummy_row; @@ -232,7 +217,7 @@ where b, &chunk, ¤t_chunk, - (&min_query_primary, &max_query_primary), + (&query_input_wires.input_wires.min_query_primary, &query_input_wires.input_wires.max_query_primary), ( &query_input_wires.min_secondary, &query_input_wires.max_secondary, @@ -258,8 +243,8 @@ where &query_input_wires.agg_ops_ids, &row_chunk.left_boundary_row.to_targets(), &row_chunk.right_boundary_row.to_targets(), - &min_query_primary.to_targets(), - &max_query_primary.to_targets(), + &query_input_wires.input_wires.min_query_primary.to_targets(), + &query_input_wires.input_wires.max_query_primary.to_targets(), &query_input_wires.min_secondary.to_targets(), &query_input_wires.max_secondary.to_targets(), &[overflow.target], @@ -271,8 +256,6 @@ where RowChunkProcessingWires { row_inputs: row_inputs.try_into().unwrap(), universal_query_inputs: query_input_wires.input_wires, - min_query_primary, - max_query_primary, } } @@ -296,12 +279,6 @@ where .for_each(|(value, target)| value.assign(pw, target)); self.universal_query_inputs .assign(pw, &wires.universal_query_inputs); - [ - (self.min_query_primary, &wires.min_query_primary), - (self.max_query_primary, &wires.max_query_primary), - ] - .into_iter() - .for_each(|(value, target)| pw.set_u256_target(target, value)); } /// This method returns the ids of the placeholders employed to compute the placeholder hash, @@ -354,7 +331,7 @@ where T, >; - const NUM_PUBLIC_INPUTS: usize = num_io::(); + const NUM_PUBLIC_INPUTS: usize = pi_len::(); fn circuit_logic( builder: &mut CircuitBuilder, @@ -405,10 +382,12 @@ mod tests { row_chunk_processing::RowChunkProcessingCircuit, tests::{build_test_tree, compute_output_values_for_row}, }, - public_inputs::PublicInputs, - row_chunk::tests::{BoundaryRowData, BoundaryRowNodeInfo}, - row_process_gadget::RowProcessingGadgetInputs, + row_chunk::{ + tests::{BoundaryRowData, BoundaryRowNodeInfo}, + row_process_gadget::RowProcessingGadgetInputs + }, }, + public_inputs::PublicInputs, computational_hash_ids::{ AggregationOperation, ColumnIDs, Identifiers, Operation, PlaceholderIdentifier, }, diff --git a/verifiable-db/src/query/batching/mod.rs b/verifiable-db/src/query/batching/mod.rs index bc4c2852f..015763f6b 100644 --- a/verifiable-db/src/query/batching/mod.rs +++ b/verifiable-db/src/query/batching/mod.rs @@ -1,6 +1,2 @@ pub(crate) mod circuits; -pub mod public_inputs; pub(crate) mod row_chunk; -mod row_process_gadget; - -pub use circuits::api::*; diff --git a/verifiable-db/src/query/batching/public_inputs.rs b/verifiable-db/src/query/batching/public_inputs.rs deleted file mode 100644 index fd0472e85..000000000 --- a/verifiable-db/src/query/batching/public_inputs.rs +++ /dev/null @@ -1,695 +0,0 @@ -use std::iter::once; - -use alloy::primitives::U256; -use itertools::Itertools; -use mp2_common::{ - public_inputs::{PublicInputCommon, PublicInputRange}, - types::CBuilder, - u256::UInt256Target, - utils::{FromFields, FromTargets, TryIntoBool}, - F, -}; -use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget}, - iop::target::{BoolTarget, Target}, -}; -use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; - -use crate::query::{ - aggregation::output_computation::compute_dummy_output_targets, - universal_circuit::universal_query_gadget::{ - CurveOrU256Target, OutputValues, OutputValuesTarget, UniversalQueryOutputWires, - }, -}; - -use super::row_chunk::{BoundaryRowDataTarget, RowChunkDataTarget}; - -/// Query circuits public inputs -pub enum QueryPublicInputs { - /// `H`: Hash of the tree - TreeHash, - /// `V`: Set of `S` values representing the cumulative results of the query, where`S` is a parameter - /// specifying the maximum number of cumulative results we support; - /// the first value could be either a `u256` or a `CurveTarget`, depending on the query, and so we always - /// represent this value with `CURVE_TARGET_LEN` elements; all the other `S-1` values are always `u256` - OutputValues, - /// `count`: `F` Number of matching records in the query - NumMatching, - /// `ops` : `[F; S]` Set of identifiers of the aggregation operations for each of the `S` items found in `V` - /// (like "SUM", "MIN", "MAX", "COUNT" operations) - OpIds, - /// Data associated to the left boundary row of the row chunk being proven - LeftBoundaryRow, - /// Data associated to the right boundary row of the row chunk being proven - RightBoundaryRow, - /// `MIN_primary`: `u256` Lower bound of the range of primary indexed column values specified in the query - MinPrimary, - /// `MAX_primary`: `u256` Upper bound of the range of primary indexed column values specified in the query - MaxPrimary, - /// `MIN_primary`: `u256` Lower bound of the range of secondary indexed column values specified in the query - MinSecondary, - /// `MAX_secondary`: `u256` Upper bound of the range of secondary indexed column values specified in the query - MaxSecondary, - /// `overflow` : `bool` Flag specifying whether an overflow error has occurred in arithmetic - Overflow, - /// `C`: computational hash - ComputationalHash, - /// `H_p` : placeholder hash - PlaceholderHash, -} - -#[derive(Clone, Debug)] -pub struct PublicInputs<'a, T, const S: usize> { - h: &'a [T], - v: &'a [T], - ops: &'a [T], - count: &'a T, - left_row: &'a [T], - right_row: &'a [T], - min_p: &'a [T], - max_p: &'a [T], - min_s: &'a [T], - max_s: &'a [T], - overflow: &'a T, - ch: &'a [T], - ph: &'a [T], -} - -const NUM_PUBLIC_INPUTS: usize = QueryPublicInputs::PlaceholderHash as usize + 1; - -impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { - const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ - Self::to_range(QueryPublicInputs::TreeHash), - Self::to_range(QueryPublicInputs::OutputValues), - Self::to_range(QueryPublicInputs::NumMatching), - Self::to_range(QueryPublicInputs::OpIds), - Self::to_range(QueryPublicInputs::LeftBoundaryRow), - Self::to_range(QueryPublicInputs::RightBoundaryRow), - Self::to_range(QueryPublicInputs::MinPrimary), - Self::to_range(QueryPublicInputs::MaxPrimary), - Self::to_range(QueryPublicInputs::MinSecondary), - Self::to_range(QueryPublicInputs::MaxSecondary), - Self::to_range(QueryPublicInputs::Overflow), - Self::to_range(QueryPublicInputs::ComputationalHash), - Self::to_range(QueryPublicInputs::PlaceholderHash), - ]; - - const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ - // Tree hash - HashOutTarget::NUM_TARGETS, - // Output values - CurveTarget::NUM_TARGETS + UInt256Target::NUM_TARGETS * (S - 1), - // Number of matching records - 1, - // Operation identifiers - S, - // Left boundary row - BoundaryRowDataTarget::NUM_TARGETS, - // Right boundary row - BoundaryRowDataTarget::NUM_TARGETS, - // Min primary index - UInt256Target::NUM_TARGETS, - // Max primary index - UInt256Target::NUM_TARGETS, - // Min secondary index - UInt256Target::NUM_TARGETS, - // Max secondary index - UInt256Target::NUM_TARGETS, - // Overflow flag - 1, - // Computational hash - HashOutTarget::NUM_TARGETS, - // Placeholder hash - HashOutTarget::NUM_TARGETS, - ]; - - pub const fn to_range(query_pi: QueryPublicInputs) -> PublicInputRange { - let mut i = 0; - let mut offset = 0; - let pi_pos = query_pi as usize; - while i < pi_pos { - offset += Self::SIZES[i]; - i += 1; - } - offset..offset + Self::SIZES[pi_pos] - } - - pub(crate) const fn total_len() -> usize { - Self::to_range(QueryPublicInputs::PlaceholderHash).end - } - - pub(crate) fn to_hash_raw(&self) -> &[T] { - self.h - } - - pub(crate) fn to_values_raw(&self) -> &[T] { - self.v - } - - pub(crate) fn to_count_raw(&self) -> &T { - self.count - } - - pub(crate) fn to_ops_raw(&self) -> &[T] { - self.ops - } - - pub(crate) fn to_left_row_raw(&self) -> &[T] { - self.left_row - } - - pub(crate) fn to_right_row_raw(&self) -> &[T] { - self.right_row - } - - pub(crate) fn to_min_primary_raw(&self) -> &[T] { - self.min_p - } - - pub(crate) fn to_max_primary_raw(&self) -> &[T] { - self.max_p - } - - pub(crate) fn to_min_secondary_raw(&self) -> &[T] { - self.min_s - } - - pub(crate) fn to_max_secondary_raw(&self) -> &[T] { - self.max_s - } - - pub(crate) fn to_overflow_raw(&self) -> &T { - self.overflow - } - - pub(crate) fn to_computational_hash_raw(&self) -> &[T] { - self.ch - } - - pub(crate) fn to_placeholder_hash_raw(&self) -> &[T] { - self.ph - } - - pub fn from_slice(input: &'a [T]) -> Self { - assert!( - input.len() >= Self::total_len(), - "input slice too short to build query public inputs, must be at least {} elements", - Self::total_len() - ); - Self { - h: &input[Self::PI_RANGES[0].clone()], - v: &input[Self::PI_RANGES[1].clone()], - count: &input[Self::PI_RANGES[2].clone()][0], - ops: &input[Self::PI_RANGES[3].clone()], - left_row: &input[Self::PI_RANGES[4].clone()], - right_row: &input[Self::PI_RANGES[5].clone()], - min_p: &input[Self::PI_RANGES[6].clone()], - max_p: &input[Self::PI_RANGES[7].clone()], - min_s: &input[Self::PI_RANGES[8].clone()], - max_s: &input[Self::PI_RANGES[9].clone()], - overflow: &input[Self::PI_RANGES[10].clone()][0], - ch: &input[Self::PI_RANGES[11].clone()], - ph: &input[Self::PI_RANGES[12].clone()], - } - } - - #[allow(clippy::too_many_arguments)] - pub fn new( - h: &'a [T], - v: &'a [T], - count: &'a [T], - ops: &'a [T], - left_row: &'a [T], - right_row: &'a [T], - min_p: &'a [T], - max_p: &'a [T], - min_s: &'a [T], - max_s: &'a [T], - overflow: &'a [T], - ch: &'a [T], - ph: &'a [T], - ) -> Self { - Self { - h, - v, - count: &count[0], - ops, - left_row, - right_row, - min_p, - max_p, - min_s, - max_s, - overflow: &overflow[0], - ch, - ph, - } - } - - pub fn to_vec(&self) -> Vec { - self.h - .iter() - .chain(self.v.iter()) - .chain(once(self.count)) - .chain(self.ops.iter()) - .chain(self.left_row.iter()) - .chain(self.right_row.iter()) - .chain(self.min_p.iter()) - .chain(self.max_p.iter()) - .chain(self.min_s.iter()) - .chain(self.max_s.iter()) - .chain(once(self.overflow)) - .chain(self.ch.iter()) - .chain(self.ph.iter()) - .cloned() - .collect_vec() - } -} - -impl PublicInputCommon for PublicInputs<'_, Target, S> { - const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES; - - fn register_args(&self, cb: &mut CBuilder) { - cb.register_public_inputs(self.h); - cb.register_public_inputs(self.v); - cb.register_public_input(*self.count); - cb.register_public_inputs(self.ops); - cb.register_public_inputs(self.left_row); - cb.register_public_inputs(self.right_row); - cb.register_public_inputs(self.min_p); - cb.register_public_inputs(self.max_p); - cb.register_public_inputs(self.min_s); - cb.register_public_inputs(self.max_s); - cb.register_public_input(*self.overflow); - cb.register_public_inputs(self.ch); - cb.register_public_inputs(self.ph); - } -} - -impl PublicInputs<'_, Target, S> { - pub fn tree_hash_target(&self) -> HashOutTarget { - HashOutTarget::try_from(self.to_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length - } - /// Return the first output value as a `CurveTarget` - pub fn first_value_as_curve_target(&self) -> CurveTarget { - let targets = self.to_values_raw(); - CurveOrU256Target::from_targets(targets).as_curve_target() - } - - /// Return the first output value as a `UInt256Target` - pub fn first_value_as_u256_target(&self) -> UInt256Target { - let targets = self.to_values_raw(); - CurveOrU256Target::from_targets(targets).as_u256_target() - } - - /// Return the `UInt256` targets for the last `S-1` values - pub fn values_target(&self) -> [UInt256Target; S - 1] { - OutputValuesTarget::from_targets(self.to_values_raw()).other_outputs - } - - /// Return the value as a `UInt256Target` at the specified index - pub fn value_target_at_index(&self, i: usize) -> UInt256Target - where - [(); S - 1]:, - { - OutputValuesTarget::from_targets(self.to_values_raw()).value_target_at_index(i) - } - - pub fn num_matching_rows_target(&self) -> Target { - *self.to_count_raw() - } - - pub fn operation_ids_target(&self) -> [Target; S] { - self.to_ops_raw().try_into().unwrap() - } - - pub(crate) fn to_row_chunk_target(&self) -> RowChunkDataTarget - where - [(); S - 1]:, - { - RowChunkDataTarget:: { - left_boundary_row: self.left_boundary_row_target(), - right_boundary_row: self.right_boundary_row_target(), - chunk_outputs: UniversalQueryOutputWires { - tree_hash: self.tree_hash_target(), - values: OutputValuesTarget::from_targets(self.to_values_raw()), - count: self.num_matching_rows_target(), - num_overflows: self.overflow_flag_target().target, - }, - } - } - - /// Build an instance of `RowChunkDataTarget` from `self`; if `is_non_dummy_chunk` is - /// `false`, then build an instance of `RowChunkDataTarget` for a dummy chunk - pub(crate) fn to_dummy_row_chunk_target( - &self, - b: &mut CBuilder, - is_non_dummy_chunk: BoolTarget, - ) -> RowChunkDataTarget - where - [(); S - 1]:, - { - let dummy_values = compute_dummy_output_targets(b, &self.operation_ids_target()); - let output_values = self - .to_values_raw() - .iter() - .zip_eq(&dummy_values) - .map(|(&value, &dummy_value)| b.select(is_non_dummy_chunk, value, dummy_value)) - .collect_vec(); - - RowChunkDataTarget:: { - left_boundary_row: self.left_boundary_row_target(), - right_boundary_row: self.right_boundary_row_target(), - chunk_outputs: UniversalQueryOutputWires { - tree_hash: self.tree_hash_target(), - values: OutputValuesTarget::from_targets(&output_values), - // `count` is zeroed if chunk is dummy - count: b.mul(self.num_matching_rows_target(), is_non_dummy_chunk.target), - num_overflows: self.overflow_flag_target().target, - }, - } - } - - pub(crate) fn left_boundary_row_target(&self) -> BoundaryRowDataTarget { - BoundaryRowDataTarget::from_targets(self.to_left_row_raw()) - } - - pub(crate) fn right_boundary_row_target(&self) -> BoundaryRowDataTarget { - BoundaryRowDataTarget::from_targets(self.to_right_row_raw()) - } - - pub fn min_primary_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_min_primary_raw()) - } - - pub fn max_primary_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_max_primary_raw()) - } - - pub fn min_secondary_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_min_secondary_raw()) - } - - pub fn max_secondary_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_max_secondary_raw()) - } - - pub fn overflow_flag_target(&self) -> BoolTarget { - BoolTarget::new_unsafe(*self.to_overflow_raw()) - } - - pub fn computational_hash_target(&self) -> HashOutTarget { - HashOutTarget::try_from(self.to_computational_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length - } - - pub fn placeholder_hash_target(&self) -> HashOutTarget { - HashOutTarget::try_from(self.to_placeholder_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length - } -} - -impl PublicInputs<'_, F, S> -where - [(); S - 1]:, -{ - pub fn tree_hash(&self) -> HashOut { - HashOut::try_from(self.to_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length - } - - pub fn first_value_as_curve_point(&self) -> WeierstrassPoint { - OutputValues::::from_fields(self.to_values_raw()).first_value_as_curve_point() - } - - pub fn first_value_as_u256(&self) -> U256 { - OutputValues::::from_fields(self.to_values_raw()).first_value_as_u256() - } - - pub fn values(&self) -> [U256; S - 1] { - OutputValues::::from_fields(self.to_values_raw()).other_outputs - } - - /// Return the value as a UInt256 at the specified index - pub fn value_at_index(&self, i: usize) -> U256 - where - [(); S - 1]:, - { - OutputValues::::from_fields(self.to_values_raw()).value_at_index(i) - } - - pub fn num_matching_rows(&self) -> F { - *self.to_count_raw() - } - - pub fn operation_ids(&self) -> [F; S] { - self.to_ops_raw().try_into().unwrap() - } - - pub fn min_primary(&self) -> U256 { - U256::from_fields(self.to_min_primary_raw()) - } - - pub fn max_primary(&self) -> U256 { - U256::from_fields(self.to_max_primary_raw()) - } - - pub fn min_secondary(&self) -> U256 { - U256::from_fields(self.to_min_secondary_raw()) - } - - pub fn max_secondary(&self) -> U256 { - U256::from_fields(self.to_max_secondary_raw()) - } - - pub fn overflow_flag(&self) -> bool { - (*self.to_overflow_raw()) - .try_into_bool() - .expect("overflow flag public input different from 0 or 1") - } - - pub fn computational_hash(&self) -> HashOut { - HashOut::try_from(self.to_computational_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length - } - - pub fn placeholder_hash(&self) -> HashOut { - HashOut::try_from(self.to_placeholder_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length - } -} - -#[cfg(test)] -pub(crate) mod tests { - use std::array; - - use alloy::primitives::U256; - use itertools::Itertools; - use mp2_common::{array::ToField, public_inputs::PublicInputCommon, utils::ToFields, C, D, F}; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::{gen_random_field_hash, gen_random_u256, random_vector}, - }; - use plonky2::{ - field::types::{Field, Sample}, - iop::{ - target::Target, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::circuit_builder::CircuitBuilder, - }; - use plonky2_ecgfp5::curve::curve::Point; - use rand::{thread_rng, Rng}; - - use crate::query::{ - aggregation::{QueryBoundSource, QueryBounds}, - batching::{public_inputs::QueryPublicInputs, row_chunk::tests::BoundaryRowData}, - computational_hash_ids::{AggregationOperation, Identifiers}, - universal_circuit::universal_circuit_inputs::Placeholders, - }; - - use super::{OutputValues, PublicInputs}; - - /// Generate a set of values in a given range ensuring that the i+1-th generated value is - /// bigger than the i-th generated value - pub(crate) fn gen_values_in_range( - rng: &mut R, - lower: U256, - upper: U256, - ) -> [U256; N] { - assert!(upper >= lower, "{upper} is smaller than {lower}"); - let mut prev_value = lower; - array::from_fn(|_| { - let range = (upper - prev_value).checked_add(U256::from(1)); - let gen_value = match range { - Some(range) => prev_value + gen_random_u256(rng) % range, - None => gen_random_u256(rng), - }; - prev_value = gen_value; - gen_value - }) - } - - impl PublicInputs<'_, F, S> { - pub(crate) fn sample_from_ops(ops: &[F; S]) -> [Vec; NUM_INPUTS] - where - [(); S - 1]:, - { - let rng = &mut thread_rng(); - - let tree_hash = gen_random_field_hash(); - let computational_hash = gen_random_field_hash(); - let placeholder_hash = gen_random_field_hash(); - let [min_primary, max_primary] = gen_values_in_range(rng, U256::ZERO, U256::MAX); - let [min_secondary, max_secondary] = gen_values_in_range(rng, U256::ZERO, U256::MAX); - - let query_bounds = { - let placeholders = Placeholders::new_empty(min_primary, max_primary); - QueryBounds::new( - &placeholders, - Some(QueryBoundSource::Constant(min_secondary)), - Some(QueryBoundSource::Constant(max_secondary)), - ) - .unwrap() - }; - - let is_first_op_id = - ops[0] == Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - let mut previous_row: Option = None; - array::from_fn(|_| { - // generate output values - let output_values = if is_first_op_id { - // generate random curve point - OutputValues::::new_outputs_no_aggregation(&Point::sample(rng)) - } else { - let values = (0..S).map(|_| gen_random_u256(rng)).collect_vec(); - OutputValues::::new_aggregation_outputs(&values) - }; - // generate random count and overflow flag - let count = F::rand(); - let overflow = F::from_bool(rng.gen()); - // generate boundary rows - let left_boundary_row = if let Some(row) = &previous_row { - row.sample_consecutive_row(rng, &query_bounds) - } else { - BoundaryRowData::sample(rng, &query_bounds) - }; - let right_boundary_row = BoundaryRowData::sample(rng, &query_bounds); - assert!( - left_boundary_row.index_node_info.predecessor_info.value >= min_primary - && left_boundary_row.index_node_info.predecessor_info.value <= max_primary - ); - assert!( - left_boundary_row.index_node_info.successor_info.value >= min_primary - && left_boundary_row.index_node_info.successor_info.value <= max_primary - ); - assert!( - right_boundary_row.index_node_info.predecessor_info.value >= min_primary - && right_boundary_row.index_node_info.predecessor_info.value <= max_primary - ); - assert!( - right_boundary_row.index_node_info.successor_info.value >= min_primary - && right_boundary_row.index_node_info.successor_info.value <= max_primary - ); - previous_row = Some(right_boundary_row.clone()); - - PublicInputs::::new( - &tree_hash.to_fields(), - &output_values.to_fields(), - &[count], - ops, - &left_boundary_row.to_fields(), - &right_boundary_row.to_fields(), - &min_primary.to_fields(), - &max_primary.to_fields(), - &min_secondary.to_fields(), - &max_secondary.to_fields(), - &[overflow], - &computational_hash.to_fields(), - &placeholder_hash.to_fields(), - ) - .to_vec() - }) - } - } - - const S: usize = 10; - #[derive(Clone, Debug)] - struct TestPublicInputs<'a> { - pis: &'a [F], - } - - impl UserCircuit for TestPublicInputs<'_> { - type Wires = Vec; - - fn build(c: &mut CircuitBuilder) -> Self::Wires { - let targets = c.add_virtual_target_arr::<{ PublicInputs::::total_len() }>(); - let pi_targets = PublicInputs::::from_slice(targets.as_slice()); - pi_targets.register_args(c); - pi_targets.to_vec() - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - pw.set_target_arr(wires, self.pis) - } - } - - #[test] - fn test_batching_query_public_inputs() { - let pis_raw: Vec = random_vector::(PublicInputs::::total_len()).to_fields(); - let pis = PublicInputs::::from_slice(pis_raw.as_slice()); - // check public inputs are constructed correctly - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::TreeHash)], - pis.to_hash_raw(), - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::OutputValues)], - pis.to_values_raw(), - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::NumMatching)], - &[*pis.to_count_raw()], - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::OpIds)], - pis.to_ops_raw(), - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::LeftBoundaryRow)], - pis.to_left_row_raw(), - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::RightBoundaryRow)], - pis.to_right_row_raw(), - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinPrimary)], - pis.to_min_primary_raw(), - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxPrimary)], - pis.to_max_primary_raw(), - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinSecondary)], - pis.to_min_secondary_raw(), - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxSecondary)], - pis.to_max_secondary_raw(), - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::Overflow)], - &[*pis.to_overflow_raw()], - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::ComputationalHash)], - pis.to_computational_hash_raw(), - ); - assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::PlaceholderHash)], - pis.to_placeholder_hash_raw(), - ); - // use public inputs in circuit - let test_circuit = TestPublicInputs { pis: &pis_raw }; - let proof = run_circuit::(test_circuit); - assert_eq!(proof.public_inputs, pis_raw); - } -} diff --git a/verifiable-db/src/query/batching/row_chunk/aggregate_chunks.rs b/verifiable-db/src/query/batching/row_chunk/aggregate_chunks.rs index 077a7c895..e8d09fa3a 100644 --- a/verifiable-db/src/query/batching/row_chunk/aggregate_chunks.rs +++ b/verifiable-db/src/query/batching/row_chunk/aggregate_chunks.rs @@ -129,17 +129,17 @@ mod tests { tests::{BoundaryRowData, BoundaryRowNodeInfo, RowChunkData}, BoundaryRowDataTarget, BoundaryRowNodeInfoTarget, RowChunkDataTarget, }, + public_inputs::PublicInputs, computational_hash_ids::{AggregationOperation, Identifiers}, merkle_path::{ tests::{build_node, generate_test_tree, NeighborInfo}, MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, }, - public_inputs::PublicInputs, universal_circuit::universal_query_gadget::{ OutputValues, OutputValuesTarget, UniversalQueryOutputWires, }, }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, + test_utils::random_aggregation_operations, }; use super::aggregate_chunks; @@ -445,7 +445,7 @@ mod tests { let root = index_node.compute_node_hash(primary_index_id); // generate the output values associated to each chunk - let inputs = random_aggregation_public_inputs::<2, MAX_NUM_RESULTS>(&ops); + let inputs = PublicInputs::::sample_from_ops::<2>(&ops); let [(first_chunk_count, first_chunk_outputs, fist_chunk_num_overflows), (second_chunk_count, second_chunk_outputs, second_chunk_num_overflows)] = inputs .into_iter() diff --git a/verifiable-db/src/query/batching/row_chunk/mod.rs b/verifiable-db/src/query/batching/row_chunk/mod.rs index b87c09a64..9a3628e83 100644 --- a/verifiable-db/src/query/batching/row_chunk/mod.rs +++ b/verifiable-db/src/query/batching/row_chunk/mod.rs @@ -23,6 +23,8 @@ use crate::query::{ pub(crate) mod aggregate_chunks; /// This module contains gadgets to enforce whether 2 rows are consecutive pub(crate) mod consecutive_rows; +/// This module copntains a gadget to prove a single row of the DB +pub(crate) mod row_process_gadget; /// Data structure containing the wires representing the data related to the node of /// the row/index tree containing a row that is on the boundary of a row chunk. @@ -222,9 +224,10 @@ pub(crate) mod tests { use crate::query::{ aggregation::QueryBounds, - batching::{public_inputs::tests::gen_values_in_range, row_chunk::BoundaryRowDataTarget}, + batching::row_chunk::BoundaryRowDataTarget, merkle_path::{tests::NeighborInfo, NeighborInfoTarget}, universal_circuit::universal_query_gadget::OutputValues, + public_inputs::tests::gen_values_in_range, }; use super::BoundaryRowNodeInfoTarget; diff --git a/verifiable-db/src/query/batching/row_process_gadget.rs b/verifiable-db/src/query/batching/row_chunk/row_process_gadget.rs similarity index 97% rename from verifiable-db/src/query/batching/row_process_gadget.rs rename to verifiable-db/src/query/batching/row_chunk/row_process_gadget.rs index af3fd6824..b821f50df 100644 --- a/verifiable-db/src/query/batching/row_process_gadget.rs +++ b/verifiable-db/src/query/batching/row_chunk/row_process_gadget.rs @@ -14,12 +14,10 @@ use crate::query::{ UniversalQueryValueInputs, UniversalQueryValueWires, }, }, + api::RowInput, }; -use super::{ - circuits::api::RowInput, - row_chunk::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget, RowChunkDataTarget}, -}; +use super::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget, RowChunkDataTarget}; #[derive(Clone, Debug, Serialize, Deserialize)] pub(crate) struct RowProcessingGadgetInputWires< @@ -221,8 +219,6 @@ where >, min_query_secondary: &UInt256Target, max_query_secondary: &UInt256Target, - min_query_primary: &UInt256Target, - max_query_primary: &UInt256Target, ) -> RowProcessingGadgetWires< ROW_TREE_MAX_DEPTH, INDEX_TREE_MAX_DEPTH, @@ -235,8 +231,6 @@ where hash_input_wires, min_query_secondary, max_query_secondary, - Some(min_query_primary), - Some(max_query_primary), &zero, ); let [primary_index_id, secondary_index_id] = diff --git a/verifiable-db/src/query/mod.rs b/verifiable-db/src/query/mod.rs index 849a94d69..cee9a91b1 100644 --- a/verifiable-db/src/query/mod.rs +++ b/verifiable-db/src/query/mod.rs @@ -1,4 +1,4 @@ -use mp2_common::F; +use plonky2::iop::target::Target; use public_inputs::PublicInputs; pub mod aggregation; @@ -10,5 +10,5 @@ pub mod public_inputs; pub mod universal_circuit; pub const fn pi_len() -> usize { - PublicInputs::::total_len() + PublicInputs::::total_len() } diff --git a/verifiable-db/src/query/public_inputs.rs b/verifiable-db/src/query/public_inputs.rs index 80c8ae45f..d64ef98b1 100644 --- a/verifiable-db/src/query/public_inputs.rs +++ b/verifiable-db/src/query/public_inputs.rs @@ -4,21 +4,26 @@ use alloy::primitives::U256; use itertools::Itertools; use mp2_common::{ public_inputs::{PublicInputCommon, PublicInputRange}, - types::{CBuilder, CURVE_TARGET_LEN}, - u256::{UInt256Target, NUM_LIMBS}, + types::CBuilder, + u256::UInt256Target, utils::{FromFields, FromTargets, TryIntoBool}, F, }; use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, + hash::hash_types::{HashOut, HashOutTarget}, iop::target::{BoolTarget, Target}, }; use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; -use super::universal_circuit::universal_query_gadget::{ - CurveOrU256Target, OutputValues, OutputValuesTarget, +use crate::query::{ + aggregation::output_computation::compute_dummy_output_targets, + universal_circuit::universal_query_gadget::{ + CurveOrU256Target, OutputValues, OutputValuesTarget, UniversalQueryOutputWires, + }, }; +use super::batching::row_chunk::{BoundaryRowDataTarget, RowChunkDataTarget}; + /// Query circuits public inputs pub enum QueryPublicInputs { /// `H`: Hash of the tree @@ -33,22 +38,56 @@ pub enum QueryPublicInputs { /// `ops` : `[F; S]` Set of identifiers of the aggregation operations for each of the `S` items found in `V` /// (like "SUM", "MIN", "MAX", "COUNT" operations) OpIds, - /// `I` : `u256` value of the indexed column for the given node (meaningful only for rows tree nodes) - IndexValue, - /// `min` : `u256` Minimum value of the indexed column among all the records stored in the subtree rooted - /// in the current node; values of secondary indexed column are employed for rows tree nodes, - /// while values of primary indexed column are employed for index tree nodes - MinValue, - /// `max`` : Maximum value of the indexed column among all the records stored in the subtree rooted - /// in the current node; values of secondary indexed column are employed for rows tree nodes, - /// while values of primary indexed column are employed for index tree nodes - MaxValue, - /// `index_ids`` : `[2]F` Identifiers of indexed columns - IndexIds, - /// `MIN_I`: `u256` Lower bound of the range of indexed column values specified in the query - MinQuery, - /// `MAX_I`: `u256` Upper bound of the range of indexed column values specified in the query - MaxQuery, + /// Data associated to the left boundary row of the row chunk being proven + LeftBoundaryRow, + /// Data associated to the right boundary row of the row chunk being proven + RightBoundaryRow, + /// `MIN_primary`: `u256` Lower bound of the range of primary indexed column values specified in the query + MinPrimary, + /// `MAX_primary`: `u256` Upper bound of the range of primary indexed column values specified in the query + MaxPrimary, + /// `MIN_secondary`: `u256` Lower bound of the range of secondary indexed column values specified in the query + MinSecondary, + /// `MAX_secondary`: `u256` Upper bound of the range of secondary indexed column values specified in the query + MaxSecondary, + /// `overflow` : `bool` Flag specifying whether an overflow error has occurred in arithmetic + Overflow, + /// `C`: computational hash + ComputationalHash, + /// `H_p` : placeholder hash + PlaceholderHash, +} + +/// Public inputs for the universal query circuit. They are mostly the same as `QueryPublicInputs`, the only +/// difference is that the query range on secondary index is replaced by the value of the indexed columns for +/// the columns being proven +pub enum QueryPublicInputsUniversalCircuit { + /// `H`: Hash of the tree + TreeHash, + /// `V`: Set of `S` values representing the cumulative results of the query, where`S` is a parameter + /// specifying the maximum number of cumulative results we support; + /// the first value could be either a `u256` or a `CurveTarget`, depending on the query, and so we always + /// represent this value with `CURVE_TARGET_LEN` elements; all the other `S-1` values are always `u256` + OutputValues, + /// `count`: `F` Number of matching records in the query + NumMatching, + /// `ops` : `[F; S]` Set of identifiers of the aggregation operations for each of the `S` items found in `V` + /// (like "SUM", "MIN", "MAX", "COUNT" operations) + OpIds, + /// Data associated to the left boundary row of the row chunk being proven; it is dummy in case of universal query + /// circuit, it is just empoyed to re-use the same public inputs + LeftBoundaryRow, + /// Data associated to the right boundary row of the row chunk being proven; it is dummy in case of universal query + /// circuit, it is just empoyed to re-use the same public inputs + RightBoundaryRow, + /// `MIN_primary`: `u256` Lower bound of the range of primary indexed column values specified in the query + MinPrimary, + /// `MAX_primary`: `u256` Upper bound of the range of primary indexed column values specified in the query + MaxPrimary, + /// Value of secondary indexed column for the row being proven + SecondaryIndexValue, + /// Value of primary indexed column for the row being proven + PrimaryIndexValue, /// `overflow` : `bool` Flag specifying whether an overflow error has occurred in arithmetic Overflow, /// `C`: computational hash @@ -57,18 +96,52 @@ pub enum QueryPublicInputs { PlaceholderHash, } +impl From for QueryPublicInputs { + fn from(value: QueryPublicInputsUniversalCircuit) -> Self { + match value { + QueryPublicInputsUniversalCircuit::TreeHash => QueryPublicInputs::TreeHash, + QueryPublicInputsUniversalCircuit::OutputValues => QueryPublicInputs::OutputValues, + QueryPublicInputsUniversalCircuit::NumMatching => QueryPublicInputs::NumMatching, + QueryPublicInputsUniversalCircuit::OpIds => QueryPublicInputs::NumMatching, + QueryPublicInputsUniversalCircuit::LeftBoundaryRow => QueryPublicInputs::LeftBoundaryRow, + QueryPublicInputsUniversalCircuit::RightBoundaryRow => QueryPublicInputs::RightBoundaryRow, + QueryPublicInputsUniversalCircuit::MinPrimary => QueryPublicInputs::MinPrimary, + QueryPublicInputsUniversalCircuit::MaxPrimary => QueryPublicInputs::MaxPrimary, + QueryPublicInputsUniversalCircuit::SecondaryIndexValue => QueryPublicInputs::MinSecondary, + QueryPublicInputsUniversalCircuit::PrimaryIndexValue => QueryPublicInputs::MaxSecondary, + QueryPublicInputsUniversalCircuit::Overflow => QueryPublicInputs::Overflow, + QueryPublicInputsUniversalCircuit::ComputationalHash => QueryPublicInputs::ComputationalHash, + QueryPublicInputsUniversalCircuit::PlaceholderHash => QueryPublicInputs::PlaceholderHash, + } + } +} +/// Public inputs for generic query circuits +pub type PublicInputs<'a, T, const S: usize> = PublicInputsFactory<'a, T, S, false>; +/// Public inputs for universal query circuit +pub type PublicInputsUniversalCircuit<'a, T, const S: usize> = PublicInputsFactory<'a, T, S, true>; + +/// This is the data structure employed for both public inputs of generic query circuits +/// and for public inputs of the universal circuit. Since the 2 public inputs are the +/// same, except for the semantic of 2 U256 elements, they can be represented by the +/// same data structure. The `UNIVERSAL_CIRCUIT` const generic is employed to +/// define 2 type aliases: 1 for public inputs of generic query circuits, and 1 for +/// public inputs of universal query circuit. The methods being common between the +/// 2 public inputs are implemented for this data structure, while the methods that +/// are specific to each public input type are implemented for the corresponding alias. +/// In this way, the methods implemented for the type alias define the correct semantics +/// of each of the items in both types of public inputs. #[derive(Clone, Debug)] -pub struct PublicInputs<'a, T, const S: usize> { +pub struct PublicInputsFactory<'a, T, const S: usize, const UNIVERSAL_CIRCUIT: bool> { h: &'a [T], v: &'a [T], ops: &'a [T], count: &'a T, - i: &'a [T], - min: &'a [T], - max: &'a [T], - ids: &'a [T], - min_q: &'a [T], - max_q: &'a [T], + left_row: &'a [T], + right_row: &'a [T], + min_p: &'a [T], + max_p: &'a [T], + min_s: &'a [T], + max_s: &'a [T], overflow: &'a T, ch: &'a [T], ph: &'a [T], @@ -76,53 +149,58 @@ pub struct PublicInputs<'a, T, const S: usize> { const NUM_PUBLIC_INPUTS: usize = QueryPublicInputs::PlaceholderHash as usize + 1; -impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { +impl< + 'a, + T: Clone, + const S: usize, + const UNIVERSAL_CIRCUIT: bool, +> PublicInputsFactory<'a, T, S, UNIVERSAL_CIRCUIT> { const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ - Self::to_range(QueryPublicInputs::TreeHash), - Self::to_range(QueryPublicInputs::OutputValues), - Self::to_range(QueryPublicInputs::NumMatching), - Self::to_range(QueryPublicInputs::OpIds), - Self::to_range(QueryPublicInputs::IndexValue), - Self::to_range(QueryPublicInputs::MinValue), - Self::to_range(QueryPublicInputs::MaxValue), - Self::to_range(QueryPublicInputs::IndexIds), - Self::to_range(QueryPublicInputs::MinQuery), - Self::to_range(QueryPublicInputs::MaxQuery), - Self::to_range(QueryPublicInputs::Overflow), - Self::to_range(QueryPublicInputs::ComputationalHash), - Self::to_range(QueryPublicInputs::PlaceholderHash), + Self::to_range_internal(QueryPublicInputs::TreeHash), + Self::to_range_internal(QueryPublicInputs::OutputValues), + Self::to_range_internal(QueryPublicInputs::NumMatching), + Self::to_range_internal(QueryPublicInputs::OpIds), + Self::to_range_internal(QueryPublicInputs::LeftBoundaryRow), + Self::to_range_internal(QueryPublicInputs::RightBoundaryRow), + Self::to_range_internal(QueryPublicInputs::MinPrimary), + Self::to_range_internal(QueryPublicInputs::MaxPrimary), + Self::to_range_internal(QueryPublicInputs::MinSecondary), + Self::to_range_internal(QueryPublicInputs::MaxSecondary), + Self::to_range_internal(QueryPublicInputs::Overflow), + Self::to_range_internal(QueryPublicInputs::ComputationalHash), + Self::to_range_internal(QueryPublicInputs::PlaceholderHash), ]; const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ // Tree hash - NUM_HASH_OUT_ELTS, + HashOutTarget::NUM_TARGETS, // Output values - CURVE_TARGET_LEN + NUM_LIMBS * (S - 1), + CurveTarget::NUM_TARGETS + UInt256Target::NUM_TARGETS * (S - 1), // Number of matching records 1, // Operation identifiers S, - // Index column value - NUM_LIMBS, - // Minimum indexed column value - NUM_LIMBS, - // Maximum indexed column value - NUM_LIMBS, - // Indexed column IDs - 2, - // Lower bound for indexed column specified in query - NUM_LIMBS, - // Upper bound for indexed column specified in query - NUM_LIMBS, + // Left boundary row + BoundaryRowDataTarget::NUM_TARGETS, + // Right boundary row + BoundaryRowDataTarget::NUM_TARGETS, + // Min primary index + UInt256Target::NUM_TARGETS, + // Max primary index + UInt256Target::NUM_TARGETS, + // Min secondary index + UInt256Target::NUM_TARGETS, + // Max secondary index + UInt256Target::NUM_TARGETS, // Overflow flag 1, // Computational hash - NUM_HASH_OUT_ELTS, + HashOutTarget::NUM_TARGETS, // Placeholder hash - NUM_HASH_OUT_ELTS, + HashOutTarget::NUM_TARGETS, ]; - pub const fn to_range(query_pi: QueryPublicInputs) -> PublicInputRange { + const fn to_range_internal(query_pi: QueryPublicInputs) -> PublicInputRange { let mut i = 0; let mut offset = 0; let pi_pos = query_pi as usize; @@ -133,8 +211,13 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { offset..offset + Self::SIZES[pi_pos] } + pub fn to_range>(query_pi: Q) -> PublicInputRange + { + Self::to_range_internal(query_pi.into()) + } + pub(crate) const fn total_len() -> usize { - Self::to_range(QueryPublicInputs::PlaceholderHash).end + Self::to_range_internal(QueryPublicInputs::PlaceholderHash).end } pub(crate) fn to_hash_raw(&self) -> &[T] { @@ -153,28 +236,28 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { self.ops } - pub(crate) fn to_index_value_raw(&self) -> &[T] { - self.i + pub(crate) fn to_left_row_raw(&self) -> &[T] { + self.left_row } - pub(crate) fn to_min_value_raw(&self) -> &[T] { - self.min + pub(crate) fn to_right_row_raw(&self) -> &[T] { + self.right_row } - pub(crate) fn to_max_value_raw(&self) -> &[T] { - self.max + pub(crate) fn to_min_primary_raw(&self) -> &[T] { + self.min_p } - pub(crate) fn to_index_ids_raw(&self) -> &[T] { - self.ids + pub(crate) fn to_max_primary_raw(&self) -> &[T] { + self.max_p } - pub(crate) fn to_min_query_raw(&self) -> &[T] { - self.min_q + pub(crate) fn to_min_secondary_raw(&self) -> &[T] { + self.min_s } - pub(crate) fn to_max_query_raw(&self) -> &[T] { - self.max_q + pub(crate) fn to_max_secondary_raw(&self) -> &[T] { + self.max_s } pub(crate) fn to_overflow_raw(&self) -> &T { @@ -200,29 +283,30 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { v: &input[Self::PI_RANGES[1].clone()], count: &input[Self::PI_RANGES[2].clone()][0], ops: &input[Self::PI_RANGES[3].clone()], - i: &input[Self::PI_RANGES[4].clone()], - min: &input[Self::PI_RANGES[5].clone()], - max: &input[Self::PI_RANGES[6].clone()], - ids: &input[Self::PI_RANGES[7].clone()], - min_q: &input[Self::PI_RANGES[8].clone()], - max_q: &input[Self::PI_RANGES[9].clone()], + left_row: &input[Self::PI_RANGES[4].clone()], + right_row: &input[Self::PI_RANGES[5].clone()], + min_p: &input[Self::PI_RANGES[6].clone()], + max_p: &input[Self::PI_RANGES[7].clone()], + min_s: &input[Self::PI_RANGES[8].clone()], + max_s: &input[Self::PI_RANGES[9].clone()], overflow: &input[Self::PI_RANGES[10].clone()][0], ch: &input[Self::PI_RANGES[11].clone()], ph: &input[Self::PI_RANGES[12].clone()], } } + #[allow(clippy::too_many_arguments)] pub fn new( h: &'a [T], v: &'a [T], count: &'a [T], ops: &'a [T], - i: &'a [T], - min: &'a [T], - max: &'a [T], - ids: &'a [T], - min_q: &'a [T], - max_q: &'a [T], + left_row: &'a [T], + right_row: &'a [T], + min_p: &'a [T], + max_p: &'a [T], + min_s: &'a [T], + max_s: &'a [T], overflow: &'a [T], ch: &'a [T], ph: &'a [T], @@ -232,12 +316,12 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { v, count: &count[0], ops, - i, - min, - max, - ids, - min_q, - max_q, + left_row, + right_row, + min_p, + max_p, + min_s, + max_s, overflow: &overflow[0], ch, ph, @@ -250,12 +334,12 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { .chain(self.v.iter()) .chain(once(self.count)) .chain(self.ops.iter()) - .chain(self.i.iter()) - .chain(self.min.iter()) - .chain(self.max.iter()) - .chain(self.ids.iter()) - .chain(self.min_q.iter()) - .chain(self.max_q.iter()) + .chain(self.left_row.iter()) + .chain(self.right_row.iter()) + .chain(self.min_p.iter()) + .chain(self.max_p.iter()) + .chain(self.min_s.iter()) + .chain(self.max_s.iter()) .chain(once(self.overflow)) .chain(self.ch.iter()) .chain(self.ph.iter()) @@ -264,7 +348,7 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { } } -impl PublicInputCommon for PublicInputs<'_, Target, S> { +impl PublicInputCommon for PublicInputsFactory<'_, Target, S, UNIVERSAL_CIRCUIT> { const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES; fn register_args(&self, cb: &mut CBuilder) { @@ -272,19 +356,19 @@ impl PublicInputCommon for PublicInputs<'_, Target, S> { cb.register_public_inputs(self.v); cb.register_public_input(*self.count); cb.register_public_inputs(self.ops); - cb.register_public_inputs(self.i); - cb.register_public_inputs(self.min); - cb.register_public_inputs(self.max); - cb.register_public_inputs(self.ids); - cb.register_public_inputs(self.min_q); - cb.register_public_inputs(self.max_q); + cb.register_public_inputs(self.left_row); + cb.register_public_inputs(self.right_row); + cb.register_public_inputs(self.min_p); + cb.register_public_inputs(self.max_p); + cb.register_public_inputs(self.min_s); + cb.register_public_inputs(self.max_s); cb.register_public_input(*self.overflow); cb.register_public_inputs(self.ch); cb.register_public_inputs(self.ph); } } -impl PublicInputs<'_, Target, S> { +impl PublicInputsFactory<'_, Target, S, UNIVERSAL_CIRCUIT> { pub fn tree_hash_target(&self) -> HashOutTarget { HashOutTarget::try_from(self.to_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length } @@ -321,44 +405,107 @@ impl PublicInputs<'_, Target, S> { self.to_ops_raw().try_into().unwrap() } - pub fn index_value_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_index_value_raw()) + pub fn min_primary_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_min_primary_raw()) } - pub fn min_value_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_min_value_raw()) + pub fn max_primary_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_max_primary_raw()) } - pub fn max_value_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_max_value_raw()) + pub fn overflow_flag_target(&self) -> BoolTarget { + BoolTarget::new_unsafe(*self.to_overflow_raw()) } - pub fn index_ids_target(&self) -> [Target; 2] { - self.to_index_ids_raw().try_into().unwrap() + pub fn computational_hash_target(&self) -> HashOutTarget { + HashOutTarget::try_from(self.to_computational_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length } - pub fn min_query_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_min_query_raw()) + pub fn placeholder_hash_target(&self) -> HashOutTarget { + HashOutTarget::try_from(self.to_placeholder_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length } +} - pub fn max_query_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_max_query_raw()) +impl PublicInputs<'_, Target, S> { + pub(crate) fn left_boundary_row_target(&self) -> BoundaryRowDataTarget { + BoundaryRowDataTarget::from_targets(self.to_left_row_raw()) } - pub fn overflow_flag_target(&self) -> BoolTarget { - BoolTarget::new_unsafe(*self.to_overflow_raw()) + pub(crate) fn right_boundary_row_target(&self) -> BoundaryRowDataTarget { + BoundaryRowDataTarget::from_targets(self.to_right_row_raw()) } - pub fn computational_hash_target(&self) -> HashOutTarget { - HashOutTarget::try_from(self.to_computational_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + pub(crate) fn to_row_chunk_target(&self) -> RowChunkDataTarget + where + [(); S - 1]:, + { + RowChunkDataTarget:: { + left_boundary_row: self.left_boundary_row_target(), + right_boundary_row: self.right_boundary_row_target(), + chunk_outputs: UniversalQueryOutputWires { + tree_hash: self.tree_hash_target(), + values: OutputValuesTarget::from_targets(self.to_values_raw()), + count: self.num_matching_rows_target(), + num_overflows: self.overflow_flag_target().target, + }, + } } - pub fn placeholder_hash_target(&self) -> HashOutTarget { - HashOutTarget::try_from(self.to_placeholder_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + /// Build an instance of `RowChunkDataTarget` from `self`; if `is_non_dummy_chunk` is + /// `false`, then build an instance of `RowChunkDataTarget` for a dummy chunk + pub(crate) fn to_dummy_row_chunk_target( + &self, + b: &mut CBuilder, + is_non_dummy_chunk: BoolTarget, + ) -> RowChunkDataTarget + where + [(); S - 1]:, + { + let dummy_values = compute_dummy_output_targets(b, &self.operation_ids_target()); + let output_values = self + .to_values_raw() + .iter() + .zip_eq(&dummy_values) + .map(|(&value, &dummy_value)| b.select(is_non_dummy_chunk, value, dummy_value)) + .collect_vec(); + + RowChunkDataTarget:: { + left_boundary_row: self.left_boundary_row_target(), + right_boundary_row: self.right_boundary_row_target(), + chunk_outputs: UniversalQueryOutputWires { + tree_hash: self.tree_hash_target(), + values: OutputValuesTarget::from_targets(&output_values), + // `count` is zeroed if chunk is dummy + count: b.mul(self.num_matching_rows_target(), is_non_dummy_chunk.target), + num_overflows: self.overflow_flag_target().target, + }, + } + } + + pub fn min_secondary_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_min_secondary_raw()) + } + + pub fn max_secondary_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_max_secondary_raw()) + } +} + +impl PublicInputsUniversalCircuit<'_, Target, S> { + pub fn secondary_index_value_target(&self) -> UInt256Target { + // secondary index value is found in `self.min_s` for + // `PublicInputsUniversalCircuit` + UInt256Target::from_targets(self.min_s) + } + + pub fn primary_index_value_target(&self) -> UInt256Target { + // primary index value is found in `self.max_s` for + // `PublicInputsUniversalCircuit` + UInt256Target::from_targets(self.max_s) } } -impl PublicInputs<'_, F, S> +impl PublicInputsFactory<'_, F, S, UNIVERSAL_CIRCUIT> where [(); S - 1]:, { @@ -394,28 +541,12 @@ where self.to_ops_raw().try_into().unwrap() } - pub fn index_value(&self) -> U256 { - U256::from_fields(self.to_index_value_raw()) - } - - pub fn min_value(&self) -> U256 { - U256::from_fields(self.to_min_value_raw()) + pub fn min_primary(&self) -> U256 { + U256::from_fields(self.to_min_primary_raw()) } - pub fn max_value(&self) -> U256 { - U256::from_fields(self.to_max_value_raw()) - } - - pub fn index_ids(&self) -> [F; 2] { - self.to_index_ids_raw().try_into().unwrap() - } - - pub fn min_query_value(&self) -> U256 { - U256::from_fields(self.to_min_query_raw()) - } - - pub fn max_query_value(&self) -> U256 { - U256::from_fields(self.to_max_query_raw()) + pub fn max_primary(&self) -> U256 { + U256::from_fields(self.to_max_primary_raw()) } pub fn overflow_flag(&self) -> bool { @@ -433,25 +564,164 @@ where } } +impl PublicInputs<'_, F, S> { + pub fn min_secondary(&self) -> U256 { + U256::from_fields(self.to_min_secondary_raw()) + } + + pub fn max_secondary(&self) -> U256 { + U256::from_fields(self.to_max_secondary_raw()) + } +} + +impl PublicInputsUniversalCircuit<'_, F, S> { + pub fn secondary_index_value(&self) -> U256 { + // secondary index value is found in `self.min_s` for + // `PublicInputsUniversalCircuit` + U256::from_fields(self.min_s) + } + + pub fn primary_index_value(&self) -> U256 { + // primary index value is found in `self.max_s` for + // `PublicInputsUniversalCircuit` + U256::from_fields(self.max_s) + } +} + #[cfg(test)] -mod tests { +pub(crate) mod tests { + use std::array; - use mp2_common::{public_inputs::PublicInputCommon, utils::ToFields, C, D, F}; + use alloy::primitives::U256; + use itertools::Itertools; + use mp2_common::{array::ToField, public_inputs::PublicInputCommon, utils::ToFields, C, D, F}; use mp2_test::{ circuit::{run_circuit, UserCircuit}, - utils::random_vector, + utils::{gen_random_field_hash, gen_random_u256, random_vector}, }; use plonky2::{ + field::types::{Field, Sample}, iop::{ target::Target, witness::{PartialWitness, WitnessWrite}, }, plonk::circuit_builder::CircuitBuilder, }; + use plonky2_ecgfp5::curve::curve::Point; + use rand::{thread_rng, Rng}; + + use crate::query::{ + aggregation::{QueryBoundSource, QueryBounds}, + batching::row_chunk::tests::BoundaryRowData, + computational_hash_ids::{AggregationOperation, Identifiers}, + universal_circuit::universal_circuit_inputs::Placeholders, + }; - use crate::query::public_inputs::QueryPublicInputs; - - use super::PublicInputs; + use super::{OutputValues, PublicInputsFactory, PublicInputs, QueryPublicInputs}; + + /// Generate a set of values in a given range ensuring that the i+1-th generated value is + /// bigger than the i-th generated value + pub(crate) fn gen_values_in_range( + rng: &mut R, + lower: U256, + upper: U256, + ) -> [U256; N] { + assert!(upper >= lower, "{upper} is smaller than {lower}"); + let mut prev_value = lower; + array::from_fn(|_| { + let range = (upper - prev_value).checked_add(U256::from(1)); + let gen_value = match range { + Some(range) => prev_value + gen_random_u256(rng) % range, + None => gen_random_u256(rng), + }; + prev_value = gen_value; + gen_value + }) + } + + impl PublicInputsFactory<'_, F, S, UNIVERSAL_CIRCUIT> { + pub(crate) fn sample_from_ops(ops: &[F; S]) -> [Vec; NUM_INPUTS] + where + [(); S - 1]:, + { + let rng = &mut thread_rng(); + + let tree_hash = gen_random_field_hash(); + let computational_hash = gen_random_field_hash(); + let placeholder_hash = gen_random_field_hash(); + let [min_primary, max_primary] = gen_values_in_range(rng, U256::ZERO, U256::MAX); + let [min_secondary, max_secondary] = gen_values_in_range(rng, U256::ZERO, U256::MAX); + + let query_bounds = { + let placeholders = Placeholders::new_empty(min_primary, max_primary); + QueryBounds::new( + &placeholders, + Some(QueryBoundSource::Constant(min_secondary)), + Some(QueryBoundSource::Constant(max_secondary)), + ) + .unwrap() + }; + + let is_first_op_id = + ops[0] == Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); + + let mut previous_row: Option = None; + array::from_fn(|_| { + // generate output values + let output_values = if is_first_op_id { + // generate random curve point + OutputValues::::new_outputs_no_aggregation(&Point::sample(rng)) + } else { + let values = (0..S).map(|_| gen_random_u256(rng)).collect_vec(); + OutputValues::::new_aggregation_outputs(&values) + }; + // generate random count and overflow flag + let count = F::from_canonical_u32(rng.gen()); + let overflow = F::from_bool(rng.gen()); + // generate boundary rows + let left_boundary_row = if let Some(row) = &previous_row { + row.sample_consecutive_row(rng, &query_bounds) + } else { + BoundaryRowData::sample(rng, &query_bounds) + }; + let right_boundary_row = BoundaryRowData::sample(rng, &query_bounds); + assert!( + left_boundary_row.index_node_info.predecessor_info.value >= min_primary + && left_boundary_row.index_node_info.predecessor_info.value <= max_primary + ); + assert!( + left_boundary_row.index_node_info.successor_info.value >= min_primary + && left_boundary_row.index_node_info.successor_info.value <= max_primary + ); + assert!( + right_boundary_row.index_node_info.predecessor_info.value >= min_primary + && right_boundary_row.index_node_info.predecessor_info.value <= max_primary + ); + assert!( + right_boundary_row.index_node_info.successor_info.value >= min_primary + && right_boundary_row.index_node_info.successor_info.value <= max_primary + ); + previous_row = Some(right_boundary_row.clone()); + + PublicInputs::::new( + &tree_hash.to_fields(), + &output_values.to_fields(), + &[count], + ops, + &left_boundary_row.to_fields(), + &right_boundary_row.to_fields(), + &min_primary.to_fields(), + &max_primary.to_fields(), + &min_secondary.to_fields(), + &max_secondary.to_fields(), + &[overflow], + &computational_hash.to_fields(), + &placeholder_hash.to_fields(), + ) + .to_vec() + }) + } + } const S: usize = 10; #[derive(Clone, Debug)] @@ -475,7 +745,7 @@ mod tests { } #[test] - fn test_query_public_inputs() { + fn test_batching_query_public_inputs() { let pis_raw: Vec = random_vector::(PublicInputs::::total_len()).to_fields(); let pis = PublicInputs::::from_slice(pis_raw.as_slice()); // check public inputs are constructed correctly @@ -496,28 +766,28 @@ mod tests { pis.to_ops_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::IndexValue)], - pis.to_index_value_raw(), + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::LeftBoundaryRow)], + pis.to_left_row_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinValue)], - pis.to_min_value_raw(), + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::RightBoundaryRow)], + pis.to_right_row_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxValue)], - pis.to_max_value_raw(), + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinPrimary)], + pis.to_min_primary_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinQuery)], - pis.to_min_query_raw(), + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxPrimary)], + pis.to_max_primary_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxQuery)], - pis.to_max_query_raw(), + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinSecondary)], + pis.to_min_secondary_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::IndexIds)], - pis.to_index_ids_raw(), + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxSecondary)], + pis.to_max_secondary_raw(), ); assert_eq!( &pis_raw[PublicInputs::::to_range(QueryPublicInputs::Overflow)], diff --git a/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs b/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs index c84d47f35..46fbd3f89 100644 --- a/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs +++ b/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs @@ -1,26 +1,18 @@ use std::iter::once; use crate::query::{ - aggregation::QueryBounds, computational_hash_ids::PlaceholderIdentifier, pi_len, - public_inputs::PublicInputs, + aggregation::QueryBounds, public_inputs::PublicInputsUniversalCircuit, batching::row_chunk::BoundaryRowDataTarget, computational_hash_ids::{Output, PlaceholderIdentifier}, pi_len }; use anyhow::Result; use itertools::Itertools; use mp2_common::{ - array::ToField, - poseidon::{empty_poseidon_hash, HashPermutation}, - public_inputs::PublicInputCommon, - serialization::{deserialize, serialize}, - utils::{HashBuilder, ToFields, ToTargets}, - CHasher, D, F, + array::ToField, poseidon::{empty_poseidon_hash, HashPermutation}, public_inputs::PublicInputCommon, serialization::{deserialize, serialize}, types::CBuilder, utils::{FromTargets, HashBuilder, ToFields, ToTargets}, CHasher, C, D, F }; use plonky2::{ - hash::hashing::hash_n_to_hash_no_pad, - iop::{ + field::types::Field, hash::hashing::hash_n_to_hash_no_pad, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, - }, - plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, + }, plonk::{circuit_builder::CircuitBuilder, circuit_data::{CircuitConfig, CircuitData}, proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}} }; use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -115,6 +107,7 @@ where is_leaf: bool, query_bounds: &QueryBounds, results: &ResultStructure, + is_dummy_row: bool, ) -> Result { let hash_gadget_inputs = UniversalQueryHashInputs::new( &row_cells.column_ids(), @@ -124,7 +117,7 @@ where results, )?; - let value_gadget_inputs = UniversalQueryValueInputs::new(row_cells, false)?; + let value_gadget_inputs = UniversalQueryValueInputs::new(row_cells, is_dummy_row)?; Ok(Self { is_leaf, @@ -148,8 +141,6 @@ where &hash_wires.input_wires, &hash_wires.min_secondary, &hash_wires.max_secondary, - None, - None, &hash_wires.num_bound_overflows, ); let is_leaf = b.add_virtual_bool_target_safe(); @@ -158,13 +149,6 @@ where // min and max for secondary indexed column let node_min = &value_wires.input_wires.column_values[1]; let node_max = node_min; - // value of the primary indexed column - let index_value = &value_wires.input_wires.column_values[0]; - // column ids for primary and seconday indexed columns - let (primary_index_id, second_index_id) = ( - &hash_wires.input_wires.column_extraction_wires.column_ids[0], - &hash_wires.input_wires.column_extraction_wires.column_ids[1], - ); // compute hash of the node in case the current row is stored in a leaf of the rows tree let empty_hash = b.constant_hash(*empty_poseidon_hash()); let leaf_hash_inputs = empty_hash @@ -173,7 +157,7 @@ where .chain(empty_hash.elements.iter()) .chain(node_min.to_targets().iter()) .chain(node_max.to_targets().iter()) - .chain(once(second_index_id)) + .chain(once(&hash_wires.input_wires.column_extraction_wires.column_ids[1])) .chain(node_min.to_targets().iter()) .chain(value_wires.output_wires.tree_hash.elements.iter()) .cloned() @@ -186,17 +170,22 @@ where let output_values_targets = value_wires.output_wires.values.to_targets(); - PublicInputs::::new( + // compute dummy left boundary and right boundary rows to be exposed as public inputs; + // they are ignored by the circuits processing this proof, so it's ok to use dummy + // values + let dummy_boundary_row_targets = b.constants(&vec![F::ZERO; BoundaryRowDataTarget::NUM_TARGETS]); + let primary_index_value = &value_wires.input_wires.column_values[0]; + PublicInputsUniversalCircuit::::new( &tree_hash.to_targets(), &output_values_targets, &[value_wires.output_wires.count], hash_wires.agg_ops_ids.as_slice(), - &index_value.to_targets(), + &dummy_boundary_row_targets, + &dummy_boundary_row_targets, + &hash_wires.input_wires.min_query_primary.to_targets(), + &hash_wires.input_wires.max_query_primary.to_targets(), &node_min.to_targets(), - &node_max.to_targets(), - &[*primary_index_id, *second_index_id], - &hash_wires.min_secondary.to_targets(), - &hash_wires.max_secondary.to_targets(), + &primary_index_value.to_targets(), &[overflow.target], &hash_wires.computational_hash.to_targets(), &hash_wires.placeholder_hash.to_targets(), @@ -321,6 +310,54 @@ where } } +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct UniversalQueryCircuitParams< + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent + Serialize, +> { + #[serde(serialize_with="serialize", deserialize_with="deserialize")] + pub(crate) data: CircuitData, + wires: UniversalQueryCircuitWires, +} + +impl< + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent + Serialize + DeserializeOwned, +> UniversalQueryCircuitParams +where + [(); MAX_NUM_RESULTS - 1]:, + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, +{ + pub(crate) fn build(config: CircuitConfig) -> Self { + let mut builder = CBuilder::new(config); + let wires = UniversalQueryCircuitInputs::build(&mut builder); + let data = builder.build(); + Self { + data, + wires, + } + } + + pub(crate) fn generate_proof(&self, input: &UniversalQueryCircuitInputs< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + > + ) -> Result> { + let mut pw = PartialWitness::::new(); + input.assign(&mut pw, &self.wires); + self.data.prove(pw) + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] /// Inputs for the 2 variant of universal query circuit pub enum UniversalCircuitInput< @@ -365,8 +402,8 @@ where [(); MAX_NUM_RESULTS - 1]:, [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, { - /// Provide input values for universal circuit variant for queries with aggregation operations - pub(crate) fn new_query_with_agg( + /// Provide input values for universal circuit variant for queries without aggregation operations + pub(crate) fn new_query_no_agg( column_cells: &RowCells, predicate_operations: &[BasicOperation], placeholders: &Placeholders, @@ -374,7 +411,7 @@ where query_bounds: &QueryBounds, results: &ResultStructure, ) -> Result { - Ok(UniversalCircuitInput::QueryWithAgg( + Ok(UniversalCircuitInput::QueryNoAgg( UniversalQueryCircuitInputs::new( column_cells, predicate_operations, @@ -382,28 +419,58 @@ where is_leaf, query_bounds, results, + false, )?, )) } - /// Provide input values for universal circuit variant for queries without aggregation operations - pub(crate) fn new_query_no_agg( - column_cells: &RowCells, + + pub(crate) fn ids_for_placeholder_hash( predicate_operations: &[BasicOperation], + results: &ResultStructure, placeholders: &Placeholders, - is_leaf: bool, query_bounds: &QueryBounds, - results: &ResultStructure, - ) -> Result { - Ok(UniversalCircuitInput::QueryNoAgg( - UniversalQueryCircuitInputs::new( - column_cells, - predicate_operations, - placeholders, - is_leaf, - query_bounds, - results, - )?, - )) + ) -> Result<[PlaceholderId; 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]> { + let row_cells = &RowCells::default(); + Ok(match results.output_variant { + Output::Aggregation => { + let circuit = UniversalQueryCircuitInputs::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + AggOutputCircuit, + >::new( + row_cells, + predicate_operations, + placeholders, + false, // doesn't matter for placeholder hash computation + query_bounds, + results, + false, // doesn't matter for placeholder hash computation + )?; + circuit.ids_for_placeholder_hash() + } + Output::NoAggregation => { + let circuit = UniversalQueryCircuitInputs::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + NoAggOutputCircuit, + >::new( + row_cells, + predicate_operations, + placeholders, + false, // doesn't matter for placeholder hash computation + query_bounds, + results, + false, // doesn't matter for placeholder hash computation + )?; + circuit.ids_for_placeholder_hash() + } + } + .try_into() + .unwrap()) } } @@ -414,12 +481,7 @@ mod tests { use alloy::primitives::U256; use itertools::Itertools; use mp2_common::{ - array::ToField, - group_hashing::map_to_curve_point, - poseidon::empty_poseidon_hash, - proof::ProofWithVK, - utils::{FromFields, ToFields, TryIntoBool}, - C, D, F, + array::ToField, default_config, group_hashing::map_to_curve_point, poseidon::empty_poseidon_hash, utils::{FromFields, ToFields, TryIntoBool}, C, D, F }; use mp2_test::{ cells_tree::{compute_cells_tree_hash, TestCell}, @@ -438,24 +500,21 @@ mod tests { use crate::query::{ aggregation::{QueryBoundSource, QueryBounds}, - api::{CircuitInput, Parameters}, computational_hash_ids::{ AggregationOperation, ColumnIDs, HashPermutation, Identifiers, Operation, PlaceholderIdentifier, }, - public_inputs::PublicInputs, + public_inputs::PublicInputsUniversalCircuit, universal_circuit::{ - universal_circuit_inputs::{ + output_no_aggregation::Circuit as OutputNoAggCircuit, output_with_aggregation::Circuit as OutputAggCircuit, universal_circuit_inputs::{ BasicOperation, ColumnCell, InputOperand, OutputItem, PlaceholderId, Placeholders, ResultStructure, RowCells, - }, - universal_query_circuit::placeholder_hash, - ComputationalHash, + }, universal_query_circuit::{placeholder_hash, UniversalQueryCircuitParams}, ComputationalHash }, }; use super::{ - OutputComponent, UniversalCircuitInput, UniversalQueryCircuitInputs, + OutputComponent, UniversalQueryCircuitInputs, UniversalQueryCircuitWires, }; @@ -495,7 +554,7 @@ mod tests { } // test the following query: - // SELECT AVG(C1+C2/(C2*C3)), SUM(C1+C2), MIN(C1+$1), MAX(C4-2), AVG(C5) FROM T WHERE (C5 > 5 AND C1*C3 <= C4+C5 OR C3 == $2) AND C2 >= 75 AND C2 < $3 + // SELECT AVG(C1+C2/(C2*C3)), SUM(C1+C2), MIN(C1+$1), MAX(C4-2), AVG(C5) FROM T WHERE (C5 > 5 AND C1*C3 <= C4+C5 OR C3 == $2) AND C2 >= 75 AND C2 < $3 AND C1 >= 42 AND C1 < 56 async fn query_with_aggregation(build_parameters: bool) { init_logging(); const NUM_ACTUAL_COLUMNS: usize = 5; @@ -504,17 +563,28 @@ mod tests { const MAX_NUM_RESULT_OPS: usize = 30; const MAX_NUM_RESULTS: usize = 10; let rng = &mut thread_rng(); - let min_query = U256::from(75); - let max_query = U256::from(98); + let min_query_primary = U256::from(42); + let max_query_primary = U256::from(55); + let min_query_secondary = U256::from(75); + let max_query_secondary = U256::from(98); let column_values = (0..NUM_ACTUAL_COLUMNS) .map(|i| { - if i == 1 { - // ensure that second column value is in the range specified by the query: - // we sample a random u256 in range [0, max_query - min_query) and then we - // add min_query - gen_random_u256(rng).div_rem(max_query - min_query).1 + min_query - } else { - gen_random_u256(rng) + match i { + 0 => { + // ensure that primary index column value is in the range specified by the query: + // we sample a random u256 in range [0, max_query - min_query) and then we + // add min_query + gen_random_u256(rng).div_rem(max_query_primary - min_query_primary + U256::from(1)).1 + min_query_primary + }, + 1 => { + // ensure that second column value is in the range specified by the query: + // we sample a random u256 in range [0, max_query - min_query) and then we + // add min_query + gen_random_u256(rng).div_rem(max_query_secondary - min_query_secondary + U256::from(1)).1 + min_query_secondary + }, + _ => { + gen_random_u256(rng) + }, } }) .collect_vec(); @@ -533,15 +603,15 @@ mod tests { let first_placeholder_id = PlaceholderId::Generic(0); let second_placeholder_id = PlaceholderIdentifier::Generic(1); let mut placeholders = Placeholders::new_empty( - U256::default(), - U256::default(), // dummy values + min_query_primary, + max_query_primary, ); [first_placeholder_id, second_placeholder_id] .iter() .for_each(|id| placeholders.insert(*id, gen_random_u256(rng))); // 3-rd placeholder is the max query bound let third_placeholder_id = PlaceholderId::Generic(2); - placeholders.insert(third_placeholder_id, max_query); + placeholders.insert(third_placeholder_id, max_query_secondary); // build predicate operations let mut predicate_operations = vec![]; @@ -695,7 +765,7 @@ mod tests { let query_bounds = QueryBounds::new( &placeholders, - Some(QueryBoundSource::Constant(min_query)), + Some(QueryBoundSource::Constant(min_query_secondary)), Some( QueryBoundSource::Operation(BasicOperation { first_operand: InputOperand::Placeholder(third_placeholder_id), @@ -707,21 +777,21 @@ mod tests { ), ) .unwrap(); - let min_query_value = query_bounds.min_query_secondary().value; - let max_query_value = query_bounds.max_query_secondary().value; - let input = CircuitInput::< + let circuit = UniversalQueryCircuitInputs::< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, - >::new_universal_circuit( + OutputAggCircuit, + >::new( &row_cells, &predicate_operations, - &results, &placeholders, is_leaf, &query_bounds, + &results, + false, ) .unwrap(); @@ -783,14 +853,6 @@ mod tests { }) .collect_vec(); - let circuit = if let CircuitInput::UniversalCircuit(UniversalCircuitInput::QueryWithAgg( - c, - )) = &input - { - c - } else { - unreachable!() - }; let placeholder_hash_ids = circuit.ids_for_placeholder_hash(); let placeholder_hash = placeholder_hash(&placeholder_hash_ids, &placeholders, &query_bounds).unwrap(); @@ -813,17 +875,17 @@ mod tests { .into(), ); let proof = if build_parameters { - let params = Parameters::build(); + let params = UniversalQueryCircuitParams::build( + default_config() + ); params - .generate_proof(input) - .and_then(|p| ProofWithVK::deserialize(&p)) - .map(|p| p.proof().clone()) + .generate_proof(&circuit) .unwrap() } else { run_circuit::(circuit.clone()) }; - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); + let pi = PublicInputsUniversalCircuit::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); assert_eq!(tree_hash, pi.tree_hash()); assert_eq!(output_values[0], pi.first_value_as_u256()); assert_eq!(output_values[1..], pi.values()[..output_values.len() - 1]); @@ -832,12 +894,10 @@ mod tests { predicate_value, pi.num_matching_rows().try_into_bool().unwrap() ); - assert_eq!(column_values[0], pi.index_value()); - assert_eq!(column_values[1], pi.min_value()); - assert_eq!(column_values[1], pi.max_value()); - assert_eq!([column_ids[0], column_ids[1]], pi.index_ids()); - assert_eq!(min_query_value, pi.min_query_value()); - assert_eq!(max_query_value, pi.max_query_value()); + assert_eq!(min_query_primary, pi.min_primary()); + assert_eq!(max_query_primary, pi.max_primary()); + assert_eq!(column_cells[1].value, pi.secondary_index_value()); + assert_eq!(column_cells[0].value, pi.primary_index_value()); assert_eq!(placeholder_hash, pi.placeholder_hash()); assert_eq!(computational_hash, pi.computational_hash()); assert_eq!(predicate_err || result_err, pi.overflow_flag()); @@ -854,7 +914,7 @@ mod tests { } // test the following query: - // SELECT C1 < C2/45, C3*C4, C7, (C5-C6)%C1, C3*C4 - $1 FROM T WHERE ((NOT C5 != 42) OR C1*C7 <= C4/C6+C5 XOR C3 < $2) AND C2 >= $3 AND C2 < 44 + // SELECT C1 < C2/45, C3*C4, C7, (C5-C6)%C1, C3*C4 - $1 FROM T WHERE ((NOT C5 != 42) OR C1*C7 <= C4/C6+C5 XOR C3 < $2) AND C2 >= $3 AND C2 < 44 AND C1 > 13 AND C1 <= 17 async fn query_without_aggregation(single_result: bool, build_parameters: bool) { init_logging(); const NUM_ACTUAL_COLUMNS: usize = 7; @@ -863,20 +923,28 @@ mod tests { const MAX_NUM_RESULT_OPS: usize = 30; const MAX_NUM_RESULTS: usize = 10; let rng = &mut thread_rng(); - let min_query = U256::from(43); - let max_query = U256::from(43); + let min_query_primary = U256::from(14); + let max_query_primary = U256::from(17); + let min_query_secondary = U256::from(43); + let max_query_secondary = U256::from(43); let column_values = (0..NUM_ACTUAL_COLUMNS) .map(|i| { - if i == 1 { - // ensure that second column value is in the range specified by the query: - // we sample a random u256 in range [0, max_query - min_query + 1) and then we - // add min_query - gen_random_u256(rng) - .div_rem(max_query - min_query + U256::from(1)) - .1 - + min_query - } else { - gen_random_u256(rng) + match i { + 0 => { + // ensure that primary index column value is in the range specified by the query: + // we sample a random u256 in range [0, max_query - min_query) and then we + // add min_query + gen_random_u256(rng).div_rem(max_query_primary - min_query_primary + U256::from(1)).1 + min_query_primary + }, + 1 => { + // ensure that second column value is in the range specified by the query: + // we sample a random u256 in range [0, max_query - min_query) and then we + // add min_query + gen_random_u256(rng).div_rem(max_query_secondary - min_query_secondary + U256::from(1)).1 + min_query_secondary + }, + _ => { + gen_random_u256(rng) + }, } }) .collect_vec(); @@ -895,15 +963,15 @@ mod tests { let first_placeholder_id = PlaceholderId::Generic(0); let second_placeholder_id = PlaceholderIdentifier::Generic(1); let mut placeholders = Placeholders::new_empty( - U256::default(), - U256::default(), // dummy values + min_query_primary, + max_query_primary, ); [first_placeholder_id, second_placeholder_id] .iter() .for_each(|id| placeholders.insert(*id, gen_random_u256(rng))); // 3-rd placeholder is the min query bound let third_placeholder_id = PlaceholderId::Generic(2); - placeholders.insert(third_placeholder_id, min_query); + placeholders.insert(third_placeholder_id, min_query_secondary); // build predicate operations let mut predicate_operations = vec![]; @@ -1087,21 +1155,23 @@ mod tests { let query_bounds = QueryBounds::new( &placeholders, Some(QueryBoundSource::Placeholder(third_placeholder_id)), - Some(QueryBoundSource::Constant(max_query)), + Some(QueryBoundSource::Constant(max_query_secondary)), ) .unwrap(); - let input = CircuitInput::< + let circuit = UniversalQueryCircuitInputs::< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, - >::new_universal_circuit( + OutputNoAggCircuit, + >::new( &row_cells, &predicate_operations, - &results, &placeholders, is_leaf, &query_bounds, + &results, + false, ) .unwrap(); @@ -1178,12 +1248,6 @@ mod tests { Point::NEUTRAL }; - let circuit = - if let CircuitInput::UniversalCircuit(UniversalCircuitInput::QueryNoAgg(c)) = &input { - c - } else { - unreachable!() - }; let placeholder_hash_ids = circuit.ids_for_placeholder_hash(); let placeholder_hash = placeholder_hash(&placeholder_hash_ids, &placeholders, &query_bounds).unwrap(); @@ -1207,17 +1271,17 @@ mod tests { ); let proof = if build_parameters { - let params = Parameters::build(); + let params = UniversalQueryCircuitParams::build( + default_config() + ); params - .generate_proof(input) - .and_then(|p| ProofWithVK::deserialize(&p)) - .map(|p| p.proof().clone()) + .generate_proof(&circuit) .unwrap() } else { run_circuit::(circuit.clone()) }; - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); + let pi = PublicInputsUniversalCircuit::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); assert_eq!(tree_hash, pi.tree_hash()); assert_eq!(output_acc.to_weierstrass(), pi.first_value_as_curve_point()); // The other MAX_NUM_RESULTS -1 output values are dummy ones, as in queries @@ -1240,12 +1304,10 @@ mod tests { predicate_value, pi.num_matching_rows().try_into_bool().unwrap() ); - assert_eq!(column_values[0], pi.index_value()); - assert_eq!(column_values[1], pi.min_value()); - assert_eq!(column_values[1], pi.max_value()); - assert_eq!([column_ids[0], column_ids[1]], pi.index_ids()); - assert_eq!(min_query, pi.min_query_value()); - assert_eq!(max_query, pi.max_query_value()); + assert_eq!(min_query_primary, pi.min_primary()); + assert_eq!(max_query_primary, pi.max_primary()); + assert_eq!(column_cells[1].value, pi.secondary_index_value()); + assert_eq!(column_cells[0].value, pi.primary_index_value()); assert_eq!(placeholder_hash, pi.placeholder_hash()); assert_eq!(computational_hash, pi.computational_hash()); assert_eq!(predicate_err || result_err, pi.overflow_flag()); diff --git a/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs b/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs index 4b2208d91..0fd2a9988 100644 --- a/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs +++ b/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs @@ -493,10 +493,14 @@ pub(crate) struct UniversalQueryHashInputWires< > { /// Input wires for column extraction component pub(crate) column_extraction_wires: ColumnExtractionInputWires, + /// Lower bound of the range for the primary index specified in the query + pub(crate) min_query_primary: UInt256Target, + /// Upper bound of the range for the primary index specified in the query + pub(crate) max_query_primary: UInt256Target, /// Lower bound of the range for the secondary index specified in the query - min_query: QueryBoundTargetInputs, + min_query_secondary: QueryBoundTargetInputs, /// Upper bound of the range for the secondary index specified in the query - max_query: QueryBoundTargetInputs, + max_query_secondary: QueryBoundTargetInputs, /// Input wires for the `MAX_NUM_PREDICATE_OPS` basic operation components necessary /// to evaluate the filtering predicate #[serde( @@ -548,8 +552,10 @@ pub(crate) struct UniversalQueryHashInputs< T: OutputComponent, > { column_extraction_inputs: ColumnExtractionInputs, - min_query: QueryBound, - max_query: QueryBound, + min_query_primary: U256, + max_query_primary: U256, + min_query_secondary: QueryBound, + max_query_secondary: QueryBound, #[serde( serialize_with = "serialize_long_array", deserialize_with = "deserialize_long_array" @@ -661,8 +667,10 @@ where Ok(Self { column_extraction_inputs, - min_query, - max_query, + min_query_primary: query_bounds.min_query_primary(), + max_query_primary: query_bounds.max_query_primary(), + min_query_secondary: min_query, + max_query_secondary: max_query, filtering_predicate_inputs: predicate_ops_inputs, result_values_inputs: result_ops_inputs, output_component_inputs, @@ -679,8 +687,9 @@ where T, > { let column_extraction_wires = ColumnExtractionInputs::build_hash(b); - let min_query = QueryBoundTarget::new(b); - let max_query = QueryBoundTarget::new(b); + let [min_query_primary, max_query_primary] = b.add_virtual_u256_arr_unsafe(); + let min_query_secondary = QueryBoundTarget::new(b); + let max_query_secondary = QueryBoundTarget::new(b); let mut input_hash = column_extraction_wires.column_hash.to_vec(); // Payload to compute the placeholder hash public input let mut placeholder_hash_payload = vec![]; @@ -742,27 +751,29 @@ where let placeholder_hash = b.hash_n_to_hash_no_pad::(placeholder_hash_payload); let placeholder_hash = QueryBoundTarget::add_query_bounds_to_placeholder_hash( b, - &min_query, - &max_query, + &min_query_secondary, + &max_query_secondary, &placeholder_hash, ); // add query bounds to computational hash let computational_hash = QueryBoundTarget::add_query_bounds_to_computational_hash( b, - &min_query, - &max_query, + &min_query_secondary, + &max_query_secondary, &output_component_wires.computational_hash(), ); - let min_secondary = min_query.get_bound_value().clone(); - let max_secondary = max_query.get_bound_value().clone(); + let min_secondary = min_query_secondary.get_bound_value().clone(); + let max_secondary = max_query_secondary.get_bound_value().clone(); let num_bound_overflows = - QueryBoundTarget::num_overflows_for_query_bound_operations(b, &min_query, &max_query); + QueryBoundTarget::num_overflows_for_query_bound_operations(b, &min_query_secondary, &max_query_secondary); UniversalQueryHashWires { input_wires: UniversalQueryHashInputWires { column_extraction_wires: column_extraction_wires.input_wires, - min_query: min_query.into(), - max_query: max_query.into(), + min_query_primary, + max_query_primary, + min_query_secondary: min_query_secondary.into(), + max_query_secondary: max_query_secondary.into(), filtering_predicate_ops: filtering_predicate_wires.try_into().unwrap(), result_value_ops: result_value_wires.try_into().unwrap(), output_component_wires: output_component_wires.input_wires(), @@ -793,8 +804,10 @@ where ) { self.column_extraction_inputs .assign(pw, &wires.column_extraction_wires); - wires.min_query.assign(pw, &self.min_query); - wires.max_query.assign(pw, &self.max_query); + pw.set_u256_target(&wires.min_query_primary, self.min_query_primary); + pw.set_u256_target(&wires.max_query_primary, self.max_query_primary); + wires.min_query_secondary.assign(pw, &self.min_query_secondary); + wires.max_query_secondary.assign(pw, &self.max_query_secondary); self.filtering_predicate_inputs .iter() .chain(self.result_values_inputs.iter()) @@ -1284,41 +1297,25 @@ where >, min_secondary: &UInt256Target, max_secondary: &UInt256Target, - min_primary: Option<&UInt256Target>, // Option since we don't need this in universal query circuit - max_primary: Option<&UInt256Target>, num_overflows: &Target, ) -> UniversalQueryValueWires { let column_values = ColumnExtractionInputs::build_column_values(b); - // check that min_primary and max_primary are either both Some or both None - assert_eq!(min_primary.is_some(), max_primary.is_some()); let _true = b._true(); // allocate dummy row flag only if we aren't in universal circuit, i.e., if min_primary.is_some() is true - let is_non_dummy_row = if min_primary.is_some() { - b.add_virtual_bool_target_safe() - } else { - _true - }; + let is_non_dummy_row = b.add_virtual_bool_target_safe(); let ColumnExtractionValueWires { tree_hash } = ColumnExtractionInputs::build_tree_hash( b, &column_values, &hash_input_wires.column_extraction_wires, ); - // if we have min_primary and max_primary bounds, enforce that the value of - // primary index for the current row is in the range given by these bounds - match (min_primary, max_primary) { - (Some(min), Some(max)) => { - let index_value = &column_values[0]; - let less_than_max = b.is_less_or_equal_than_u256(index_value, max); - let greater_than_min = b.is_less_or_equal_than_u256(min, index_value); - b.connect(less_than_max.target, _true.target); - b.connect(greater_than_min.target, _true.target); - } - (None, None) => (), - _ => unreachable!( - "min_primary and max_primary should be either both Some(_) or both None" - ), - } + // Enforce that the value of primary index for the current row is in the range given by these bounds + let index_value = &column_values[0]; + let less_than_max = b.is_less_or_equal_than_u256(index_value, &hash_input_wires.max_query_primary); + let greater_than_min = b.is_less_or_equal_than_u256(&hash_input_wires.min_query_primary, index_value); + b.connect(less_than_max.target, _true.target); + b.connect(greater_than_min.target, _true.target); + // min and max for secondary indexed column let node_min = &column_values[1]; diff --git a/verifiable-db/src/results_tree/binding/binding_results.rs b/verifiable-db/src/results_tree/binding/binding_results.rs index 9431af03c..1bbb2e41f 100644 --- a/verifiable-db/src/results_tree/binding/binding_results.rs +++ b/verifiable-db/src/results_tree/binding/binding_results.rs @@ -3,12 +3,12 @@ use crate::{ query::{ computational_hash_ids::{AggregationOperation, ResultIdentifier}, - public_inputs::PublicInputs as QueryProofPI, universal_circuit::ComputationalHashTarget, }, results_tree::{ binding::public_inputs::PublicInputs, construction::public_inputs::PublicInputs as ResultsConstructionProofPI, + old_public_inputs::PublicInputs as QueryProofPI, }, }; use mp2_common::{ @@ -99,12 +99,11 @@ impl BindingResultsCircuit { mod tests { use super::*; use crate::{ - query::pi_len as query_pi_len, - results_tree::construction::{ + results_tree::{construction::{ public_inputs::ResultsConstructionPublicInputs, tests::{pi_len, random_results_construction_public_inputs}, - }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, + }, tests::random_aggregation_public_inputs}, + test_utils::random_aggregation_operations, }; use itertools::Itertools; use mp2_common::{poseidon::H, utils::ToFields, C, D, F}; @@ -117,7 +116,7 @@ mod tests { const S: usize = 20; - const QUERY_PI_LEN: usize = query_pi_len::(); + const QUERY_PI_LEN: usize = QueryProofPI::::total_len(); const RESULTS_CONSTRUCTION_PI_LEN: usize = pi_len::(); #[derive(Clone, Debug)] diff --git a/verifiable-db/src/results_tree/mod.rs b/verifiable-db/src/results_tree/mod.rs index 53396f41a..62b718052 100644 --- a/verifiable-db/src/results_tree/mod.rs +++ b/verifiable-db/src/results_tree/mod.rs @@ -1,2 +1,78 @@ pub(crate) mod binding; pub(crate) mod construction; +/// Old query public inputs, moved here because the circuits in this module still expects +/// these public inputs for now +pub(crate) mod old_public_inputs; + +#[cfg(test)] +pub(crate) mod tests { + use std::array; + + use mp2_common::{array::ToField, types::CURVE_TARGET_LEN, utils::ToFields, F}; + use plonky2::{field::types::{Field, Sample}, hash::hash_types::NUM_HASH_OUT_ELTS}; + use plonky2_ecgfp5::curve::curve::Point; + use rand::{thread_rng, Rng}; + + use crate::query::computational_hash_ids::{AggregationOperation, Identifiers}; + + use super::old_public_inputs::{PublicInputs, QueryPublicInputs}; + + /// Generate S number of proof public input slices by the specified operations for testing. + /// The each returned proof public inputs could be constructed by + /// `PublicInputs::from_slice` function. + pub fn random_aggregation_public_inputs( + ops: &[F; S], + ) -> [Vec; N] { + let [ops_range, overflow_range, index_ids_range, c_hash_range, p_hash_range] = [ + QueryPublicInputs::OpIds, + QueryPublicInputs::Overflow, + QueryPublicInputs::IndexIds, + QueryPublicInputs::ComputationalHash, + QueryPublicInputs::PlaceholderHash, + ] + .map(PublicInputs::::to_range); + + let first_value_start = PublicInputs::::to_range(QueryPublicInputs::OutputValues).start; + let is_first_op_id = + ops[0] == Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); + + // Generate the index ids, computational hash and placeholder hash, + // they should be same for a series of public inputs. + let mut rng = thread_rng(); + let index_ids = (0..2).map(|_| rng.gen()).collect::>().to_fields(); + let [computational_hash, placeholder_hash]: [Vec<_>; 2] = array::from_fn(|_| { + (0..NUM_HASH_OUT_ELTS) + .map(|_| rng.gen()) + .collect::>() + .to_fields() + }); + + array::from_fn(|_| { + let mut pi = (0..PublicInputs::::total_len()) + .map(|_| rng.gen()) + .collect::>() + .to_fields(); + + // Copy the specified operations to the proofs. + pi[ops_range.clone()].copy_from_slice(ops); + + // Set the overflow flag to a random boolean. + let overflow = F::from_bool(rng.gen()); + pi[overflow_range.clone()].copy_from_slice(&[overflow]); + + // Set the index ids, computational hash and placeholder hash, + pi[index_ids_range.clone()].copy_from_slice(&index_ids); + pi[c_hash_range.clone()].copy_from_slice(&computational_hash); + pi[p_hash_range.clone()].copy_from_slice(&placeholder_hash); + + // If the first operation is ID, set the value to a random point. + if is_first_op_id { + let first_value = Point::sample(&mut rng).to_weierstrass().to_fields(); + pi[first_value_start..first_value_start + CURVE_TARGET_LEN] + .copy_from_slice(&first_value); + } + + pi + }) + } +} diff --git a/verifiable-db/src/results_tree/old_public_inputs.rs b/verifiable-db/src/results_tree/old_public_inputs.rs new file mode 100644 index 000000000..7f6d07b00 --- /dev/null +++ b/verifiable-db/src/results_tree/old_public_inputs.rs @@ -0,0 +1,539 @@ +use std::iter::once; + +use alloy::primitives::U256; +use itertools::Itertools; +use mp2_common::{ + public_inputs::{PublicInputCommon, PublicInputRange}, + types::{CBuilder, CURVE_TARGET_LEN}, + u256::{UInt256Target, NUM_LIMBS}, + utils::{FromFields, FromTargets, TryIntoBool}, + F, +}; +use plonky2::{ + hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, + iop::target::{BoolTarget, Target}, +}; +use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; + +use crate::query::universal_circuit::universal_query_gadget::{ + CurveOrU256Target, OutputValues, OutputValuesTarget, +}; + +/// Query circuits public inputs +pub enum QueryPublicInputs { + /// `H`: Hash of the tree + TreeHash, + /// `V`: Set of `S` values representing the cumulative results of the query, where`S` is a parameter + /// specifying the maximum number of cumulative results we support; + /// the first value could be either a `u256` or a `CurveTarget`, depending on the query, and so we always + /// represent this value with `CURVE_TARGET_LEN` elements; all the other `S-1` values are always `u256` + OutputValues, + /// `count`: `F` Number of matching records in the query + NumMatching, + /// `ops` : `[F; S]` Set of identifiers of the aggregation operations for each of the `S` items found in `V` + /// (like "SUM", "MIN", "MAX", "COUNT" operations) + OpIds, + /// `I` : `u256` value of the indexed column for the given node (meaningful only for rows tree nodes) + IndexValue, + /// `min` : `u256` Minimum value of the indexed column among all the records stored in the subtree rooted + /// in the current node; values of secondary indexed column are employed for rows tree nodes, + /// while values of primary indexed column are employed for index tree nodes + MinValue, + /// `max`` : Maximum value of the indexed column among all the records stored in the subtree rooted + /// in the current node; values of secondary indexed column are employed for rows tree nodes, + /// while values of primary indexed column are employed for index tree nodes + MaxValue, + /// `index_ids`` : `[2]F` Identifiers of indexed columns + IndexIds, + /// `MIN_I`: `u256` Lower bound of the range of indexed column values specified in the query + MinQuery, + /// `MAX_I`: `u256` Upper bound of the range of indexed column values specified in the query + MaxQuery, + /// `overflow` : `bool` Flag specifying whether an overflow error has occurred in arithmetic + Overflow, + /// `C`: computational hash + ComputationalHash, + /// `H_p` : placeholder hash + PlaceholderHash, +} + +#[derive(Clone, Debug)] +pub struct PublicInputs<'a, T, const S: usize> { + h: &'a [T], + v: &'a [T], + ops: &'a [T], + count: &'a T, + i: &'a [T], + min: &'a [T], + max: &'a [T], + ids: &'a [T], + min_q: &'a [T], + max_q: &'a [T], + overflow: &'a T, + ch: &'a [T], + ph: &'a [T], +} + +const NUM_PUBLIC_INPUTS: usize = QueryPublicInputs::PlaceholderHash as usize + 1; + +impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { + const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ + Self::to_range(QueryPublicInputs::TreeHash), + Self::to_range(QueryPublicInputs::OutputValues), + Self::to_range(QueryPublicInputs::NumMatching), + Self::to_range(QueryPublicInputs::OpIds), + Self::to_range(QueryPublicInputs::IndexValue), + Self::to_range(QueryPublicInputs::MinValue), + Self::to_range(QueryPublicInputs::MaxValue), + Self::to_range(QueryPublicInputs::IndexIds), + Self::to_range(QueryPublicInputs::MinQuery), + Self::to_range(QueryPublicInputs::MaxQuery), + Self::to_range(QueryPublicInputs::Overflow), + Self::to_range(QueryPublicInputs::ComputationalHash), + Self::to_range(QueryPublicInputs::PlaceholderHash), + ]; + + const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ + // Tree hash + NUM_HASH_OUT_ELTS, + // Output values + CURVE_TARGET_LEN + NUM_LIMBS * (S - 1), + // Number of matching records + 1, + // Operation identifiers + S, + // Index column value + NUM_LIMBS, + // Minimum indexed column value + NUM_LIMBS, + // Maximum indexed column value + NUM_LIMBS, + // Indexed column IDs + 2, + // Lower bound for indexed column specified in query + NUM_LIMBS, + // Upper bound for indexed column specified in query + NUM_LIMBS, + // Overflow flag + 1, + // Computational hash + NUM_HASH_OUT_ELTS, + // Placeholder hash + NUM_HASH_OUT_ELTS, + ]; + + pub const fn to_range(query_pi: QueryPublicInputs) -> PublicInputRange { + let mut i = 0; + let mut offset = 0; + let pi_pos = query_pi as usize; + while i < pi_pos { + offset += Self::SIZES[i]; + i += 1; + } + offset..offset + Self::SIZES[pi_pos] + } + + pub(crate) const fn total_len() -> usize { + Self::to_range(QueryPublicInputs::PlaceholderHash).end + } + + pub(crate) fn to_hash_raw(&self) -> &[T] { + self.h + } + + pub(crate) fn to_values_raw(&self) -> &[T] { + self.v + } + + pub(crate) fn to_count_raw(&self) -> &T { + self.count + } + + pub(crate) fn to_ops_raw(&self) -> &[T] { + self.ops + } + + pub(crate) fn to_index_value_raw(&self) -> &[T] { + self.i + } + + pub(crate) fn to_min_value_raw(&self) -> &[T] { + self.min + } + + pub(crate) fn to_max_value_raw(&self) -> &[T] { + self.max + } + + pub(crate) fn to_index_ids_raw(&self) -> &[T] { + self.ids + } + + pub(crate) fn to_min_query_raw(&self) -> &[T] { + self.min_q + } + + pub(crate) fn to_max_query_raw(&self) -> &[T] { + self.max_q + } + + pub(crate) fn to_overflow_raw(&self) -> &T { + self.overflow + } + + pub(crate) fn to_computational_hash_raw(&self) -> &[T] { + self.ch + } + + pub(crate) fn to_placeholder_hash_raw(&self) -> &[T] { + self.ph + } + + pub fn from_slice(input: &'a [T]) -> Self { + assert!( + input.len() >= Self::total_len(), + "input slice too short to build query public inputs, must be at least {} elements", + Self::total_len() + ); + Self { + h: &input[Self::PI_RANGES[0].clone()], + v: &input[Self::PI_RANGES[1].clone()], + count: &input[Self::PI_RANGES[2].clone()][0], + ops: &input[Self::PI_RANGES[3].clone()], + i: &input[Self::PI_RANGES[4].clone()], + min: &input[Self::PI_RANGES[5].clone()], + max: &input[Self::PI_RANGES[6].clone()], + ids: &input[Self::PI_RANGES[7].clone()], + min_q: &input[Self::PI_RANGES[8].clone()], + max_q: &input[Self::PI_RANGES[9].clone()], + overflow: &input[Self::PI_RANGES[10].clone()][0], + ch: &input[Self::PI_RANGES[11].clone()], + ph: &input[Self::PI_RANGES[12].clone()], + } + } + #[allow(clippy::too_many_arguments)] + pub fn new( + h: &'a [T], + v: &'a [T], + count: &'a [T], + ops: &'a [T], + i: &'a [T], + min: &'a [T], + max: &'a [T], + ids: &'a [T], + min_q: &'a [T], + max_q: &'a [T], + overflow: &'a [T], + ch: &'a [T], + ph: &'a [T], + ) -> Self { + Self { + h, + v, + count: &count[0], + ops, + i, + min, + max, + ids, + min_q, + max_q, + overflow: &overflow[0], + ch, + ph, + } + } + + pub fn to_vec(&self) -> Vec { + self.h + .iter() + .chain(self.v.iter()) + .chain(once(self.count)) + .chain(self.ops.iter()) + .chain(self.i.iter()) + .chain(self.min.iter()) + .chain(self.max.iter()) + .chain(self.ids.iter()) + .chain(self.min_q.iter()) + .chain(self.max_q.iter()) + .chain(once(self.overflow)) + .chain(self.ch.iter()) + .chain(self.ph.iter()) + .cloned() + .collect_vec() + } +} + +impl PublicInputCommon for PublicInputs<'_, Target, S> { + const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES; + + fn register_args(&self, cb: &mut CBuilder) { + cb.register_public_inputs(self.h); + cb.register_public_inputs(self.v); + cb.register_public_input(*self.count); + cb.register_public_inputs(self.ops); + cb.register_public_inputs(self.i); + cb.register_public_inputs(self.min); + cb.register_public_inputs(self.max); + cb.register_public_inputs(self.ids); + cb.register_public_inputs(self.min_q); + cb.register_public_inputs(self.max_q); + cb.register_public_input(*self.overflow); + cb.register_public_inputs(self.ch); + cb.register_public_inputs(self.ph); + } +} + +impl PublicInputs<'_, Target, S> { + pub fn tree_hash_target(&self) -> HashOutTarget { + HashOutTarget::try_from(self.to_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } + /// Return the first output value as a `CurveTarget` + pub fn first_value_as_curve_target(&self) -> CurveTarget { + let targets = self.to_values_raw(); + CurveOrU256Target::from_targets(targets).as_curve_target() + } + + /// Return the first output value as a `UInt256Target` + pub fn first_value_as_u256_target(&self) -> UInt256Target { + let targets = self.to_values_raw(); + CurveOrU256Target::from_targets(targets).as_u256_target() + } + + /// Return the `UInt256` targets for the last `S-1` values + pub fn values_target(&self) -> [UInt256Target; S - 1] { + OutputValuesTarget::from_targets(self.to_values_raw()).other_outputs + } + + /// Return the value as a `UInt256Target` at the specified index + pub fn value_target_at_index(&self, i: usize) -> UInt256Target + where + [(); S - 1]:, + { + OutputValuesTarget::from_targets(self.to_values_raw()).value_target_at_index(i) + } + + pub fn num_matching_rows_target(&self) -> Target { + *self.to_count_raw() + } + + pub fn operation_ids_target(&self) -> [Target; S] { + self.to_ops_raw().try_into().unwrap() + } + + pub fn index_value_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_index_value_raw()) + } + + pub fn min_value_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_min_value_raw()) + } + + pub fn max_value_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_max_value_raw()) + } + + pub fn index_ids_target(&self) -> [Target; 2] { + self.to_index_ids_raw().try_into().unwrap() + } + + pub fn min_query_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_min_query_raw()) + } + + pub fn max_query_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_max_query_raw()) + } + + pub fn overflow_flag_target(&self) -> BoolTarget { + BoolTarget::new_unsafe(*self.to_overflow_raw()) + } + + pub fn computational_hash_target(&self) -> HashOutTarget { + HashOutTarget::try_from(self.to_computational_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } + + pub fn placeholder_hash_target(&self) -> HashOutTarget { + HashOutTarget::try_from(self.to_placeholder_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } +} + +impl PublicInputs<'_, F, S> +where + [(); S - 1]:, +{ + pub fn tree_hash(&self) -> HashOut { + HashOut::try_from(self.to_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } + + pub fn first_value_as_curve_point(&self) -> WeierstrassPoint { + OutputValues::::from_fields(self.to_values_raw()).first_value_as_curve_point() + } + + pub fn first_value_as_u256(&self) -> U256 { + OutputValues::::from_fields(self.to_values_raw()).first_value_as_u256() + } + + pub fn values(&self) -> [U256; S - 1] { + OutputValues::::from_fields(self.to_values_raw()).other_outputs + } + + /// Return the value as a UInt256 at the specified index + pub fn value_at_index(&self, i: usize) -> U256 + where + [(); S - 1]:, + { + OutputValues::::from_fields(self.to_values_raw()).value_at_index(i) + } + + pub fn num_matching_rows(&self) -> F { + *self.to_count_raw() + } + + pub fn operation_ids(&self) -> [F; S] { + self.to_ops_raw().try_into().unwrap() + } + + pub fn index_value(&self) -> U256 { + U256::from_fields(self.to_index_value_raw()) + } + + pub fn min_value(&self) -> U256 { + U256::from_fields(self.to_min_value_raw()) + } + + pub fn max_value(&self) -> U256 { + U256::from_fields(self.to_max_value_raw()) + } + + pub fn index_ids(&self) -> [F; 2] { + self.to_index_ids_raw().try_into().unwrap() + } + + pub fn min_query_value(&self) -> U256 { + U256::from_fields(self.to_min_query_raw()) + } + + pub fn max_query_value(&self) -> U256 { + U256::from_fields(self.to_max_query_raw()) + } + + pub fn overflow_flag(&self) -> bool { + (*self.to_overflow_raw()) + .try_into_bool() + .expect("overflow flag public input different from 0 or 1") + } + + pub fn computational_hash(&self) -> HashOut { + HashOut::try_from(self.to_computational_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } + + pub fn placeholder_hash(&self) -> HashOut { + HashOut::try_from(self.to_placeholder_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } +} + +#[cfg(test)] +mod tests { + + use mp2_common::{public_inputs::PublicInputCommon, utils::ToFields, C, D, F}; + use mp2_test::{ + circuit::{run_circuit, UserCircuit}, + utils::random_vector, + }; + use plonky2::{ + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::circuit_builder::CircuitBuilder, + }; + + use super::QueryPublicInputs; + + use super::PublicInputs; + + const S: usize = 10; + #[derive(Clone, Debug)] + struct TestPublicInputs<'a> { + pis: &'a [F], + } + + impl UserCircuit for TestPublicInputs<'_> { + type Wires = Vec; + + fn build(c: &mut CircuitBuilder) -> Self::Wires { + let targets = c.add_virtual_target_arr::<{ PublicInputs::::total_len() }>(); + let pi_targets = PublicInputs::::from_slice(targets.as_slice()); + pi_targets.register_args(c); + pi_targets.to_vec() + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + pw.set_target_arr(wires, self.pis) + } + } + + #[test] + fn test_query_public_inputs() { + let pis_raw: Vec = random_vector::(PublicInputs::::total_len()).to_fields(); + let pis = PublicInputs::::from_slice(pis_raw.as_slice()); + // check public inputs are constructed correctly + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::TreeHash)], + pis.to_hash_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::OutputValues)], + pis.to_values_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::NumMatching)], + &[*pis.to_count_raw()], + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::OpIds)], + pis.to_ops_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::IndexValue)], + pis.to_index_value_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinValue)], + pis.to_min_value_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxValue)], + pis.to_max_value_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinQuery)], + pis.to_min_query_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxQuery)], + pis.to_max_query_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::IndexIds)], + pis.to_index_ids_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::Overflow)], + &[*pis.to_overflow_raw()], + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::ComputationalHash)], + pis.to_computational_hash_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::PlaceholderHash)], + pis.to_placeholder_hash_raw(), + ); + // use public inputs in circuit + let test_circuit = TestPublicInputs { pis: &pis_raw }; + let proof = run_circuit::(test_circuit); + assert_eq!(proof.public_inputs, pis_raw); + } +} diff --git a/verifiable-db/src/revelation/api.rs b/verifiable-db/src/revelation/api.rs index 60c2a5677..04cbd27fb 100644 --- a/verifiable-db/src/revelation/api.rs +++ b/verifiable-db/src/revelation/api.rs @@ -12,7 +12,7 @@ use mp2_common::{ C, D, F, }; use plonky2::plonk::{ - circuit_data::VerifierOnlyCircuitData, config::Hasher, proof::ProofWithPublicInputs, + circuit_data::{VerifierCircuitData, VerifierOnlyCircuitData}, config::Hasher, proof::ProofWithPublicInputs, }; use recursion_framework::{ circuit_builder::{CircuitWithUniversalVerifier, CircuitWithUniversalVerifierBuilder}, @@ -24,14 +24,9 @@ use serde::{Deserialize, Serialize}; use crate::{ query::{ - self, - aggregation::QueryBounds, - api::{CircuitInput as QueryCircuitInput, Parameters as QueryParams}, - computational_hash_ids::ColumnIDs, - pi_len as query_pi_len, - universal_circuit::universal_circuit_inputs::{ + aggregation::QueryBounds, computational_hash_ids::ColumnIDs, pi_len as query_pi_len, universal_circuit::{output_no_aggregation::Circuit as OutputNoAggCircuit, universal_circuit_inputs::{ BasicOperation, Placeholders, ResultStructure, - }, + }, universal_query_circuit::{UniversalCircuitInput, UniversalQueryCircuitParams}} }, revelation::{ placeholders_check::CheckPlaceholderGadget, @@ -43,13 +38,12 @@ use crate::{ }; use super::{ - num_query_io, num_query_io_no_results_tree, pi_len, + pi_len, revelation_unproven_offset::{ - RecursiveCircuitInputs as RecursiveCircuitInputsUnporvenOffset, - RevelationCircuit as RevelationCircuitUnprovenOffset, RowPath, + CircuitBuilderParams, RecursiveCircuitInputs as RecursiveCircuitInputsUnporvenOffset, RevelationCircuit as RevelationCircuitUnprovenOffset, RowPath }, revelation_without_results_tree::{ - CircuitBuilderParams, RecursiveCircuitInputs, RecursiveCircuitWires, + CircuitBuilderParams as CircuitBuilderParamsNoResultsTree, RecursiveCircuitInputs, RecursiveCircuitWires, RevelationWithoutResultsTreeCircuit, }, }; @@ -150,8 +144,7 @@ pub struct Parameters< [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, - [(); num_query_io::()]:, - [(); num_query_io_no_results_tree::()]:, + [(); query_pi_len::()]:, { revelation_no_results_tree: CircuitWithUniversalVerifier< F, @@ -222,7 +215,7 @@ pub enum CircuitInput< >, }, UnprovenOffset { - row_proofs: Vec, + row_proofs: Vec>, preprocessing_proof: ProofWithPublicInputs, revelation_circuit: RevelationCircuitUnprovenOffset< ROW_TREE_MAX_DEPTH, @@ -233,7 +226,7 @@ pub enum CircuitInput< { 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS) }, >, dummy_row_proof_input: Option< - QueryCircuitInput< + UniversalCircuitInput< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -268,7 +261,6 @@ where [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, [(); MAX_NUM_ITEMS_PER_OUTPUT - 1]:, - [(); query_pi_len::()]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, { @@ -291,7 +283,7 @@ where ) -> Result { let query_proof = ProofWithVK::deserialize(&query_proof)?; let preprocessing_proof = deserialize_proof(&preprocessing_proof)?; - let placeholder_hash_ids = query::api::CircuitInput::< + let placeholder_hash_ids = UniversalCircuitInput::< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -309,6 +301,7 @@ where placeholder_hash_ids, )?, }; + println!("{:?}", revelation_circuit); Ok(CircuitInput::NoResultsTree { query_proof, @@ -377,10 +370,10 @@ where .map(|(i, row)| { row_paths[i] = row.path.clone(); result_values[i] = row.result.clone(); - ProofWithVK::deserialize(&row.proof) + deserialize_proof(&row.proof) }) .collect::>>()?; - let placeholder_hash_ids = query::api::CircuitInput::< + let placeholder_hash_ids = UniversalCircuitInput::< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -396,6 +389,7 @@ where let revelation_circuit = RevelationCircuitUnprovenOffset::new( row_paths, + [column_ids.primary, column_ids.secondary], &results_structure.output_ids, result_values, limit, @@ -436,20 +430,18 @@ impl< > where [(); MAX_NUM_ITEMS_PER_OUTPUT - 1]:, - [(); num_query_io::()]:, + [(); query_pi_len::()]:, [(); >::HASH_SIZE]:, [(); ROW_TREE_MAX_DEPTH - 1]:, [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, - [(); query_pi_len::()]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, [(); pi_len::()]:, - [(); num_query_io_no_results_tree::()]:, { pub fn build( - _batching_query_circuit_set: &RecursiveCircuits, query_circuit_set: &RecursiveCircuits, + universal_circuit_vk: VerifierCircuitData, preprocessing_circuit_set: &RecursiveCircuits, preprocessing_vk: &VerifierOnlyCircuitData, ) -> Self { @@ -458,22 +450,17 @@ where D, { pi_len::() }, >::new::(default_config(), REVELATION_CIRCUIT_SET_SIZE); - let build_parameters = CircuitBuilderParams { + let build_parameters = CircuitBuilderParamsNoResultsTree { query_circuit_set: query_circuit_set.clone(), preprocessing_circuit_set: preprocessing_circuit_set.clone(), preprocessing_vk: preprocessing_vk.clone(), }; - #[cfg(feature = "batching_circuits")] - let revelation_no_results_tree = { - let batching_build_params = CircuitBuilderParams { - query_circuit_set: _batching_query_circuit_set.clone(), - preprocessing_circuit_set: preprocessing_circuit_set.clone(), - preprocessing_vk: preprocessing_vk.clone(), - }; - builder.build_circuit(batching_build_params) + let revelation_no_results_tree = builder.build_circuit(build_parameters); + let build_parameters = CircuitBuilderParams { + universal_query_vk: universal_circuit_vk, + preprocessing_circuit_set: preprocessing_circuit_set.clone(), + preprocessing_vk: preprocessing_vk.clone(), }; - #[cfg(not(feature = "batching_circuits"))] - let revelation_no_results_tree = builder.build_circuit(build_parameters.clone()); let revelation_unproven_offset = builder.build_circuit(build_parameters); let circuits = vec![ @@ -490,7 +477,7 @@ where } } - pub fn generate_proof( + pub(crate) fn generate_proof( &self, input: CircuitInput< ROW_TREE_MAX_DEPTH, @@ -502,14 +489,14 @@ where MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_PLACEHOLDERS, >, - _batching_query_circuit_set: &RecursiveCircuits, query_circuit_set: &RecursiveCircuits, query_params: Option< - &QueryParams< + &UniversalQueryCircuitParams< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_ITEMS_PER_OUTPUT, + OutputNoAggCircuit, >, >, ) -> Result> { @@ -519,14 +506,6 @@ where preprocessing_proof, revelation_circuit, } => { - #[cfg(feature = "batching_circuits")] - let input = RecursiveCircuitInputs { - inputs: revelation_circuit, - query_proof, - preprocessing_proof, - query_circuit_set: _batching_query_circuit_set.clone(), - }; - #[cfg(not(feature = "batching_circuits"))] let input = RecursiveCircuitInputs { inputs: revelation_circuit, query_proof, @@ -550,8 +529,11 @@ where dummy_row_proof_input, } => { let row_proofs = if let Some(input) = dummy_row_proof_input { - let proof = query_params.unwrap().generate_proof(input)?; - let proof = ProofWithVK::deserialize(&proof)?; + let proof = if let UniversalCircuitInput::QueryNoAgg(input) = input { + query_params.unwrap().generate_proof(&input)? + } else { + unreachable!("Universal circuit should only be used for queries with no aggregation operations") + }; row_proofs .into_iter() .chain(repeat(proof)) @@ -588,12 +570,11 @@ where } #[cfg(test)] -#[cfg(not(feature = "batching_circuits"))] mod tests { - use crate::test_utils::{ + use crate::{query::pi_len as query_pi_len, test_utils::{ TestRevelationData, MAX_NUM_COLUMNS, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, MAX_NUM_PLACEHOLDERS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, - }; + }}; use itertools::Itertools; use mp2_common::{ array::ToField, @@ -601,7 +582,7 @@ mod tests { types::HashOutput, C, D, F, }; - use mp2_test::log::init_logging; + use mp2_test::{circuit::TestDummyCircuit, log::init_logging}; use plonky2::{ field::types::PrimeField64, hash::hash_types::HashOut, plonk::config::GenericHashOut, }; @@ -615,7 +596,6 @@ mod tests { }, revelation::{ api::{CircuitInput, Parameters}, - num_query_io, tests::compute_results_from_query_proof_outputs, PublicInputs, NUM_PREPROCESSING_IO, }, @@ -636,10 +616,11 @@ mod tests { F, C, D, - { num_query_io::() }, + { query_pi_len::() }, >::default(); let preprocessing_circuits = TestingRecursiveCircuits::::default(); + let dummy_universal_circuit = TestDummyCircuit::<{query_pi_len::()}>::build(); println!("building params"); let params = Parameters::< ROW_TREE_MAX_DEPTH, @@ -651,8 +632,8 @@ mod tests { MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_PLACEHOLDERS, >::build( - query_circuits.get_recursive_circuit_set(), // unused, so we use a dummy one query_circuits.get_recursive_circuit_set(), + dummy_universal_circuit.circuit_data().verifier_data(), preprocessing_circuits.get_recursive_circuit_set(), preprocessing_circuits .verifier_data_for_input_proofs::<1>() @@ -664,7 +645,6 @@ mod tests { let test_data = TestRevelationData::sample(42, 76); let query_pi = QueryPI::::from_slice(test_data.query_pi_raw()); - // generate query proof let [query_proof] = query_circuits .generate_input_proofs::<1>([test_data.query_pi_raw().try_into().unwrap()]) @@ -679,7 +659,6 @@ mod tests { .unwrap(); let preprocessing_pi = PreprocessingPI::from_slice(&preprocessing_proof.public_inputs); let preprocessing_proof = serialize_proof(&preprocessing_proof).unwrap(); - let input = CircuitInput::new_revelation_aggregated( query_proof, preprocessing_proof, @@ -692,7 +671,6 @@ mod tests { let proof = params .generate_proof( input, - query_circuits.get_recursive_circuit_set(), // unused in this test, so we provide a dummy one query_circuits.get_recursive_circuit_set(), None, ) diff --git a/verifiable-db/src/revelation/mod.rs b/verifiable-db/src/revelation/mod.rs index e45644c1d..d27c49e06 100644 --- a/verifiable-db/src/revelation/mod.rs +++ b/verifiable-db/src/revelation/mod.rs @@ -1,14 +1,10 @@ //! Module including the revelation circuits for query -use crate::{ivc::NUM_IO, query::pi_len as query_pi_len}; +use crate::ivc::NUM_IO; use mp2_common::F; -#[cfg(feature = "batching_circuits")] -use crate::query::batching::circuits::api::num_io as num_batching_io; pub mod api; -#[cfg(feature = "batching_circuits")] -mod batching; pub(crate) mod placeholders_check; mod public_inputs; mod revelation_unproven_offset; @@ -25,19 +21,6 @@ pub const fn pi_len() -> usize } pub const NUM_PREPROCESSING_IO: usize = NUM_IO; -#[cfg(feature = "batching_circuits")] -pub const fn num_query_io_no_results_tree() -> usize { - num_batching_io::() -} - -#[cfg(not(feature = "batching_circuits"))] -pub const fn num_query_io_no_results_tree() -> usize { - query_pi_len::() -} - -pub const fn num_query_io() -> usize { - query_pi_len::() -} #[cfg(test)] pub(crate) mod tests { use super::*; diff --git a/verifiable-db/src/revelation/revelation_unproven_offset.rs b/verifiable-db/src/revelation/revelation_unproven_offset.rs index fc3042f0f..e1c68c5ab 100644 --- a/verifiable-db/src/revelation/revelation_unproven_offset.rs +++ b/verifiable-db/src/revelation/revelation_unproven_offset.rs @@ -16,7 +16,7 @@ use mp2_common::{ default_config, group_hashing::CircuitBuilderGroupHashing, poseidon::{flatten_poseidon_hash_target, H}, - proof::ProofWithVK, + proof::verify_proof_fixed_circuit, public_inputs::PublicInputCommon, serialization::{ deserialize, deserialize_array, deserialize_long_array, serialize, serialize_array, @@ -28,22 +28,20 @@ use mp2_common::{ C, D, F, }; use plonky2::{ - field::types::PrimeField64, hash::hash_types::HashOutTarget, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, plonk::{ - config::Hasher, - proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}, + circuit_data::{VerifierCircuitData, VerifierOnlyCircuitData}, config::Hasher, proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget} }, }; use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; use recursion_framework::{ circuit_builder::CircuitLogicWires, framework::{ - RecursiveCircuits, RecursiveCircuitsVerifierGagdet, RecursiveCircuitsVerifierTarget, + RecursiveCircuits, RecursiveCircuitsVerifierGagdet, }, }; use serde::{Deserialize, Serialize}; @@ -51,23 +49,16 @@ use serde::{Deserialize, Serialize}; use crate::{ ivc::PublicInputs as OriginalTreePublicInputs, query::{ - aggregation::{ChildPosition, NodeInfo, QueryBounds, QueryHashNonExistenceCircuits}, - api::CircuitInput as QueryCircuitInput, - computational_hash_ids::{AggregationOperation, ColumnIDs, ResultIdentifier}, - merkle_path::{MerklePathGadget, MerklePathTargetInputs}, - pi_len, - public_inputs::PublicInputs as QueryProofPublicInputs, - universal_circuit::{ + aggregation::{ChildPosition, NodeInfo, QueryBounds}, public_inputs::PublicInputsUniversalCircuit as QueryProofPublicInputs, computational_hash_ids::{ColumnIDs, ResultIdentifier}, merkle_path::{MerklePathGadget, MerklePathTargetInputs}, universal_circuit::{ build_cells_tree, - universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure}, - }, + universal_circuit_inputs::{BasicOperation, ColumnCell, Placeholders, ResultStructure, RowCells}, universal_query_circuit::{UniversalCircuitInput, UniversalQueryCircuitInputs}, + } }, }; use super::{ - num_query_io, pi_len as revelation_pi_len, + pi_len as revelation_pi_len, placeholders_check::{CheckPlaceholderGadget, CheckPlaceholderInputWires}, - revelation_without_results_tree::CircuitBuilderParams, PublicInputs, NUM_PREPROCESSING_IO, }; @@ -192,6 +183,7 @@ pub(crate) struct RevelationWires< deserialize_with = "deserialize_array" )] is_row_node_leaf: [BoolTarget; L], + index_column_ids: [Target; 2], #[serde( serialize_with = "serialize_array", deserialize_with = "deserialize_array" @@ -252,6 +244,8 @@ pub struct RevelationCircuit< /// Info about the nodes of the index tree that stores the rows trees where each of /// the L rows being proven are located index_node_info: [NodeInfo; L], + /// Identifiers of the indexed columns + index_column_ids: [F; 2], /// Actual number of items per-row included in the results. num_actual_items_per_row: usize, /// Ids of the output items included in the results for each row @@ -290,6 +284,7 @@ where { pub(crate) fn new( row_paths: [RowPath; L], + index_column_ids: [F; 2], item_ids: &[F], results: [Vec; L], limit: u32, @@ -338,6 +333,7 @@ where index_tree_paths, row_node_info, index_node_info, + index_column_ids, num_actual_items_per_row, ids: padded_ids.try_into().unwrap(), results: results.try_into().unwrap(), @@ -366,8 +362,12 @@ where // computed by the universal query circuit // closure to access the output items of the i-th result let get_result = |i| &results[S * i..S * (i + 1)]; - let [min_query, max_query] = b.add_virtual_u256_arr_unsafe(); // unsafe should be ok since they are later included in placeholder hash + let (min_query_primary, max_query_primary) = ( + row_proofs[0].min_primary_target(), + row_proofs[0].max_primary_target(), + ); let [limit, offset] = b.add_virtual_target_arr(); + let index_column_ids = b.add_virtual_target_arr(); let tree_hash = original_tree_proof.merkle_hash(); let zero = b.zero(); let one = b.one(); @@ -385,7 +385,6 @@ where // this is a requirement to ensure that the check for DISTINCT is sound let mut only_matching_rows = _true; row_proofs.iter().enumerate().for_each(|(i, row_proof)| { - let index_ids = row_proof.index_ids_target(); let is_matching_row = b.is_equal(row_proof.num_matching_rows_target(), one); // ensure that once `is_matching_row = false`, then it will be false for all // subsequent iterations @@ -401,8 +400,8 @@ where .flat_map(|hash| hash.to_targets()) .chain(row_node_info[i].node_min.to_targets()) .chain(row_node_info[i].node_max.to_targets()) - .chain(once(index_ids[1])) - .chain(row_proof.min_value_target().to_targets()) + .chain(once(index_column_ids[1])) + .chain(row_proof.secondary_index_value_target().to_targets()) .chain(row_proof.tree_hash_target().to_targets()) .collect_vec(); let row_node_hash = b.hash_n_to_hash_no_pad::(inputs); @@ -412,7 +411,7 @@ where &row_node_hash, ) }; - let row_path_wires = MerklePathGadget::build(b, row_node_hash, index_ids[1]); + let row_path_wires = MerklePathGadget::build(b, row_node_hash, index_column_ids[1]); let row_tree_root = row_path_wires.root; // compute hash of the index node storing the rows tree containing the current row let index_node_hash = { @@ -422,13 +421,13 @@ where .flat_map(|hash| hash.to_targets()) .chain(index_node_info[i].node_min.to_targets()) .chain(index_node_info[i].node_max.to_targets()) - .chain(once(index_ids[0])) - .chain(row_proof.index_value_target().to_targets()) + .chain(once(index_column_ids[0])) + .chain(row_proof.primary_index_value_target().to_targets()) .chain(row_tree_root.to_targets()) .collect_vec(); b.hash_n_to_hash_no_pad::(inputs) }; - let index_path_wires = MerklePathGadget::build(b, index_node_hash, index_ids[0]); + let index_path_wires = MerklePathGadget::build(b, index_node_hash, index_column_ids[0]); // if the current row is valid, check that the root is the same of the original tree, completing // membership proof for the current row; otherwise, we don't care let root = b.select_hash(is_matching_row, &index_path_wires.root, &tree_hash); @@ -436,14 +435,6 @@ where row_paths.push(row_path_wires.inputs); index_paths.push(index_path_wires.inputs); - // check that the primary index value for the current row is within the query - // bounds (only if the row is valid) - let index_value = row_proof.index_value_target(); - let greater_than_min = b.is_less_or_equal_than_u256(&min_query, &index_value); - let smaller_than_max = b.is_less_or_equal_than_u256(&index_value, &max_query); - let in_range = b.and(greater_than_min, smaller_than_max); - let in_range = b.and(is_matching_row, in_range); - b.connect(in_range.target, is_matching_row.target); // enforce DISTINCT only for actual results: we enforce the i-th actual result is strictly smaller // than the (i+1)-th actual result @@ -489,6 +480,15 @@ where // the proofs b.connect_hashes(row_proof.computational_hash_target(), computational_hash); b.connect_hashes(row_proof.placeholder_hash_target(), placeholder_hash); + // check that query bounds on primary index are the same for all the proofs + b.enforce_equal_u256( + &row_proof.min_primary_target(), + &min_query_primary, + ); + b.enforce_equal_u256( + &row_proof.max_primary_target(), + &max_query_primary, + ); overflow = b.or(overflow, row_proof.overflow_flag_target()); }); @@ -499,19 +499,19 @@ where let inputs = placeholder_hash .to_targets() .into_iter() - .chain(min_query.to_targets()) - .chain(max_query.to_targets()) + .chain(min_query_primary.to_targets()) + .chain(max_query_primary.to_targets()) .collect_vec(); b.hash_n_to_hash_no_pad::(inputs) }; let check_placeholder_wires = CheckPlaceholderGadget::build(b, &final_placeholder_hash); b.enforce_equal_u256( - &min_query, + &min_query_primary, &check_placeholder_wires.input_wires.placeholder_values[0], ); b.enforce_equal_u256( - &max_query, + &max_query_primary, &check_placeholder_wires.input_wires.placeholder_values[1], ); @@ -566,6 +566,7 @@ where row_node_info, index_node_info, is_row_node_leaf, + index_column_ids, is_item_included, ids, results, @@ -614,6 +615,7 @@ where .zip(wires.results.iter()) .for_each(|(&value, target)| pw.set_u256_target(target, value)); pw.set_target_arr(&wires.ids, &self.ids); + pw.set_target_arr(&wires.index_column_ids, &self.index_column_ids); pw.set_target(wires.limit, self.limit.to_field()); pw.set_target(wires.offset, self.offset.to_field()); pw.set_bool_target(wires.distinct, self.distinct); @@ -637,7 +639,7 @@ pub(crate) fn generate_dummy_row_proof_inputs< placeholders: &Placeholders, query_bounds: &QueryBounds, ) -> Result< - QueryCircuitInput< + UniversalCircuitInput< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -648,57 +650,49 @@ where [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, [(); MAX_NUM_ITEMS_PER_OUTPUT - 1]:, - [(); pi_len::()]:, [(); >::HASH_SIZE]:, { - // we generate a dummy proof for a dummy node of the index tree with an index value out of range - let query_hashes = QueryHashNonExistenceCircuits::new::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_ITEMS_PER_OUTPUT, - >( - column_ids, + // we generate dummy column cells; we can use all dummy values, except for the + // primary index value which must be in the query range + let primary_index_value = query_bounds.min_query_primary(); + let primary_index_column = ColumnCell { + value: primary_index_value, + id: column_ids.primary, + }; + let secondary_index_column = ColumnCell { + value: U256::default(), + id: column_ids.secondary, + }; + let non_indexed_columns = column_ids.non_indexed_columns().iter().map(|id| + ColumnCell::new(*id, U256::default()) + ).collect_vec(); + let cells = RowCells::new( + primary_index_column, + secondary_index_column, + non_indexed_columns + ); + let universal_query_circuit = UniversalQueryCircuitInputs::new( + &cells, predicate_operations, - results, placeholders, - query_bounds, - false, - )?; - // we generate info about the proven index-tree node; we can use all dummy values, except for the - // node value which must be out of the query range - let node_value = query_bounds.max_query_primary() + U256::from(1); - let node_info = NodeInfo::new( - &HashOutput::default(), - None, // no children, for simplicity - None, - node_value, - U256::default(), - U256::default(), - ); - // The query has no aggregation operations, so by construction of the circuits we - // know that the first aggregate operation is ID, while the remaining ones are dummies - let aggregation_ops = once(AggregationOperation::IdOp) - .chain(repeat(AggregationOperation::default())) - .take(MAX_NUM_ITEMS_PER_OUTPUT) - .collect_vec(); - QueryCircuitInput::new_non_existence_input( - node_info, - None, - None, - node_value, - &[ - column_ids.primary.to_canonical_u64(), - column_ids.secondary.to_canonical_u64(), - ], - &aggregation_ops, - query_hashes, false, query_bounds, - placeholders, + results, + true, // we generate proof for a dummy row + )?; + Ok( + UniversalCircuitInput::QueryNoAgg( + universal_query_circuit + ) ) } +pub struct CircuitBuilderParams { + pub(crate) universal_query_vk: VerifierCircuitData, + pub(crate) preprocessing_circuit_set: RecursiveCircuits, + pub(crate) preprocessing_vk: VerifierOnlyCircuitData, +} + #[derive(Serialize, Deserialize, Clone, Debug)] pub struct RecursiveCircuitWires< const ROW_TREE_MAX_DEPTH: usize, @@ -714,10 +708,10 @@ pub struct RecursiveCircuitWires< { revelation_circuit: RevelationWires, #[serde( - serialize_with = "serialize_long_array", - deserialize_with = "deserialize_long_array" + serialize_with = "serialize_array", + deserialize_with = "deserialize_array" )] - row_verifiers: [RecursiveCircuitsVerifierTarget; L], + row_verifiers: [ProofWithPublicInputsTarget; L], #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] preprocessing_proof: ProofWithPublicInputsTarget, } @@ -740,7 +734,7 @@ pub struct RecursiveCircuitInputs< serialize_with = "serialize_long_array", deserialize_with = "deserialize_long_array" )] - pub(crate) row_proofs: [ProofWithVK; L], + pub(crate) row_proofs: [ProofWithPublicInputs; L], pub(crate) preprocessing_proof: ProofWithPublicInputs, pub(crate) query_circuit_set: RecursiveCircuits, } @@ -758,7 +752,6 @@ where [(); ROW_TREE_MAX_DEPTH - 1]:, [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); S * L]:, - [(); num_query_io::()]:, [(); >::HASH_SIZE]:, { type CircuitBuilderParams = CircuitBuilderParams; @@ -772,11 +765,10 @@ where _verified_proofs: [&ProofWithPublicInputsTarget; 0], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - let row_verifier = RecursiveCircuitsVerifierGagdet::() }>::new( - default_config(), - &builder_parameters.query_circuit_set, - ); - let row_verifiers = [0; L].map(|_| row_verifier.verify_proof_in_circuit_set(builder)); + let row_verifiers = [0; L].map(|_| verify_proof_fixed_circuit( + builder, + &builder_parameters.universal_query_vk, + )); let preprocessing_verifier = RecursiveCircuitsVerifierGagdet::::new( default_config(), @@ -790,7 +782,7 @@ where .iter() .map(|verifier| { QueryProofPublicInputs::from_slice( - verifier.get_public_input_targets::() }>(), + &verifier.public_inputs, ) }) .collect_vec(); @@ -808,8 +800,7 @@ where fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { for (verifier_target, row_proof) in self.row_verifiers.iter().zip(inputs.row_proofs) { - let (proof, verifier_data) = (&row_proof).into(); - verifier_target.set_target(pw, &inputs.query_circuit_set, proof, verifier_data)?; + pw.set_proof_with_pis_target(verifier_target, &row_proof); } pw.set_proof_with_pis_target(&self.preprocessing_proof, &inputs.preprocessing_proof); inputs.inputs.assign(pw, &self.revelation_circuit); @@ -854,13 +845,13 @@ mod tests { }, query::{ aggregation::{ChildPosition, NodeInfo}, - public_inputs::{PublicInputs as QueryProofPublicInputs, QueryPublicInputs}, + public_inputs::{PublicInputsUniversalCircuit as QueryProofPublicInputs, QueryPublicInputsUniversalCircuit}, pi_len as query_pi_len, }, revelation::{ - num_query_io, revelation_unproven_offset::RowPath, tests::TestPlaceholders, + revelation_unproven_offset::RowPath, tests::TestPlaceholders, NUM_PREPROCESSING_IO, }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, + test_utils::random_aggregation_operations, }; use super::{RevelationCircuit, RevelationWires}; @@ -907,7 +898,7 @@ mod tests { fn build(c: &mut CircuitBuilder) -> Self::Wires { let row_pis_raw: [Vec; L] = (0..L) - .map(|_| c.add_virtual_targets(num_query_io::())) + .map(|_| c.add_virtual_targets(query_pi_len::())) .collect_vec() .try_into() .unwrap(); @@ -943,51 +934,41 @@ mod tests { const PH: usize = 10; const PP: usize = 30; let ops = random_aggregation_operations::(); - let mut row_pis = random_aggregation_public_inputs(&ops); + let mut row_pis = QueryProofPublicInputs::sample_from_ops(&ops); let rng = &mut thread_rng(); let mut original_tree_pis = (0..NUM_PREPROCESSING_IO) .map(|_| rng.gen()) .collect::>() .to_fields(); + let index_ids = F::rand_array(); const NUM_PLACEHOLDERS: usize = 5; let test_placeholders = TestPlaceholders::sample(NUM_PLACEHOLDERS); - let (index_ids, computational_hash) = { + let computational_hash = { let row_pi_0 = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[0]); - let index_ids = row_pi_0.index_ids(); - let computational_hash = row_pi_0.computational_hash(); - - (index_ids, computational_hash) + row_pi_0.computational_hash() }; let placeholder_hash = test_placeholders.query_placeholder_hash; - // set same index_ids, computational hash and placeholder hash for all proofs; set also num matching rows to 1 - // for all proofs + let min_query_primary = test_placeholders.min_query; + let max_query_primary = test_placeholders.max_query; + // set same primary index query bounds, computational hash and placeholder hash for all proofs; + // set also num matching rows to 1 for all proofs row_pis.iter_mut().for_each(|pis| { - let [index_id_range, ch_range, ph_range, count_range] = [ - QueryPublicInputs::IndexIds, - QueryPublicInputs::ComputationalHash, - QueryPublicInputs::PlaceholderHash, - QueryPublicInputs::NumMatching, + let [min_primary_range, max_primary_range, ch_range, ph_range, count_range] = [ + QueryPublicInputsUniversalCircuit::MinPrimary, + QueryPublicInputsUniversalCircuit::MaxPrimary, + QueryPublicInputsUniversalCircuit::ComputationalHash, + QueryPublicInputsUniversalCircuit::PlaceholderHash, + QueryPublicInputsUniversalCircuit::NumMatching, ] .map(QueryProofPublicInputs::::to_range); - pis[index_id_range].copy_from_slice(&index_ids); + pis[min_primary_range].copy_from_slice(&min_query_primary.to_fields()); + pis[max_primary_range].copy_from_slice(&max_query_primary.to_fields()); pis[ch_range].copy_from_slice(&computational_hash.to_fields()); pis[ph_range].copy_from_slice(&placeholder_hash.to_fields()); pis[count_range].copy_from_slice(&[F::ONE]); }); - let index_value_range = - QueryProofPublicInputs::::to_range(QueryPublicInputs::IndexValue); - let hash_range = QueryProofPublicInputs::::to_range(QueryPublicInputs::TreeHash); - let min_query = test_placeholders.min_query; - let max_query = test_placeholders.max_query; - // closure that modifies a set of row public inputs to ensure that the index value lies - // within the query bounds; the new index value set in the public inputs is returned by the closure - let enforce_index_value_in_query_range = |pis: &mut [F], index_value: U256| { - let query_range_size = max_query - min_query + U256::from(1); - let new_index_value = min_query + index_value % query_range_size; - pis[index_value_range.clone()].copy_from_slice(&new_index_value.to_fields()); - assert!(new_index_value >= min_query && new_index_value <= max_query); - new_index_value - }; + let hash_range = QueryProofPublicInputs::::to_range(QueryPublicInputsUniversalCircuit::TreeHash); + let index_value_range = QueryProofPublicInputs::::to_range(QueryPublicInputsUniversalCircuit::PrimaryIndexValue); // build a test tree containing the rows 0..5 found in row_pis // Index tree: // A @@ -1004,7 +985,7 @@ mod tests { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[1]); let embedded_tree_hash = HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(); - let node_value = row_pi.min_value(); + let node_value = row_pi.secondary_index_value(); NodeInfo::new( &embedded_tree_hash, None, @@ -1019,10 +1000,10 @@ mod tests { row_pis[1][hash_range.clone()].copy_from_slice(&node_1_hash.to_fields()); let node_0 = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[0]); - let embedded_tree_hash = HashOutput::try_from(row_pi.tree_hash().to_bytes()).unwrap(); - let node_value = row_pi.min_value(); + let embedded_tree_hash = HashOutput::from(row_pi.tree_hash()); + let node_value = row_pi.secondary_index_value(); // left child is node 1 - let left_child_hash = HashOutput::try_from(node_1_hash.to_bytes()).unwrap(); + let left_child_hash = HashOutput::from(node_1_hash); NodeInfo::new( &embedded_tree_hash, Some(&left_child_hash), @@ -1035,8 +1016,8 @@ mod tests { let node_2 = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[2]); let embedded_tree_hash = - HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(); - let node_value = row_pi.min_value(); + HashOutput::from(gen_random_field_hash::()); + let node_value = row_pi.secondary_index_value(); NodeInfo::new( &embedded_tree_hash, None, @@ -1052,8 +1033,8 @@ mod tests { let node_4 = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[4]); let embedded_tree_hash = - HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(); - let node_value = row_pi.min_value(); + HashOutput::from(gen_random_field_hash::()); + let node_value = row_pi.secondary_index_value(); NodeInfo::new( &embedded_tree_hash, None, @@ -1069,7 +1050,7 @@ mod tests { let node_5 = { // can use all dummy values for this node, since there is no proof associated to it let embedded_tree_hash = - HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(); + HashOutput::from(gen_random_field_hash::()); let [node_value, node_min, node_max] = array::from_fn(|_| gen_random_u256(rng)); NodeInfo::new( &embedded_tree_hash, @@ -1080,13 +1061,13 @@ mod tests { node_max, ) }; - let node_4_hash = HashOutput::try_from(node_4_hash.to_bytes()).unwrap(); + let node_4_hash = HashOutput::from(node_4_hash); let node_5_hash = - HashOutput::try_from(node_5.compute_node_hash(index_ids[1]).to_bytes()).unwrap(); + HashOutput::from(node_5.compute_node_hash(index_ids[1])); let node_3 = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[3]); - let embedded_tree_hash = HashOutput::try_from(row_pi.tree_hash().to_bytes()).unwrap(); - let node_value = row_pi.min_value(); + let embedded_tree_hash = HashOutput::from(row_pi.tree_hash()); + let node_value = row_pi.secondary_index_value(); NodeInfo::new( &embedded_tree_hash, Some(&node_4_hash), // left child is node 4 @@ -1099,9 +1080,8 @@ mod tests { let node_b = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[2]); let embedded_tree_hash = - HashOutput::try_from(node_2.compute_node_hash(index_ids[1]).to_bytes()).unwrap(); - let index_value = row_pi.index_value(); - let node_value = enforce_index_value_in_query_range(&mut row_pis[2], index_value); + HashOutput::from(node_2.compute_node_hash(index_ids[1])); + let node_value = row_pi.primary_index_value(); NodeInfo::new( &embedded_tree_hash, None, @@ -1112,13 +1092,13 @@ mod tests { ) }; let node_c = { - let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[4]); + let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[3]); let embedded_tree_hash = - HashOutput::try_from(node_3.compute_node_hash(index_ids[1]).to_bytes()).unwrap(); - let index_value = row_pi.index_value(); - let node_value = enforce_index_value_in_query_range(&mut row_pis[4], index_value); - // we need also to set index value PI in row_pis[3] to the same value of row_pis[4], as they are in the same index tree - row_pis[3][index_value_range.clone()].copy_from_slice(&node_value.to_fields()); + HashOutput::from(node_3.compute_node_hash(index_ids[1])); + let node_value = row_pi.primary_index_value(); + // we need to set index value in `row_pis[4]` to the same value of `row_pis[3]`, as + // they are in the same index tree + row_pis[4][index_value_range.clone()].copy_from_slice(&node_value.to_fields()); NodeInfo::new( &embedded_tree_hash, None, @@ -1129,17 +1109,17 @@ mod tests { ) }; let node_b_hash = - HashOutput::try_from(node_b.compute_node_hash(index_ids[0]).to_bytes()).unwrap(); + HashOutput::from(node_b.compute_node_hash(index_ids[0])); let node_c_hash = - HashOutput::try_from(node_c.compute_node_hash(index_ids[0]).to_bytes()).unwrap(); + HashOutput::from(node_c.compute_node_hash(index_ids[0])); let node_a = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[0]); let embedded_tree_hash = - HashOutput::try_from(node_0.compute_node_hash(index_ids[1]).to_bytes()).unwrap(); - let index_value = row_pi.index_value(); - let node_value = enforce_index_value_in_query_range(&mut row_pis[0], index_value); - // we need also to set index value PI in row_pis[1] to the same value of row_pis[0], as they are in the same index tree - row_pis[1][index_value_range].copy_from_slice(&node_value.to_fields()); + HashOutput::from(node_0.compute_node_hash(index_ids[1])); + let node_value = row_pi.primary_index_value(); + // we need to set index value in `row_pis[1]` to the same value of `row_pis[0]`, as + // they are in the same index tree + row_pis[1][index_value_range].copy_from_slice(&node_value.to_fields()); NodeInfo::new( &embedded_tree_hash, Some(&node_b_hash), // left child is node B @@ -1203,7 +1183,7 @@ mod tests { row_pis.iter_mut().zip(digests).for_each(|(pis, digest)| { let values_range = - QueryProofPublicInputs::::to_range(QueryPublicInputs::OutputValues); + QueryProofPublicInputs::::to_range(QueryPublicInputsUniversalCircuit::OutputValues); pis[values_range.start..values_range.start + CURVE_TARGET_LEN] .copy_from_slice(&digest.to_fields()) }); @@ -1254,6 +1234,7 @@ mod tests { TestRevelationCircuit:: { circuit: RevelationCircuit::new( [row_path_0, row_path_1, row_path_2, row_path_3, row_path_4], + index_ids, &ids, results.map(|res| res.to_vec()), 0, diff --git a/verifiable-db/src/revelation/revelation_without_results_tree.rs b/verifiable-db/src/revelation/revelation_without_results_tree.rs index 52f500813..c37e8c0c1 100644 --- a/verifiable-db/src/revelation/revelation_without_results_tree.rs +++ b/verifiable-db/src/revelation/revelation_without_results_tree.rs @@ -3,12 +3,7 @@ use crate::{ ivc::PublicInputs as OriginalTreePublicInputs, query::{ - computational_hash_ids::AggregationOperation, - public_inputs::PublicInputs as QueryProofPublicInputs, - universal_circuit::{ - universal_query_gadget::OutputValuesTarget, ComputationalHashTarget, - MembershipHashTarget, PlaceholderHashTarget, - }, + public_inputs::PublicInputs as QueryProofPublicInputs, computational_hash_ids::AggregationOperation, pi_len as query_pi_len, }, revelation::PublicInputs, }; @@ -23,7 +18,7 @@ use mp2_common::{ serialization::{deserialize, serialize}, types::CBuilder, u256::{CircuitBuilderU256, UInt256Target}, - utils::{FromTargets, ToTargets}, + utils::ToTargets, C, D, F, }; use plonky2::{ @@ -47,16 +42,11 @@ use recursion_framework::{ use serde::{Deserialize, Serialize}; use super::{ - num_query_io, num_query_io_no_results_tree, pi_len as revelation_pi_len, + pi_len as revelation_pi_len, placeholders_check::{CheckPlaceholderGadget, CheckPlaceholderInputWires}, NUM_PREPROCESSING_IO, }; -#[cfg(feature = "batching_circuits")] -use super::batching::RevelationCircuitBatching; -#[cfg(feature = "batching_circuits")] -use crate::query::batching::public_inputs::PublicInputs as BatchingPublicInputs; - // L: maximum number of results // S: maximum number of items in each result // PH: maximum number of unique placeholder IDs and values bound for query @@ -81,46 +71,6 @@ pub struct RevelationWithoutResultsTreeCircuit< pub(crate) check_placeholder: CheckPlaceholderGadget, } -/// Data structure containing the wires corresponding to public inputs -/// of the query proof employed in the revelation circuit. It is a -/// data structure employed to represent the public inputs of the query -/// proof in a revelation circuit both for query proofs generated with -/// batching circuits, and for query proofs generated with non-batching -/// circuits -pub(crate) struct QueryProofInputWires -where - [(); S - 1]:, -{ - pub(crate) tree_hash: MembershipHashTarget, - pub(crate) results: OutputValuesTarget, - pub(crate) entry_count: Target, - pub(crate) overflow: Target, - pub(crate) placeholder_hash: PlaceholderHashTarget, - pub(crate) computational_hash: ComputationalHashTarget, - pub(crate) min_primary: UInt256Target, - pub(crate) max_primary: UInt256Target, - pub(crate) ops: [Target; S], -} - -impl<'a, const S: usize> From<&'a QueryProofPublicInputs<'a, Target, S>> for QueryProofInputWires -where - [(); S - 1]:, -{ - fn from(value: &'a QueryProofPublicInputs) -> Self { - Self { - tree_hash: value.tree_hash_target(), - results: OutputValuesTarget::from_targets(value.to_values_raw()), - entry_count: value.num_matching_rows_target(), - overflow: value.overflow_flag_target().target, - placeholder_hash: value.placeholder_hash_target(), - computational_hash: value.computational_hash_target(), - min_primary: value.min_query_target(), - max_primary: value.max_query_target(), - ops: value.operation_ids_target(), - } - } -} - impl RevelationWithoutResultsTreeCircuit where @@ -132,18 +82,6 @@ where query_proof: &QueryProofPublicInputs, // proof of construction of the original tree in the pre-processing stage (IVC proof) original_tree_proof: &OriginalTreePublicInputs, - ) -> RevelationWithoutResultsTreeWires { - Self::build_core(b, query_proof.into(), original_tree_proof) - } - - // Internal build method taking as input `QueryProofInputWires` instead of the public inputs of the - // query proof circuits without batching. The core logic is placed in this method because it is shared - // with the `RevelationCircuitBatchingCircuits` circuit - pub(crate) fn build_core( - b: &mut CBuilder, - query_proof: QueryProofInputWires, - // proof of construction of the original tree in the pre-processing stage (IVC proof) - original_tree_proof: &OriginalTreePublicInputs, ) -> RevelationWithoutResultsTreeWires { let zero = b.zero(); let u256_zero = b.zero_u256(); @@ -153,16 +91,16 @@ where .map(|op| b.constant(op.to_field())); // Convert the entry count to an Uint256. - let entry_count = UInt256Target::new_from_target(b, query_proof.entry_count); + let entry_count = UInt256Target::new_from_target(b, query_proof.num_matching_rows_target()); // Compute the output results array, and deal with AVG and COUNT operations if any. let mut results = Vec::with_capacity(L * S); // flag to determine whether entry count is zero let is_entry_count_zero = b.add_virtual_bool_target_unsafe(); - query_proof.ops.into_iter().enumerate().for_each(|(i, op)| { + query_proof.operation_ids_target().into_iter().enumerate().for_each(|(i, op)| { let is_op_avg = b.is_equal(op, op_avg); let is_op_count = b.is_equal(op, op_count); - let result = query_proof.results.value_target_at_index(i); + let result = query_proof.value_target_at_index(i); // Compute the AVG result (and it's set to zero if the divisor is zero). let (avg_result, _, is_divisor_zero) = b.div_u256(&result, &entry_count); @@ -180,11 +118,11 @@ where // `check_placeholders` function: // H(pQ.H_p || pQ.MIN_I || pQ.MAX_I) let inputs = query_proof - .placeholder_hash + .placeholder_hash_target() .to_targets() .into_iter() - .chain(query_proof.min_primary.to_targets()) - .chain(query_proof.max_primary.to_targets()) + .chain(query_proof.min_primary_target().to_targets()) + .chain(query_proof.max_primary_target().to_targets()) .collect(); let final_placeholder_hash = b.hash_n_to_hash_no_pad::(inputs); @@ -194,13 +132,13 @@ where // Check that the tree employed to build the queries is the same as the // tree constructed in pre-processing. - b.connect_hashes(query_proof.tree_hash, original_tree_proof.merkle_hash()); + b.connect_hashes(query_proof.tree_hash_target(), original_tree_proof.merkle_hash()); // Add the hash of placeholder identifiers and pre-processing metadata // hash to the computational hash: // H(pQ.C || placeholder_ids_hash || pQ.M) let inputs = query_proof - .computational_hash + .computational_hash_target() .to_targets() .iter() .chain(&check_placeholder_wires.placeholder_id_hash.to_targets()) @@ -221,6 +159,97 @@ where let flat_computational_hash = flatten_poseidon_hash_target(b, computational_hash); + // additional constraints on boundary rows to ensure completeness of proven rows + // (i.e., that we look at all the rows with primary and secondary index values in the query range) + + let left_boundary_row = query_proof.left_boundary_row_target(); + + // 1. Either the index tree node of left boundary row has no predecessor, or + // the value of the predecessor is smaller than MIN_primary + let smaller_than_min_primary = b.is_less_than_u256( + &left_boundary_row.index_node_info.predecessor_info.value, + &query_proof.min_primary_target(), + ); + // assert not pQ.left_boundary_row.index_node_data.predecessor_info.is_found or + // pQ.left_boundary_row.index_node_data.predecessor_value < pQ.MIN_primary + let constraint = b.and( + left_boundary_row.index_node_info.predecessor_info.is_found, + smaller_than_min_primary, + ); + b.connect( + left_boundary_row + .index_node_info + .predecessor_info + .is_found + .target, + constraint.target, + ); + + // 2. Either the rows tree node storing left boundary row has no predecessor, or + // the value of the predecessor is smaller than MIN_secondary + let smaller_than_min_secondary = b.is_less_than_u256( + &left_boundary_row.row_node_info.predecessor_info.value, + &query_proof.min_secondary_target(), + ); + // assert not pQ.left_boundary_row.row_node_data.predecessor_info.is_found or + // pQ.left_boundary_row.row_node_data.predecessor_value < pQ.MIN_secondary + let constraint = b.and( + left_boundary_row.row_node_info.predecessor_info.is_found, + smaller_than_min_secondary, + ); + b.connect( + left_boundary_row + .row_node_info + .predecessor_info + .is_found + .target, + constraint.target, + ); + + let right_boundary_row = query_proof.right_boundary_row_target(); + + // 3. Either the index tree node of right boundary row has no successor, or + // the value of the successor is greater than MAX_primary + let greater_than_max_primary = b.is_greater_than_u256( + &right_boundary_row.index_node_info.successor_info.value, + &query_proof.max_primary_target(), + ); + // assert not pQ.right_boundary_row.index_node_data.successor_info.is_found or + // pQ.right_boundary_row.index_node_data.successor_value > pQ.MAX_primary + let constraint = b.and( + right_boundary_row.index_node_info.successor_info.is_found, + greater_than_max_primary, + ); + b.connect( + right_boundary_row + .index_node_info + .successor_info + .is_found + .target, + constraint.target, + ); + + // 4. Either the rows tree node storing right boundary row has no successor, or + // the value of the successor is greater than MAX_secondary + let greater_than_max_secondary = b.is_greater_than_u256( + &right_boundary_row.row_node_info.successor_info.value, + &query_proof.max_secondary_target(), + ); + // assert not pQ.right_boundary_row.row_node_data.successor_info.is_found or + // pQ.right_boundary_row.row_node_data.successor_value > pQ.MAX_secondary + let constraint = b.and( + right_boundary_row.row_node_info.successor_info.is_found, + greater_than_max_secondary, + ); + b.connect( + right_boundary_row + .row_node_info + .successor_info + .is_found + .target, + constraint.target, + ); + // Register the public innputs. PublicInputs::<_, L, S, PH>::new( &original_tree_proof.block_hash(), @@ -230,8 +259,8 @@ where &[check_placeholder_wires.num_placeholders], // The aggregation query proof only has one result. &[num_results.target], - &[query_proof.entry_count], - &[query_proof.overflow], + &[query_proof.num_matching_rows_target()], + &[query_proof.overflow_flag_target().target], // Query limit &[zero], // Query offset @@ -280,8 +309,7 @@ impl CircuitLo for RecursiveCircuitWires where [(); S - 1]:, - [(); num_query_io::()]:, - [(); num_query_io_no_results_tree::()]:, + [(); query_pi_len::()]:, [(); >::HASH_SIZE]:, { type CircuitBuilderParams = CircuitBuilderParams; @@ -299,7 +327,7 @@ where F, C, D, - { num_query_io_no_results_tree::() }, + { query_pi_len::() }, >::new(default_config(), &builder_parameters.query_circuit_set); let query_verifier = query_verifier.verify_proof_in_circuit_set(builder); let preprocessing_verifier = @@ -313,19 +341,10 @@ where ); let preprocessing_pi = OriginalTreePublicInputs::from_slice(&preprocessing_proof.public_inputs); - #[cfg(feature = "batching_circuits")] - let revelation_circuit = { - let query_pi = BatchingPublicInputs::from_slice( - query_verifier - .get_public_input_targets::() }>(), - ); - RevelationCircuitBatching::build(builder, &query_pi, &preprocessing_pi) - }; - #[cfg(not(feature = "batching_circuits"))] let revelation_circuit = { let query_pi = QueryProofPublicInputs::from_slice( query_verifier - .get_public_input_targets::() }>(), + .get_public_input_targets::() }>(), ); RevelationWithoutResultsTreeCircuit::build(builder, &query_pi, &preprocessing_pi) }; @@ -349,27 +368,50 @@ where #[cfg(test)] mod tests { - use super::*; - use crate::{ - query::{ - public_inputs::QueryPublicInputs, - universal_circuit::universal_query_gadget::OutputValues, - }, - revelation::tests::{compute_results_from_query_proof_outputs, TestPlaceholders}, - test_utils::{ - random_aggregation_operations, random_aggregation_public_inputs, - random_original_tree_proof, - }, - }; + use std::array; + use alloy::primitives::U256; + use itertools::Itertools; use mp2_common::{ - poseidon::flatten_poseidon_hash_value, + array::ToField, + poseidon::{flatten_poseidon_hash_value, H}, + types::CBuilder, utils::{FromFields, ToFields}, - C, D, + C, D, F, }; use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::{field::types::Field, plonk::config::Hasher}; - use rand::{prelude::SliceRandom, thread_rng, Rng}; + use plonky2::{ + field::types::Field, + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::config::Hasher, + }; + use rand::{seq::SliceRandom, thread_rng, Rng}; + + use crate::{ + ivc::PublicInputs as OriginalTreePublicInputs, + query::{ + aggregation::{QueryBoundSource, QueryBounds}, + public_inputs::{ + PublicInputs as QueryProofPublicInputs, + QueryPublicInputs, + }, + computational_hash_ids::AggregationOperation, + universal_circuit::{ + universal_circuit_inputs::Placeholders, universal_query_gadget::OutputValues, + }, + }, + revelation::{ + revelation_without_results_tree::{ + RevelationWithoutResultsTreeCircuit, RevelationWithoutResultsTreeWires, + }, + tests::{compute_results_from_query_proof_outputs, TestPlaceholders}, + PublicInputs, NUM_PREPROCESSING_IO, + }, + test_utils::{random_aggregation_operations, random_original_tree_proof, sample_boundary_rows_for_revelation}, + }; // L: maximum number of results // S: maximum number of items in each result @@ -381,9 +423,9 @@ mod tests { const PP: usize = 20; // Real number of the placeholders - const NUM_PLACEHOLDERS: usize = 5; + const NUM_PLACEHOLDERS: usize = 6; - const QUERY_PI_LEN: usize = crate::query::pi_len::(); + const QUERY_PI_LEN: usize = QueryProofPublicInputs::::total_len(); impl From<&TestPlaceholders> for RevelationWithoutResultsTreeCircuit { fn from(test_placeholders: &TestPlaceholders) -> Self { @@ -392,15 +434,16 @@ mod tests { } } } + #[derive(Clone, Debug)] - struct TestRevelationWithoutResultsTreeCircuit<'a> { + struct TestRevelationCircuit<'a> { c: RevelationWithoutResultsTreeCircuit, query_proof: &'a [F], original_tree_proof: &'a [F], } - impl UserCircuit for TestRevelationWithoutResultsTreeCircuit<'_> { + impl UserCircuit for TestRevelationCircuit<'_> { // Circuit wires + query proof + original tree proof (IVC proof) type Wires = ( RevelationWithoutResultsTreeWires, @@ -433,21 +476,26 @@ mod tests { ops: &[F; S], test_placeholders: &TestPlaceholders, ) -> Vec { - let [mut proof] = random_aggregation_public_inputs(ops); - - let [count_range, min_query_range, max_query_range, p_hash_range] = [ - QueryPublicInputs::NumMatching, - QueryPublicInputs::MinQuery, - QueryPublicInputs::MaxQuery, - QueryPublicInputs::PlaceholderHash, - ] - .map(QueryProofPublicInputs::::to_range); + let [mut proof] = QueryProofPublicInputs::sample_from_ops(ops); + + let [count_range, min_query_primary, max_query_primary, min_query_secondary, max_query_secondary, p_hash_range, left_row_range, right_row_range] = + [ + QueryPublicInputs::NumMatching, + QueryPublicInputs::MinPrimary, + QueryPublicInputs::MaxPrimary, + QueryPublicInputs::MinSecondary, + QueryPublicInputs::MaxSecondary, + QueryPublicInputs::PlaceholderHash, + QueryPublicInputs::LeftBoundaryRow, + QueryPublicInputs::RightBoundaryRow, + ] + .map(QueryProofPublicInputs::::to_range); // Set the count, minimum, maximum query and the placeholder hash. [ (count_range, vec![entry_count.to_field()]), - (min_query_range, test_placeholders.min_query.to_fields()), - (max_query_range, test_placeholders.max_query.to_fields()), + (min_query_primary, test_placeholders.min_query.to_fields()), + (max_query_primary, test_placeholders.max_query.to_fields()), ( p_hash_range, test_placeholders.query_placeholder_hash.to_fields(), @@ -456,11 +504,28 @@ mod tests { .into_iter() .for_each(|(range, fields)| proof[range].copy_from_slice(&fields)); + // Set boundary rows to satisfy constraints for completeness + let rng = &mut thread_rng(); + let min_secondary = U256::from_fields(&proof[min_query_secondary]); + let max_secondary = U256::from_fields(&proof[max_query_secondary]); + let placeholders = + Placeholders::new_empty(test_placeholders.min_query, test_placeholders.max_query); + let query_bounds = QueryBounds::new( + &placeholders, + Some(QueryBoundSource::Constant(min_secondary)), + Some(QueryBoundSource::Constant(max_secondary)), + ) + .unwrap(); + let (left_boundary_row, right_boundary_row) = sample_boundary_rows_for_revelation(&query_bounds, rng); + + proof[left_row_range].copy_from_slice(&left_boundary_row.to_fields()); + proof[right_row_range].copy_from_slice(&right_boundary_row.to_fields()); + proof } /// Utility function for testing the revelation circuit with results tree - fn test_revelation_without_results_tree_circuit(ops: &[F; S], entry_count: Option) { + fn test_revelation_batching_circuit(ops: &[F; S], entry_count: Option) { let rng = &mut thread_rng(); // Generate the testing placeholder data. @@ -476,7 +541,7 @@ mod tests { let original_tree_pi = OriginalTreePublicInputs::from_slice(&original_tree_proof); // Construct the test circuit. - let test_circuit = TestRevelationWithoutResultsTreeCircuit { + let test_circuit = TestRevelationCircuit { c: (&test_placeholders).into(), query_proof: &query_proof, original_tree_proof: &original_tree_proof, @@ -486,7 +551,6 @@ mod tests { let proof = run_circuit::(test_circuit); let pi = PublicInputs::<_, L, S, PH>::from_slice(&proof.public_inputs); - // Initialize the overflow flag to false. let entry_count = query_pi.num_matching_rows(); // Check the public inputs. @@ -547,57 +611,58 @@ mod tests { } #[test] - fn test_revelation_without_results_tree_simple() { + fn test_revelation_batching_simple() { // Generate the random operations and set the first operation to SUM // (not ID which should not be present in the aggregation). let mut ops: [_; S] = random_aggregation_operations(); ops[0] = AggregationOperation::SumOp.to_field(); - test_revelation_without_results_tree_circuit(&ops, None); + test_revelation_batching_circuit(&ops, None); } // Test for COUNT operation. #[test] - fn test_revelation_without_results_tree_for_op_count() { + fn test_revelation_batching_for_op_count() { // Set the first operation to COUNT. let mut ops: [_; S] = random_aggregation_operations(); ops[0] = AggregationOperation::CountOp.to_field(); - test_revelation_without_results_tree_circuit(&ops, None); + test_revelation_batching_circuit(&ops, None); } // Test for AVG operation. #[test] - fn test_revelation_without_results_tree_for_op_avg() { + fn test_revelation_batching_for_op_avg() { // Set the first operation to AVG. let mut ops: [_; S] = random_aggregation_operations(); ops[0] = AggregationOperation::AvgOp.to_field(); - test_revelation_without_results_tree_circuit(&ops, None); + test_revelation_batching_circuit(&ops, None); } // Test for AVG operation with zero entry count. #[test] - fn test_revelation_without_results_tree_for_op_avg_with_no_entries() { + fn test_revelation_batching_for_op_avg_with_no_entries() { // Set the first operation to AVG. let mut ops: [_; S] = random_aggregation_operations(); ops[0] = AggregationOperation::AvgOp.to_field(); - test_revelation_without_results_tree_circuit(&ops, Some(0)); + test_revelation_batching_circuit(&ops, Some(0)); } // Test for no AVG operation with zero entry count. #[test] - fn test_revelation_without_results_tree_for_no_op_avg_with_no_entries() { + fn test_revelation_batching_for_no_op_avg_with_no_entries() { // Initialize the all operations to SUM or COUNT (not AVG). let mut rng = thread_rng(); - let ops = std::array::from_fn(|_| { + let ops = array::from_fn(|_| { [AggregationOperation::SumOp, AggregationOperation::CountOp] .choose(&mut rng) .unwrap() .to_field() }); - test_revelation_without_results_tree_circuit(&ops, Some(0)); + test_revelation_batching_circuit(&ops, Some(0)); } } + diff --git a/verifiable-db/src/test_utils.rs b/verifiable-db/src/test_utils.rs index c42a1205b..3fa940683 100644 --- a/verifiable-db/src/test_utils.rs +++ b/verifiable-db/src/test_utils.rs @@ -3,15 +3,12 @@ use crate::{ ivc::public_inputs::H_RANGE as ORIGINAL_TREE_H_RANGE, query::{ - aggregation::{QueryBounds, QueryHashNonExistenceCircuits}, - computational_hash_ids::{ + aggregation::{QueryBounds, QueryHashNonExistenceCircuits}, batching::row_chunk::tests::BoundaryRowData, computational_hash_ids::{ AggregationOperation, ColumnIDs, Identifiers, Operation, PlaceholderIdentifier, - }, - pi_len, - public_inputs::{PublicInputs as QueryPI, PublicInputs, QueryPublicInputs}, - universal_circuit::universal_circuit_inputs::{ + }, universal_circuit::universal_circuit_inputs::{ BasicOperation, ColumnCell, InputOperand, OutputItem, Placeholders, ResultStructure, }, + public_inputs::{tests::gen_values_in_range, PublicInputs as QueryPI, QueryPublicInputs} }, revelation::NUM_PREPROCESSING_IO, }; @@ -19,16 +16,14 @@ use alloy::primitives::U256; use itertools::Itertools; use mp2_common::{ array::ToField, - types::CURVE_TARGET_LEN, utils::{Fieldable, ToFields}, F, }; use plonky2::{ - field::types::{Field, PrimeField64, Sample}, - hash::hash_types::{HashOut, NUM_HASH_OUT_ELTS}, + field::types::PrimeField64, + hash::hash_types::HashOut, plonk::config::GenericHashOut, }; -use plonky2_ecgfp5::curve::curve::Point; use rand::{prelude::SliceRandom, thread_rng, Rng}; use std::array; @@ -79,65 +74,6 @@ pub fn random_aggregation_operations() -> [F; S] { }) } -/// Generate S number of proof public input slices by the specified operations for testing. -/// The each returned proof public inputs could be constructed by -/// `PublicInputs::from_slice` function. -pub fn random_aggregation_public_inputs( - ops: &[F; S], -) -> [Vec; N] { - let [ops_range, overflow_range, index_ids_range, c_hash_range, p_hash_range] = [ - QueryPublicInputs::OpIds, - QueryPublicInputs::Overflow, - QueryPublicInputs::IndexIds, - QueryPublicInputs::ComputationalHash, - QueryPublicInputs::PlaceholderHash, - ] - .map(PublicInputs::::to_range); - - let first_value_start = PublicInputs::::to_range(QueryPublicInputs::OutputValues).start; - let is_first_op_id = - ops[0] == Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - // Generate the index ids, computational hash and placeholder hash, - // they should be same for a series of public inputs. - let mut rng = thread_rng(); - let index_ids = (0..2).map(|_| rng.gen()).collect::>().to_fields(); - let [computational_hash, placeholder_hash]: [Vec<_>; 2] = array::from_fn(|_| { - (0..NUM_HASH_OUT_ELTS) - .map(|_| rng.gen()) - .collect::>() - .to_fields() - }); - - array::from_fn(|_| { - let mut pi = (0..pi_len::()) - .map(|_| rng.gen()) - .collect::>() - .to_fields(); - - // Copy the specified operations to the proofs. - pi[ops_range.clone()].copy_from_slice(ops); - - // Set the overflow flag to a random boolean. - let overflow = F::from_bool(rng.gen()); - pi[overflow_range.clone()].copy_from_slice(&[overflow]); - - // Set the index ids, computational hash and placeholder hash, - pi[index_ids_range.clone()].copy_from_slice(&index_ids); - pi[c_hash_range.clone()].copy_from_slice(&computational_hash); - pi[p_hash_range.clone()].copy_from_slice(&placeholder_hash); - - // If the first operation is ID, set the value to a random point. - if is_first_op_id { - let first_value = Point::sample(&mut rng).to_weierstrass().to_fields(); - pi[first_value_start..first_value_start + CURVE_TARGET_LEN] - .copy_from_slice(&first_value); - } - - pi - }) -} - /// Revelation related data used for testing #[derive(Debug)] pub struct TestRevelationData { @@ -244,34 +180,67 @@ impl TestRevelationData { let computational_hash = non_existence_circuits.computational_hash(); let placeholder_hash = non_existence_circuits.placeholder_hash(); - let [mut query_pi_raw] = random_aggregation_public_inputs::<1, MAX_NUM_ITEMS_PER_OUTPUT>( + let [mut query_pi_raw] = QueryPI::::sample_from_ops( &ops_ids.try_into().unwrap(), ); - let [min_query_range, max_query_range, p_hash_range, c_hash_range] = [ - QueryPublicInputs::MinQuery, - QueryPublicInputs::MaxQuery, + let [min_query_primary, max_query_primary, min_query_secondary, max_query_secondary, p_hash_range, c_hash_range, left_row_range, right_row_range] = [ + QueryPublicInputs::MinPrimary, + QueryPublicInputs::MaxPrimary, + QueryPublicInputs::MinSecondary, + QueryPublicInputs::MaxSecondary, QueryPublicInputs::PlaceholderHash, QueryPublicInputs::ComputationalHash, + QueryPublicInputs::LeftBoundaryRow, + QueryPublicInputs::RightBoundaryRow, ] .map(QueryPI::::to_range); + + // sample left boundary row and right boundary row to satisfy revelation circuit constraints + let (left_boundary_row, right_boundary_row) = sample_boundary_rows_for_revelation(&query_bounds, rng); // Set the minimum, maximum query, placeholder hash andn computational hash to expected values. [ ( - min_query_range, + min_query_primary, query_bounds.min_query_primary().to_fields(), ), ( - max_query_range, + max_query_primary, query_bounds.max_query_primary().to_fields(), ), + ( + min_query_secondary, + query_bounds.min_query_secondary().value().to_fields(), + ), + ( + max_query_secondary, + query_bounds.max_query_secondary().value().to_fields(), + ), (p_hash_range, placeholder_hash.to_vec()), (c_hash_range, computational_hash.to_vec()), + (left_row_range, left_boundary_row.to_fields()), + (right_row_range, right_boundary_row.to_fields()) ] .into_iter() .for_each(|(range, fields)| query_pi_raw[range].copy_from_slice(&fields)); let query_pi = QueryPI::::from_slice(&query_pi_raw); + assert_eq!( + query_pi.min_primary(), + query_bounds.min_query_primary(), + ); + assert_eq!( + query_pi.max_primary(), + query_bounds.max_query_primary(), + ); + assert_eq!( + query_pi.min_secondary(), + query_bounds.min_query_secondary().value, + ); + assert_eq!( + query_pi.max_secondary(), + query_bounds.max_query_secondary().value, + ); // generate preprocessing proof public inputs let preprocessing_pi_raw = random_original_tree_proof(query_pi.tree_hash()); @@ -313,3 +282,57 @@ impl TestRevelationData { &self.query_pi_raw } } + +pub(crate) fn sample_boundary_rows_for_revelation( + query_bounds: &QueryBounds, + rng: &mut R, +) -> (BoundaryRowData, BoundaryRowData) { + let min_secondary = *query_bounds.min_query_secondary().value(); + let max_secondary = *query_bounds.max_query_secondary().value(); + let mut left_boundary_row = BoundaryRowData::sample(rng, &query_bounds); + // for predecessor of `left_boundary_row` in index tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || query_bounds.min_query_primary() == U256::ZERO { + left_boundary_row.index_node_info.predecessor_info.is_found = false; + } else { + let [predecessor_value] = gen_values_in_range( + rng, + U256::ZERO, + query_bounds.min_query_primary() - U256::from(1), + ); + left_boundary_row.index_node_info.predecessor_info.value = predecessor_value; + } + // for predecessor of `left_boundary_row` in rows tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || min_secondary == U256::ZERO { + left_boundary_row.row_node_info.predecessor_info.is_found = false; + } else { + let [predecessor_value] = + gen_values_in_range(rng, U256::ZERO, min_secondary - U256::from(1)); + left_boundary_row.row_node_info.predecessor_info.value = predecessor_value; + } + let mut right_boundary_row = BoundaryRowData::sample(rng, &query_bounds); + // for successor of `right_boundary_row` in index tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || query_bounds.max_query_primary() == U256::MAX { + right_boundary_row.index_node_info.successor_info.is_found = false; + } else { + let [successor_value] = gen_values_in_range( + rng, + query_bounds.max_query_primary() + U256::from(1), + U256::MAX, + ); + right_boundary_row.index_node_info.successor_info.value = successor_value; + } + // for successor of `right_boundary_row` in rows tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || max_secondary == U256::MAX { + right_boundary_row.row_node_info.successor_info.is_found = false; + } else { + let [successor_value] = + gen_values_in_range(rng, max_secondary + U256::from(1), U256::MAX); + right_boundary_row.row_node_info.successor_info.value = successor_value; + } + + (left_boundary_row, right_boundary_row) +} \ No newline at end of file From 593c937490b35b331e37b59a386a3ece51d95e68 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 6 Dec 2024 12:16:39 +0100 Subject: [PATCH 02/12] Restructure verifiable-db query + fix dependent crates --- Cargo.lock | 1 + groth16-framework/Cargo.toml | 1 + groth16-framework/tests/common/context.rs | 12 +- groth16-framework/tests/common/query.rs | 1 - mp2-v1/src/query/batching_planner.rs | 2 +- mp2-v1/src/query/planner.rs | 47 +- mp2-v1/tests/common/cases/mod.rs | 1 - mp2-v1/tests/common/cases/planner.rs | 418 ------- .../common/cases/query/aggregated_queries.rs | 1049 +++-------------- mp2-v1/tests/common/cases/query/mod.rs | 29 +- .../cases/query/simple_select_queries.rs | 127 +- mp2-v1/tests/common/proof_storage.rs | 16 +- parsil/src/assembler.rs | 2 +- parsil/src/bracketer.rs | 2 +- parsil/src/isolator.rs | 2 +- parsil/src/lib.rs | 2 +- parsil/src/queries.rs | 2 +- verifiable-db/Cargo.toml | 3 +- verifiable-db/src/api.rs | 2 - verifiable-db/src/lib.rs | 1 - .../child_proven_single_path_node.rs | 366 ------ .../embedded_tree_proven_single_path_node.rs | 572 --------- .../query/aggregation/full_node_index_leaf.rs | 246 ---- .../aggregation/full_node_with_one_child.rs | 412 ------- .../full_node_with_two_children.rs | 398 ------- .../query/aggregation/non_existence_inter.rs | 761 ------------ .../src/query/aggregation/partial_node.rs | 519 -------- verifiable-db/src/query/aggregation/utils.rs | 153 --- verifiable-db/src/query/api.rs | 24 +- .../src/query/batching/circuits/api.rs | 756 ------------ verifiable-db/src/query/batching/mod.rs | 2 - .../circuits/chunk_aggregation.rs | 4 +- .../src/query/{batching => }/circuits/mod.rs | 12 +- .../{batching => }/circuits/non_existence.rs | 13 +- .../circuits/row_chunk_processing.rs | 22 +- .../src/query/computational_hash_ids.rs | 2 +- verifiable-db/src/query/merkle_path.rs | 139 +-- verifiable-db/src/query/mod.rs | 6 +- .../{aggregation => }/output_computation.rs | 2 +- verifiable-db/src/query/public_inputs.rs | 128 +- .../aggregate_chunks.rs | 9 +- .../consecutive_rows.rs | 2 +- .../row_chunk => row_chunk_gadgets}/mod.rs | 277 +++-- .../row_process_gadget.rs | 0 .../universal_query_circuit.rs | 6 +- .../universal_query_gadget.rs | 4 +- .../query/{aggregation/mod.rs => utils.rs} | 115 -- verifiable-db/src/revelation/api.rs | 5 +- .../src/revelation/placeholders_check.rs | 2 +- .../revelation/revelation_unproven_offset.rs | 4 +- .../revelation_without_results_tree.rs | 2 +- verifiable-db/src/test_utils.rs | 116 +- 52 files changed, 643 insertions(+), 6156 deletions(-) delete mode 100644 mp2-v1/tests/common/cases/planner.rs delete mode 100644 verifiable-db/src/query/aggregation/child_proven_single_path_node.rs delete mode 100644 verifiable-db/src/query/aggregation/embedded_tree_proven_single_path_node.rs delete mode 100644 verifiable-db/src/query/aggregation/full_node_index_leaf.rs delete mode 100644 verifiable-db/src/query/aggregation/full_node_with_one_child.rs delete mode 100644 verifiable-db/src/query/aggregation/full_node_with_two_children.rs delete mode 100644 verifiable-db/src/query/aggregation/non_existence_inter.rs delete mode 100644 verifiable-db/src/query/aggregation/partial_node.rs delete mode 100644 verifiable-db/src/query/aggregation/utils.rs delete mode 100644 verifiable-db/src/query/batching/circuits/api.rs delete mode 100644 verifiable-db/src/query/batching/mod.rs rename verifiable-db/src/query/{batching => }/circuits/chunk_aggregation.rs (99%) rename verifiable-db/src/query/{batching => }/circuits/mod.rs (97%) rename verifiable-db/src/query/{batching => }/circuits/non_existence.rs (97%) rename verifiable-db/src/query/{batching => }/circuits/row_chunk_processing.rs (99%) rename verifiable-db/src/query/{aggregation => }/output_computation.rs (99%) rename verifiable-db/src/query/{batching/row_chunk => row_chunk_gadgets}/aggregate_chunks.rs (98%) rename verifiable-db/src/query/{batching/row_chunk => row_chunk_gadgets}/consecutive_rows.rs (99%) rename verifiable-db/src/query/{batching/row_chunk => row_chunk_gadgets}/mod.rs (71%) rename verifiable-db/src/query/{batching/row_chunk => row_chunk_gadgets}/row_process_gadget.rs (100%) rename verifiable-db/src/query/{aggregation/mod.rs => utils.rs} (81%) diff --git a/Cargo.lock b/Cargo.lock index db84f41e6..e2187e8da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2977,6 +2977,7 @@ dependencies = [ "itertools 0.13.0", "log", "mp2_common", + "mp2_test", "plonky2", "plonky2x", "rand", diff --git a/groth16-framework/Cargo.toml b/groth16-framework/Cargo.toml index a8ce89abe..7922499e1 100644 --- a/groth16-framework/Cargo.toml +++ b/groth16-framework/Cargo.toml @@ -23,6 +23,7 @@ itertools.workspace = true rand.workspace = true serial_test.workspace = true sha2.workspace = true +mp2_test = { path = "../mp2-test" } recursion_framework = { path = "../recursion-framework" } verifiable-db = { path = "../verifiable-db" } diff --git a/groth16-framework/tests/common/context.rs b/groth16-framework/tests/common/context.rs index 81cb6dbf7..dc38470bf 100644 --- a/groth16-framework/tests/common/context.rs +++ b/groth16-framework/tests/common/context.rs @@ -3,14 +3,13 @@ use super::{NUM_PREPROCESSING_IO, NUM_QUERY_IO}; use groth16_framework::{compile_and_generate_assets, utils::clone_circuit_data}; use mp2_common::{C, D, F}; +use mp2_test::circuit::TestDummyCircuit; use recursion_framework::framework_testing::TestingRecursiveCircuits; use verifiable_db::{ - api::WrapCircuitParams, - revelation::api::Parameters as RevelationParameters, - test_utils::{ + api::WrapCircuitParams, query::pi_len, revelation::api::Parameters as RevelationParameters, test_utils::{ INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, MAX_NUM_PLACEHOLDERS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, ROW_TREE_MAX_DEPTH, - }, + } }; /// Test context @@ -40,7 +39,8 @@ impl TestContext { // Generate a fake query circuit set. let query_circuits = TestingRecursiveCircuits::::default(); - + let dummy_universal_circuit = TestDummyCircuit::<{pi_len::()}>::build(); + // Create the revelation parameters. let revelation_params = RevelationParameters::< ROW_TREE_MAX_DEPTH, @@ -53,7 +53,7 @@ impl TestContext { MAX_NUM_PLACEHOLDERS, >::build( query_circuits.get_recursive_circuit_set(), // unused, so we provide a dummy one - query_circuits.get_recursive_circuit_set(), + dummy_universal_circuit.circuit_data().verifier_data(), preprocessing_circuits.get_recursive_circuit_set(), preprocessing_circuits .verifier_data_for_input_proofs::<1>() diff --git a/groth16-framework/tests/common/query.rs b/groth16-framework/tests/common/query.rs index 2a90a9468..75a7db8bd 100644 --- a/groth16-framework/tests/common/query.rs +++ b/groth16-framework/tests/common/query.rs @@ -59,7 +59,6 @@ impl TestContext { .generate_proof( input, self.query_circuits.get_recursive_circuit_set(), - self.query_circuits.get_recursive_circuit_set(), None, ) .unwrap(); diff --git a/mp2-v1/src/query/batching_planner.rs b/mp2-v1/src/query/batching_planner.rs index 29cd5c115..d0c21d029 100644 --- a/mp2-v1/src/query/batching_planner.rs +++ b/mp2-v1/src/query/batching_planner.rs @@ -11,7 +11,7 @@ use ryhope::{ Epoch, }; use verifiable_db::query::{ - batching::{NodePath, RowInput, TreePathInputs}, + api::{NodePath, RowInput, TreePathInputs}, computational_hash_ids::ColumnIDs, universal_circuit::universal_circuit_inputs::{ColumnCell, RowCells}, }; diff --git a/mp2-v1/src/query/planner.rs b/mp2-v1/src/query/planner.rs index 0cc0c1c91..305d1a848 100644 --- a/mp2-v1/src/query/planner.rs +++ b/mp2-v1/src/query/planner.rs @@ -19,8 +19,8 @@ use ryhope::{ use std::{fmt::Debug, future::Future}; use tokio_postgres::{row::Row as PsqlRow, types::ToSql, NoTls}; use verifiable_db::query::{ - aggregation::{ChildPosition, NodeInfo, QueryBounds}, - batching::TreePathInputs, + utils::{ChildPosition, NodeInfo, QueryBounds}, + api::TreePathInputs, }; use crate::indexing::{ @@ -375,50 +375,7 @@ impl< } } -/// Returns the proving plan to prove the non existence of node of the query in this row tree at -/// the epoch primary. It also returns the leaf node chosen. -/// -/// The row tree is given and specialized to psql storage since that is the only official storage -/// supported. -/// The `table_name` must be the one given to parsil settings, it is the human friendly table -/// name, i.e. the vTable name. -/// The pool is to issue specific query -/// Primary is indicating the primary index over which this row tree is looked at. -/// Settings are the parsil settings corresponding to the current SQL and current table looked at. -/// Pis contain the bounds and placeholders values. -/// TODO: we should extend ryhope to offer this API directly on the tree since it's very related. -pub async fn proving_plan_for_non_existence( - row_tree: &MerkleTreeKvDb, DBRowStorage>, - table_name: String, - pool: &DBPool, - primary: BlockPrimaryIndex, - settings: &ParsilSettings, - bounds: &QueryBounds, -) -> anyhow::Result<(RowTreeKey, UpdateTree)> -where - C: ContextProvider, -{ - let to_be_proven_node = { - let input = NonExistenceInput { - row_tree, - table_name, - pool, - settings, - bounds: bounds.clone(), - }; - input.find_row_node_for_non_existence(primary).await - }?; - let path = row_tree - // since the epoch starts at genesis we can directly give the block number ! - .lineage_at(&to_be_proven_node, primary as Epoch) - .await - .expect("node doesn't have a lineage?") - .into_full_path() - .collect_vec(); - let proving_tree = UpdateTree::from_paths([path], primary as Epoch); - Ok((to_be_proven_node.clone(), proving_tree)) -} /// Fetch a key `k` from a tree, assuming that the key is in the /// tree. Therefore, it handles differently the case when `k` is not found: /// - If `T::WIDE_LINEAGE` is true, then `k` might not be found because the diff --git a/mp2-v1/tests/common/cases/mod.rs b/mp2-v1/tests/common/cases/mod.rs index c6445467e..991c2eef8 100644 --- a/mp2-v1/tests/common/cases/mod.rs +++ b/mp2-v1/tests/common/cases/mod.rs @@ -11,7 +11,6 @@ use super::table::Table; pub mod contract; pub mod indexing; -pub mod planner; pub mod query; pub mod table_source; diff --git a/mp2-v1/tests/common/cases/planner.rs b/mp2-v1/tests/common/cases/planner.rs deleted file mode 100644 index 7b9bf58b4..000000000 --- a/mp2-v1/tests/common/cases/planner.rs +++ /dev/null @@ -1,418 +0,0 @@ -use std::{collections::HashSet, future::Future}; - -use anyhow::Result; -use log::info; -use mp2_v1::indexing::{ - block::BlockPrimaryIndex, - index::IndexNode, - row::{RowPayload, RowTreeKey}, -}; -use parsil::{assembler::DynamicCircuitPis, ParsilSettings}; -use ryhope::{storage::WideLineage, tree::NodeContext, Epoch}; - -use crate::common::{ - cases::query::aggregated_queries::prove_non_existence_row, - index_tree::MerkleIndexTree, - proof_storage::{PlaceholderValues, ProofKey, ProofStorage, QueryID}, - rowtree::MerkleRowTree, - table::{Table, TableColumns}, - TestContext, -}; - -use super::query::{aggregated_queries::prove_single_row, QueryCooking}; - -pub(crate) struct QueryPlanner<'a> { - pub(crate) query: QueryCooking, - pub(crate) pis: &'a DynamicCircuitPis, - pub(crate) ctx: &'a mut TestContext, - pub(crate) settings: &'a ParsilSettings<&'a Table>, - // useful for non existence since we need to search in both trees the places to prove - // the fact a given node doesn't exist - pub(crate) table: &'a Table, - pub(crate) columns: TableColumns, -} - -pub trait TreeInfo { - fn is_row_tree(&self) -> bool; - fn is_satisfying_query(&self, k: &K) -> bool; - fn load_proof( - &self, - ctx: &TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &K, - placeholder_values: PlaceholderValues, - ) -> Result>; - fn save_proof( - &self, - ctx: &mut TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &K, - placeholder_values: PlaceholderValues, - proof: Vec, - ) -> Result<()>; - - async fn load_or_prove_embedded( - &self, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &K, - v: &V, - ) -> Result>>; - - fn fetch_ctx_and_payload_at( - &self, - epoch: Epoch, - key: &K, - ) -> impl Future, V)>> + Send; -} - -impl TreeInfo> - for WideLineage> -{ - fn is_row_tree(&self) -> bool { - true - } - - fn is_satisfying_query(&self, k: &RowTreeKey) -> bool { - self.is_touched_key(k) - } - - fn load_proof( - &self, - ctx: &TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &RowTreeKey, - placeholder_values: PlaceholderValues, - ) -> Result> { - // TODO export that in single function - let proof_key = ProofKey::QueryAggregateRow(( - query_id.clone(), - placeholder_values, - primary, - key.clone(), - )); - ctx.storage.get_proof_exact(&proof_key) - } - - fn save_proof( - &self, - ctx: &mut TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &RowTreeKey, - placeholder_values: PlaceholderValues, - proof: Vec, - ) -> Result<()> { - // TODO export that in single function - let proof_key = ProofKey::QueryAggregateRow(( - query_id.clone(), - placeholder_values, - primary, - key.clone(), - )); - ctx.storage.store_proof(proof_key, proof) - } - - async fn load_or_prove_embedded( - &self, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &RowTreeKey, - _v: &RowPayload, - ) -> Result>> { - // TODO export that in single function - Ok(if self.is_satisfying_query(k) { - let ctx = &mut planner.ctx; - Some( - prove_single_row( - ctx, - self, - &planner.columns, - primary, - k, - planner.pis, - &planner.query, - ) - .await?, - ) - } else { - None - }) - } - - async fn fetch_ctx_and_payload_at( - &self, - epoch: Epoch, - key: &RowTreeKey, - ) -> Option<(NodeContext, RowPayload)> { - self.ctx_and_payload_at(epoch, key) - } -} - -pub struct RowInfo<'a> { - pub(crate) satisfiying_rows: HashSet, - pub(crate) tree: &'a MerkleRowTree, -} - -impl<'a> RowInfo<'a> { - pub fn no_satisfying_rows(tree: &'a MerkleRowTree) -> Self { - Self { - satisfiying_rows: Default::default(), - tree, - } - } -} - -impl TreeInfo> for RowInfo<'_> { - fn is_row_tree(&self) -> bool { - true - } - - fn is_satisfying_query(&self, k: &RowTreeKey) -> bool { - self.satisfiying_rows.contains(k) - } - - fn load_proof( - &self, - ctx: &TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &RowTreeKey, - placeholder_values: PlaceholderValues, - ) -> Result> { - let proof_key = ProofKey::QueryAggregateRow(( - query_id.clone(), - placeholder_values, - primary, - key.clone(), - )); - ctx.storage.get_proof_exact(&proof_key) - } - - fn save_proof( - &self, - ctx: &mut TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &RowTreeKey, - placeholder_values: PlaceholderValues, - proof: Vec, - ) -> Result<()> { - let proof_key = ProofKey::QueryAggregateRow(( - query_id.clone(), - placeholder_values, - primary, - key.clone(), - )); - ctx.storage.store_proof(proof_key, proof) - } - - async fn load_or_prove_embedded( - &self, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &RowTreeKey, - _v: &RowPayload, - ) -> Result>> { - Ok(if self.is_satisfying_query(k) { - let ctx = &mut planner.ctx; - Some( - prove_single_row( - ctx, - self, - &planner.columns, - primary, - k, - planner.pis, - &planner.query, - ) - .await?, - ) - } else { - None - }) - } - - async fn fetch_ctx_and_payload_at( - &self, - epoch: Epoch, - key: &RowTreeKey, - ) -> Option<(NodeContext, RowPayload)> { - self.tree.try_fetch_with_context_at(key, epoch).await - } -} - -impl TreeInfo> - for WideLineage> -{ - fn is_row_tree(&self) -> bool { - false - } - - fn is_satisfying_query(&self, k: &BlockPrimaryIndex) -> bool { - self.is_touched_key(k) - } - - fn load_proof( - &self, - ctx: &TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &BlockPrimaryIndex, - placeholder_values: PlaceholderValues, - ) -> Result> { - // TODO export that in single function - repetition - info!("loading proof for {primary} -> {key:?}"); - let proof_key = ProofKey::QueryAggregateIndex((query_id.clone(), placeholder_values, *key)); - ctx.storage.get_proof_exact(&proof_key) - } - - fn save_proof( - &self, - ctx: &mut TestContext, - query_id: &QueryID, - _primary: BlockPrimaryIndex, - key: &BlockPrimaryIndex, - placeholder_values: PlaceholderValues, - proof: Vec, - ) -> Result<()> { - // TODO export that in single function - let proof_key = ProofKey::QueryAggregateIndex((query_id.clone(), placeholder_values, *key)); - ctx.storage.store_proof(proof_key, proof) - } - - async fn load_or_prove_embedded( - &self, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &BlockPrimaryIndex, - v: &IndexNode, - ) -> Result>> { - load_or_prove_embedded_index(self, planner, primary, k, v).await - } - - async fn fetch_ctx_and_payload_at( - &self, - epoch: Epoch, - key: &BlockPrimaryIndex, - ) -> Option<(NodeContext, IndexNode)> { - self.ctx_and_payload_at(epoch, key) - } -} - -pub struct IndexInfo<'a> { - pub(crate) bounds: (BlockPrimaryIndex, BlockPrimaryIndex), - pub(crate) tree: &'a MerkleIndexTree, -} - -impl<'a> IndexInfo<'a> { - pub fn non_satisfying_info(tree: &'a MerkleIndexTree) -> Self { - Self { - // so it never returns true to is satisfying query - bounds: (BlockPrimaryIndex::MAX, BlockPrimaryIndex::MIN), - tree, - } - } -} - -impl TreeInfo> for IndexInfo<'_> { - fn is_row_tree(&self) -> bool { - false - } - - fn is_satisfying_query(&self, k: &BlockPrimaryIndex) -> bool { - self.bounds.0 <= *k && *k <= self.bounds.1 - } - - fn load_proof( - &self, - ctx: &TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &BlockPrimaryIndex, - placeholder_values: PlaceholderValues, - ) -> Result> { - //assert_eq!(primary, *key); - info!("loading proof for {primary} -> {key:?}"); - let proof_key = ProofKey::QueryAggregateIndex((query_id.clone(), placeholder_values, *key)); - ctx.storage.get_proof_exact(&proof_key) - } - - fn save_proof( - &self, - ctx: &mut TestContext, - query_id: &QueryID, - _primary: BlockPrimaryIndex, - key: &BlockPrimaryIndex, - placeholder_values: PlaceholderValues, - proof: Vec, - ) -> Result<()> { - //assert_eq!(primary, *key); - let proof_key = ProofKey::QueryAggregateIndex((query_id.clone(), placeholder_values, *key)); - ctx.storage.store_proof(proof_key, proof) - } - - async fn load_or_prove_embedded( - &self, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &BlockPrimaryIndex, - v: &IndexNode, - ) -> Result>> { - load_or_prove_embedded_index(self, planner, primary, k, v).await - } - - async fn fetch_ctx_and_payload_at( - &self, - epoch: Epoch, - key: &BlockPrimaryIndex, - ) -> Option<(NodeContext, IndexNode)> { - self.tree.try_fetch_with_context_at(key, epoch).await - } -} - -async fn load_or_prove_embedded_index< - T: TreeInfo>, ->( - info: &T, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &BlockPrimaryIndex, - v: &IndexNode, -) -> Result>> { - //assert_eq!(primary, *k); - info!("loading embedded proof for node {primary} -> {k:?}"); - Ok(if info.is_satisfying_query(k) { - // load the proof of the row root for this query, if it is already proven; - // otherwise, it means that there are no rows in the rows tree embedded in this - // node that satisfies the query bounds on secondary index, so we need to - // generate a non-existence proof for the row tree - let row_root_proof_key = ProofKey::QueryAggregateRow(( - planner.query.query.clone(), - planner.query.placeholders.placeholder_values(), - *k, - v.row_tree_root_key.clone(), - )); - let proof = match planner.ctx.storage.get_proof_exact(&row_root_proof_key) { - Ok(proof) => proof, - Err(_) => { - prove_non_existence_row(planner, *k).await?; - info!("non existence proved for {primary} -> {k:?}"); - // fetch again the generated proof - planner - .ctx - .storage - .get_proof_exact(&row_root_proof_key) - .unwrap_or_else(|_| { - panic!("non-existence root proof not found for key {row_root_proof_key:?}") - }) - } - }; - Some(proof) - } else { - None - }) -} diff --git a/mp2-v1/tests/common/cases/query/aggregated_queries.rs b/mp2-v1/tests/common/cases/query/aggregated_queries.rs index 570f8a29c..fe8daf45d 100644 --- a/mp2-v1/tests/common/cases/query/aggregated_queries.rs +++ b/mp2-v1/tests/common/cases/query/aggregated_queries.rs @@ -1,22 +1,17 @@ use plonky2::{ field::types::PrimeField64, hash::hash_types::HashOut, plonk::config::GenericHashOut, }; -use std::{ - collections::{HashMap, HashSet}, - fmt::Debug, - hash::Hash, -}; +use std::collections::HashMap; use crate::common::{ cases::{ indexing::BLOCK_COLUMN_NAME, - planner::{IndexInfo, QueryPlanner, RowInfo, TreeInfo}, - query::{QueryCooking, SqlReturn, SqlType}, + query::{QueryCooking, SqlReturn, SqlType, NUM_CHUNKS, NUM_ROWS}, table_source::BASE_VALUE, }, proof_storage::{ProofKey, ProofStorage}, rowtree::MerkleRowTree, - table::{Table, TableColumns}, + table::Table, TableInfo, }; @@ -40,389 +35,119 @@ use mp2_v1::{ block::BlockPrimaryIndex, cell::MerkleCell, row::{Row, RowPayload, RowTreeKey}, - LagrangeNode, }, - query::planner::{execute_row_query, proving_plan_for_non_existence}, - values_extraction::identifier_block_column, + query::{batching_planner::{generate_chunks_and_update_tree, UTForChunkProofs, UTKey}, planner::{execute_row_query, NonExistenceInput, TreeFetcher}}, }; use parsil::{ assembler::{DynamicCircuitPis, StaticCircuitPis}, queries::{core_keys_for_index_tree, core_keys_for_row_tree}, - ParsilSettings, DEFAULT_MAX_BLOCK_PLACEHOLDER, DEFAULT_MIN_BLOCK_PLACEHOLDER, + DEFAULT_MAX_BLOCK_PLACEHOLDER, DEFAULT_MIN_BLOCK_PLACEHOLDER, }; use ryhope::{ storage::{ - updatetree::{Next, UpdateTree, WorkplanItem}, + updatetree::{Next, WorkplanItem}, EpochKvStorage, RoEpochKvStorage, TreeTransactionalStorage, }, - tree::NodeContext, - Epoch, NodePayload, + Epoch, }; use sqlparser::ast::Query; use tokio_postgres::Row as PsqlRow; use verifiable_db::{ ivc::PublicInputs as IndexingPIS, query::{ - aggregation::{ChildPosition, NodeInfo, QueryHashNonExistenceCircuits, SubProof}, computational_hash_ids::{ColumnIDs, Identifiers}, universal_circuit::universal_circuit_inputs::{ - ColumnCell, PlaceholderId, Placeholders, RowCells, + ColumnCell, PlaceholderId, Placeholders, }, }, revelation::PublicInputs, }; use super::{ - GlobalCircuitInput, QueryCircuitInput, RevelationCircuitInput, MAX_NUM_COLUMNS, - MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, MAX_NUM_PLACEHOLDERS, MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, + GlobalCircuitInput, QueryCircuitInput, QueryPlanner, RevelationCircuitInput, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, MAX_NUM_PLACEHOLDERS }; pub type RevelationPublicInputs<'a> = PublicInputs<'a, F, MAX_NUM_OUTPUTS, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_PLACEHOLDERS>; /// Execute a query to know all the touched rows, and then call the universal circuit on all rows -#[allow(clippy::too_many_arguments)] pub(crate) async fn prove_query( - ctx: &mut TestContext, - table: &Table, - query: QueryCooking, - parsed: Query, - settings: &ParsilSettings<&Table>, - res: Vec, - metadata: MetadataHash, - pis: DynamicCircuitPis, -) -> Result<()> { - #[cfg(not(feature = "batching_circuits"))] - let res = - prove_query_non_batching(ctx, table, query, parsed, settings, res, metadata, pis).await; - #[cfg(feature = "batching_circuits")] - let res = query_batching::prove_query_batching( - ctx, table, query, parsed, settings, res, metadata, pis, - ) - .await; - res -} - -#[cfg(feature = "batching_circuits")] -mod query_batching { - use super::*; - use crate::common::cases::query::{BatchingQueryCircuitInput, NUM_CHUNKS, NUM_ROWS}; - use mp2_v1::query::{ - batching_planner::{generate_chunks_and_update_tree, UTForChunkProofs, UTKey}, - planner::{NonExistenceInput, TreeFetcher}, - }; - - #[allow(clippy::too_many_arguments)] - pub(crate) async fn prove_query_batching( - ctx: &mut TestContext, - table: &Table, - query: QueryCooking, - mut parsed: Query, - settings: &ParsilSettings<&Table>, - res: Vec, - metadata: MetadataHash, - pis: DynamicCircuitPis, - ) -> Result<()> { - let row_cache = table - .row - .wide_lineage_between( - table.row.current_epoch(), - &core_keys_for_row_tree(&query.query, settings, &pis.bounds, &query.placeholders)?, - (query.min_block as Epoch, query.max_block as Epoch), - ) - .await?; - // prove the index tree, on a single version. Both path can be taken depending if we do have - // some nodes or not - let initial_epoch = table.index.initial_epoch() as BlockPrimaryIndex; - let current_epoch = table.index.current_epoch() as BlockPrimaryIndex; - let block_range = - query.min_block.max(initial_epoch + 1)..=query.max_block.min(current_epoch); - info!( - "found {} blocks in range: {:?}", - block_range.clone().count(), - block_range - ); - let column_ids = ColumnIDs::from(&table.columns); - let query_proof_id = if block_range.is_empty() { - info!("Running INDEX TREE proving for EMPTY query"); - // no valid blocks in the query range, so we need to choose a block to prove - // non-existence. Either the one after genesis or the last one - let to_be_proven_node = if query.max_block < initial_epoch { - initial_epoch + 1 - } else if query.min_block > current_epoch { - current_epoch - } else { - bail!( - "Empty block range to be proven for query bounds {}, {}, but no node - to be proven with non-existence circuit was found. Something is wrong", - query.min_block, - query.max_block - ); - } as BlockPrimaryIndex; - let index_path = table - .index - .compute_path(&to_be_proven_node, current_epoch as Epoch) - .await - .expect( - format!("Compute path for index node with key {to_be_proven_node} failed") - .as_str(), - ); - let input = BatchingQueryCircuitInput::new_non_existence_input( - index_path, - &column_ids, - &pis.predication_operations, - &pis.result, - &query.placeholders, - &pis.bounds, - )?; - let query_proof = ctx.run_query_proof( - "batching::non_existence", - GlobalCircuitInput::BatchingQuery(input), - )?; - let proof_key = ProofKey::QueryAggregate(( - query.query.clone(), - query.placeholders.placeholder_values(), - UTKey::default(), - )); - ctx.storage.store_proof(proof_key.clone(), query_proof)?; - proof_key - } else { - info!("Running INDEX tree proving from cache"); - // Only here we can run the SQL query for index so it doesn't crash - let index_query = core_keys_for_index_tree( - current_epoch as Epoch, - (query.min_block, query.max_block), - )?; - let big_index_cache = table - .index - // The bounds here means between which versions of the tree should we look. For index tree, - // we only look at _one_ version of the tree. - .wide_lineage_between( - current_epoch as Epoch, - &index_query, - (current_epoch as Epoch, current_epoch as Epoch), - ) - .await?; - let (proven_chunks, update_tree) = - generate_chunks_and_update_tree::( - row_cache, - big_index_cache, - &column_ids, - NonExistenceInput::new( - &table.row, - table.public_name.clone(), - &table.db_pool, - settings, - &pis.bounds, - ), - current_epoch as Epoch, - ) - .await?; - info!("Root of update tree is {:?}", update_tree.root()); - let mut workplan = update_tree.into_workplan(); - let mut proof_id = None; - while let Some(Next::Ready(wk)) = workplan.next() { - let (k, is_path_end) = if let WorkplanItem::Node { k, is_path_end } = &wk { - (k, *is_path_end) - } else { - unreachable!("this update tree has been created with a batch size of 1") - }; - let proof = if is_path_end { - // this is a row chunk to be proven - let to_be_proven_chunk = proven_chunks - .get(k) - .expect(format!("chunk for key {:?} not found", k).as_str()); - let input = BatchingQueryCircuitInput::new_row_chunks_input( - &to_be_proven_chunk, - &pis.predication_operations, - &query.placeholders, - &pis.bounds, - &pis.result, - )?; - info!("Proving chunk {:?}", k); - ctx.run_query_proof( - "batching::chunk_processing", - GlobalCircuitInput::BatchingQuery(input), - ) - } else { - let children_keys = workplan.t.get_children_keys(&k); - info!("children keys: {:?}", children_keys); - // fetch the proof for each child from the storage - let child_proofs = children_keys - .into_iter() - .map(|child_key| { - let proof_key = ProofKey::QueryAggregate(( - query.query.clone(), - query.placeholders.placeholder_values(), - child_key, - )); - ctx.storage.get_proof_exact(&proof_key) - }) - .collect::>>()?; - let input = - BatchingQueryCircuitInput::new_chunk_aggregation_input(&child_proofs)?; - info!("Aggregating chunk {:?}", k); - ctx.run_query_proof( - "batching::chunk_aggregation", - GlobalCircuitInput::BatchingQuery(input), - ) - }?; - let proof_key = ProofKey::QueryAggregate(( - query.query.clone(), - query.placeholders.placeholder_values(), - k.clone(), - )); - ctx.storage.store_proof(proof_key.clone(), proof)?; - proof_id = Some(proof_key); - workplan.done(&wk)?; - } - proof_id.unwrap() - }; - - info!("proving revelation"); - - let proof = prove_revelation( - ctx, - &query, - &pis, - table.index.current_epoch(), - &query_proof_id, - ) - .await?; - info!("Revelation proof done! Checking public inputs..."); - - // get `StaticPublicInputs`, i.e., the data about the query available only at query registration time, - // to check the public inputs - let pis = parsil::assembler::assemble_static(&parsed, &settings)?; - // get number of matching rows - let mut exec_query = parsil::executor::generate_query_keys(&mut parsed, &settings)?; - let query_params = exec_query.convert_placeholders(&query.placeholders); - let num_touched_rows = execute_row_query( - &table.db_pool, - &exec_query - .normalize_placeholder_names() - .to_pgsql_string_with_placeholder(), - &query_params, - ) - .await? - .len(); - - check_final_outputs( - proof, - ctx, - table, - &query, - &pis, - table.index.current_epoch(), - num_touched_rows, - res, - metadata, - )?; - info!("Revelation done!"); - Ok(()) - } -} - -/// Execute a query to know all the touched rows, and then call the universal circuit on all rows -#[allow(clippy::too_many_arguments)] -pub(crate) async fn prove_query_non_batching( - ctx: &mut TestContext, - table: &Table, - query: QueryCooking, mut parsed: Query, - settings: &ParsilSettings<&Table>, res: Vec, metadata: MetadataHash, - pis: DynamicCircuitPis, + planner: &mut QueryPlanner<'_>, ) -> Result<()> { - let row_cache = table + let row_cache = planner.table .row .wide_lineage_between( - table.row.current_epoch(), - &core_keys_for_row_tree(&query.query, settings, &pis.bounds, &query.placeholders)?, - (query.min_block as Epoch, query.max_block as Epoch), + planner.table.row.current_epoch(), + &core_keys_for_row_tree(&planner.query.query, planner.settings, &planner.pis.bounds, &planner.query.placeholders)?, + (planner.query.min_block as Epoch, planner.query.max_block as Epoch), ) .await?; - // the query to use to fetch all the rows keys involved in the result tree. - let pis = parsil::assembler::assemble_dynamic(&parsed, settings, &query.placeholders)?; - let row_keys_per_epoch = row_cache.keys_by_epochs(); - let mut planner = QueryPlanner { - ctx, - query: query.clone(), - settings, - pis: &pis, - table, - columns: table.columns.clone(), - }; - - // prove the different versions of the row tree for each of the involved rows for each block - for (epoch, keys) in row_keys_per_epoch { - let up = row_cache - .update_tree_for(epoch as Epoch) - .expect("this epoch should exist"); - let info = RowInfo { - tree: &table.row, - satisfiying_rows: keys, - }; - prove_query_on_tree(&mut planner, info, up, epoch as BlockPrimaryIndex).await?; - } - // prove the index tree, on a single version. Both path can be taken depending if we do have // some nodes or not - let initial_epoch = table.index.initial_epoch() as BlockPrimaryIndex; - let current_epoch = table.index.current_epoch() as BlockPrimaryIndex; - let block_range = query.min_block.max(initial_epoch + 1)..=query.max_block.min(current_epoch); + let initial_epoch = planner.table.index.initial_epoch() as BlockPrimaryIndex; + let current_epoch = planner.table.index.current_epoch() as BlockPrimaryIndex; + let block_range = + planner.query.min_block.max(initial_epoch + 1)..=planner.query.max_block.min(current_epoch); info!( "found {} blocks in range: {:?}", block_range.clone().count(), block_range ); - if block_range.is_empty() { + let column_ids = ColumnIDs::from(&planner.table.columns); + let query_proof_id = if block_range.is_empty() { info!("Running INDEX TREE proving for EMPTY query"); // no valid blocks in the query range, so we need to choose a block to prove // non-existence. Either the one after genesis or the last one - let to_be_proven_node = if query.max_block < initial_epoch { + let to_be_proven_node = if planner.query.max_block < initial_epoch { initial_epoch + 1 - } else if query.min_block > current_epoch { + } else if planner.query.min_block > current_epoch { current_epoch } else { bail!( "Empty block range to be proven for query bounds {}, {}, but no node to be proven with non-existence circuit was found. Something is wrong", - query.min_block, - query.max_block + planner.query.min_block, + planner.query.max_block ); } as BlockPrimaryIndex; - prove_non_existence_index(&mut planner, to_be_proven_node).await?; - // we get the lineage of the node that proves the non existence of the index nodes - // required for the query. We specify the epoch at which we want to get this lineage as - // of the current epoch. - let lineage = table + let index_path = planner.table .index - .lineage_at(&to_be_proven_node, current_epoch as Epoch) + .compute_path(&to_be_proven_node, current_epoch as Epoch) .await - .expect("can't get lineage") - .into_full_path() - .collect(); - let up = UpdateTree::from_path(lineage, current_epoch as Epoch); - let info = IndexInfo { - tree: &table.index, - bounds: (query.min_block, query.max_block), - }; - prove_query_on_tree( - &mut planner, - info, - up, - table.index.current_epoch() as BlockPrimaryIndex, - ) - .await?; + .expect( + format!("Compute path for index node with key {to_be_proven_node} failed") + .as_str(), + ); + let input = QueryCircuitInput::new_non_existence_input( + index_path, + &column_ids, + &planner.pis.predication_operations, + &planner.pis.result, + &planner.query.placeholders, + &planner.pis.bounds, + )?; + let query_proof = planner.ctx.run_query_proof( + "batching::non_existence", + GlobalCircuitInput::Query(input), + )?; + let proof_key = ProofKey::QueryAggregate(( + planner.query.query.clone(), + planner.query.placeholders.placeholder_values(), + UTKey::default(), + )); + planner.ctx.storage.store_proof(proof_key.clone(), query_proof)?; + proof_key } else { info!("Running INDEX tree proving from cache"); // Only here we can run the SQL query for index so it doesn't crash - let index_query = - core_keys_for_index_tree(current_epoch as Epoch, (query.min_block, query.max_block))?; - let big_index_cache = table + let index_query = core_keys_for_index_tree( + current_epoch as Epoch, + (planner.query.min_block, planner.query.max_block), + )?; + let big_index_cache = planner.table .index // The bounds here means between which versions of the tree should we look. For index tree, // we only look at _one_ version of the tree. @@ -432,47 +157,102 @@ pub(crate) async fn prove_query_non_batching( (current_epoch as Epoch, current_epoch as Epoch), ) .await?; - // since we only analyze the index tree for one epoch - assert_eq!(big_index_cache.keys_by_epochs().len(), 1); - // This is ok because the cache only have the block that are in the range so the - // filter_check is gonna return the same thing - // TOOD: @franklin is that correct ? - let up = big_index_cache - // this is the epoch we choose how to prove - .update_tree_for(current_epoch as Epoch) - .expect("this epoch should exist"); - prove_query_on_tree( - &mut planner, - big_index_cache, - up, - table.index.current_epoch() as BlockPrimaryIndex, - ) - .await?; - } + let (proven_chunks, update_tree) = + generate_chunks_and_update_tree::( + row_cache, + big_index_cache, + &column_ids, + NonExistenceInput::new( + &planner.table.row, + planner.table.public_name.clone(), + &planner.table.db_pool, + planner.settings, + &planner.pis.bounds, + ), + current_epoch as Epoch, + ) + .await?; + info!("Root of update tree is {:?}", update_tree.root()); + let mut workplan = update_tree.into_workplan(); + let mut proof_id = None; + while let Some(Next::Ready(wk)) = workplan.next() { + let (k, is_path_end) = if let WorkplanItem::Node { k, is_path_end } = &wk { + (k, *is_path_end) + } else { + unreachable!("this update tree has been created with a batch size of 1") + }; + let proof = if is_path_end { + // this is a row chunk to be proven + let to_be_proven_chunk = proven_chunks + .get(k) + .expect(format!("chunk for key {:?} not found", k).as_str()); + let input = QueryCircuitInput::new_row_chunks_input( + &to_be_proven_chunk, + &planner.pis.predication_operations, + &planner.query.placeholders, + &planner.pis.bounds, + &planner.pis.result, + )?; + info!("Proving chunk {:?}", k); + planner.ctx.run_query_proof( + "batching::chunk_processing", + GlobalCircuitInput::Query(input), + ) + } else { + let children_keys = workplan.t.get_children_keys(&k); + info!("children keys: {:?}", children_keys); + // fetch the proof for each child from the storage + let child_proofs = children_keys + .into_iter() + .map(|child_key| { + let proof_key = ProofKey::QueryAggregate(( + planner.query.query.clone(), + planner.query.placeholders.placeholder_values(), + child_key, + )); + planner.ctx.storage.get_proof_exact(&proof_key) + }) + .collect::>>()?; + let input = + QueryCircuitInput::new_chunk_aggregation_input(&child_proofs)?; + info!("Aggregating chunk {:?}", k); + planner.ctx.run_query_proof( + "batching::chunk_aggregation", + GlobalCircuitInput::Query(input), + ) + }?; + let proof_key = ProofKey::QueryAggregate(( + planner.query.query.clone(), + planner.query.placeholders.placeholder_values(), + k.clone(), + )); + planner.ctx.storage.store_proof(proof_key.clone(), proof)?; + proof_id = Some(proof_key); + workplan.done(&wk)?; + } + proof_id.unwrap() + }; - info!("Query proofs done! Generating revelation proof..."); - let root_key = table - .index - .root_at(table.index.current_epoch()) - .await - .unwrap(); - let proof_key = ProofKey::QueryAggregateIndex(( - query.query.clone(), - query.placeholders.placeholder_values(), - root_key, - )); - let proof = - prove_revelation(ctx, &query, &pis, table.index.current_epoch(), &proof_key).await?; + info!("proving revelation"); + + let proof = prove_revelation( + planner.ctx, + &planner.query, + &planner.pis, + planner.table.index.current_epoch(), + &query_proof_id, + ) + .await?; info!("Revelation proof done! Checking public inputs..."); // get `StaticPublicInputs`, i.e., the data about the query available only at query registration time, // to check the public inputs - let pis = parsil::assembler::assemble_static(&parsed, settings)?; + let pis = parsil::assembler::assemble_static(&parsed, planner.settings)?; // get number of matching rows - let mut exec_query = parsil::executor::generate_query_keys(&mut parsed, settings)?; - let query_params = exec_query.convert_placeholders(&query.placeholders); + let mut exec_query = parsil::executor::generate_query_keys(&mut parsed, planner.settings)?; + let query_params = exec_query.convert_placeholders(&planner.query.placeholders); let num_touched_rows = execute_row_query( - &table.db_pool, + &planner.table.db_pool, &exec_query .normalize_placeholder_names() .to_pgsql_string_with_placeholder(), @@ -483,11 +263,11 @@ pub(crate) async fn prove_query_non_batching( check_final_outputs( proof, - ctx, - table, - &query, + planner.ctx, + planner.table, + &planner.query, &pis, - table.index.current_epoch(), + planner.table.index.current_epoch(), num_touched_rows, res, metadata, @@ -629,541 +409,6 @@ pub(crate) fn check_final_outputs( Ok(()) } -/// Generic function as to how to handle the aggregation. It handles both aggregation in the row -/// tree as well as in the index tree the same way. The TreeInfo trait is just here to bring some -/// context, so savign and loading the proof at the right location depending if it's a row or index -/// tree -/// clippy doesn't see that it can not be done -#[allow(clippy::needless_lifetimes)] -async fn prove_query_on_tree<'a, I, K, V>( - planner: &mut QueryPlanner<'a>, - info: I, - update: UpdateTree, - primary: BlockPrimaryIndex, -) -> Result> -where - I: TreeInfo, - V: NodePayload + Send + Sync + LagrangeNode + Clone, - K: Debug + Hash + Clone + Eq + Sync + Send, -{ - let query_id = planner.query.query.clone(); - let placeholder_values = planner.query.placeholders.placeholder_values(); - let mut workplan = update.into_workplan(); - let mut proven_nodes = HashSet::new(); - let fetch_only_proven_child = |nctx: NodeContext, - cctx: &TestContext, - proven: &HashSet| - -> (ChildPosition, Vec) { - let (child_key, pos) = match (nctx.left, nctx.right) { - (Some(left), Some(right)) => { - assert!( - proven.contains(&left) ^ proven.contains(&right), - "only one child should be already proven, not both" - ); - if proven.contains(&left) { - (left, ChildPosition::Left) - } else { - (right, ChildPosition::Right) - } - } - (Some(left), None) if proven.contains(&left) => (left, ChildPosition::Left), - (None, Some(right)) if proven.contains(&right) => (right, ChildPosition::Right), - _ => panic!("stg's wrong in the tree"), - }; - let child_proof = info - .load_proof( - cctx, - &query_id, - primary, - &child_key, - placeholder_values.clone(), - ) - .expect("key should already been proven"); - (pos, child_proof) - }; - while let Some(Next::Ready(wk)) = workplan.next() { - let k = wk.k(); - // closure performing all the operations necessary beofre jumping to the next iteration - let mut end_iteration = |proven_nodes: &mut HashSet| -> Result<()> { - proven_nodes.insert(k.clone()); - workplan.done(&wk)?; - Ok(()) - }; - // since epoch starts at genesis now, we can directly give the value of the block - // number as epoch number - let (node_ctx, node_payload) = info - .fetch_ctx_and_payload_at(primary as Epoch, k) - .await - .expect("cache is not full"); - let is_satisfying_query = info.is_satisfying_query(k); - let embedded_proof = info - .load_or_prove_embedded(planner, primary, k, &node_payload) - .await; - if node_ctx.is_leaf() && info.is_row_tree() { - // NOTE: if it is a leaf of the row tree, then there is no need to prove anything, - // since we're not "aggregating" any from below. For the index tree however, we - // need to always generate an aggregate proof. Therefore, in this test, we just copy the - // proof to the expected aggregation location and move on. Note that we need to - // save the proof only if the current row is satisfying the query: indeed, if - // this not the case, then the proof should have already been generated and stored - // with the non-existence circuit - if is_satisfying_query { - // unwrap is safe since we are guaranteed the row is satisfying the query - info.save_proof( - planner.ctx, - &query_id, - primary, - k, - placeholder_values.clone(), - embedded_proof?.unwrap(), - )?; - } - - end_iteration(&mut proven_nodes)?; - continue; - } - - // In the case we haven't proven anything under this node, it's the single path case - // It is sufficient to check if this node is one of the leaves we in this update tree.Note - // it is not the same meaning as a "leaf of a tree", here it just means is it the first - // node in the merkle path. - let (k, is_path_end) = if let WorkplanItem::Node { k, is_path_end } = &wk { - (k, *is_path_end) - } else { - unreachable!("this update tree has been created with a batch size of 1") - }; - - let (name, input) = if is_path_end { - info!("node {primary} -> {k:?} is at path end"); - if !is_satisfying_query { - // if the node of the key does not satisfy the query, but this node is at the end of - // a path to be proven, then it means we are in a tree with no satisfying nodes, and - // so the current node is the node chosen to be proven with non-existence circuits. - // Since the node has already been proven, we just save the proof and we continue - end_iteration(&mut proven_nodes)?; - continue; - } - assert!( - info.is_satisfying_query(k), - "first node in merkle path should always be a valid query one" - ); - let (node_info, left_info, right_info) = - // we can use primary as epoch now that tree stores epoch from genesis - get_node_info(&info, k, primary as Epoch).await; - ( - "querying::aggregation::single", - QueryCircuitInput::new_single_path( - SubProof::new_embedded_tree_proof(embedded_proof?.unwrap())?, - left_info, - right_info, - node_info, - info.is_row_tree(), - &planner.pis.bounds, - ) - .expect("can't create leaf input"), - ) - } else { - // here we are guaranteed there is a node below that we have already proven - // It can not be a single path with the embedded tree only since that falls into the - // previous category ("is_path_end" == true) since update plan starts by the "leaves" - // of all the paths it has been given. - // So it means There is at least one child of this node that we have proven before. - // If this node is satisfying query, then we use One/TwoProvenChildNode, - // If this node is not in the query touched rows, we use a SinglePath with proven child tree. - // - if !is_satisfying_query { - let (child_pos, child_proof) = - fetch_only_proven_child(node_ctx, planner.ctx, &proven_nodes); - let (node_info, left_info, right_info) = get_node_info( - &info, - k, - // we can use primary as epoch since storage starts epoch at genesis - primary as Epoch, - ) - .await; - // we look which child is the one to load from storage, the one we already proved - ( - "querying::aggregation::single", - QueryCircuitInput::new_single_path( - SubProof::new_child_proof(child_proof, child_pos)?, - left_info, - right_info, - node_info, - info.is_row_tree(), - &planner.pis.bounds, - ) - .expect("can't create leaf input"), - ) - } else { - // this case is easy, since all that's left is partial or full where both - // child(ren) and current node belong to query - let is_correct_left = node_ctx.left.is_some() - && proven_nodes.contains(node_ctx.left.as_ref().unwrap()); - let is_correct_right = node_ctx.right.is_some() - && proven_nodes.contains(node_ctx.right.as_ref().unwrap()); - if is_correct_left && is_correct_right { - // full node case - let left_proof = info.load_proof( - planner.ctx, - &query_id, - primary, - node_ctx.left.as_ref().unwrap(), - placeholder_values.clone(), - )?; - let right_proof = info.load_proof( - planner.ctx, - &query_id, - primary, - node_ctx.right.as_ref().unwrap(), - placeholder_values.clone(), - )?; - ( - "querying::aggregation::full", - QueryCircuitInput::new_full_node( - left_proof, - right_proof, - embedded_proof?.expect("should be a embedded_proof here"), - info.is_row_tree(), - &planner.pis.bounds, - ) - .expect("can't create full node circuit input"), - ) - } else { - // partial case - let (child_pos, child_proof) = - fetch_only_proven_child(node_ctx, planner.ctx, &proven_nodes); - let (_, left_info, right_info) = - get_node_info(&info, k, primary as Epoch).await; - let unproven = match child_pos { - ChildPosition::Left => right_info, - ChildPosition::Right => left_info, - }; - ( - "querying::aggregation::partial", - QueryCircuitInput::new_partial_node( - child_proof, - embedded_proof?.expect("should be an embedded_proof here too"), - unproven, - child_pos, - info.is_row_tree(), - &planner.pis.bounds, - ) - .expect("can't build new partial node input"), - ) - } - } - }; - info!("AGGREGATE query proof RUNNING for {primary} -> {k:?} "); - let proof = planner - .ctx - .run_query_proof(name, GlobalCircuitInput::Query(input))?; - info.save_proof( - planner.ctx, - &query_id, - primary, - k, - placeholder_values.clone(), - proof, - )?; - info!("query proof DONE for {primary} -> {k:?} "); - end_iteration(&mut proven_nodes)?; - } - Ok(vec![]) -} - -// TODO: make it recursive with async - tentative in `fetch_child_info` but it doesn't work, -// recursion with async is weird. -pub(crate) async fn get_node_info>( - lookup: &T, - k: &K, - at: Epoch, -) -> (NodeInfo, Option, Option) -where - K: Debug + Hash + Clone + Send + Sync + Eq, - // NOTICE the ToValue here to get the value associated to a node - V: NodePayload + Send + Sync + LagrangeNode + Clone, -{ - // look at the left child first then right child, then build the node info - let (ctx, node_payload) = lookup - .fetch_ctx_and_payload_at(at, k) - .await - .expect("cache not filled"); - // this looks at the value of a child node (left and right), and fetches the grandchildren - // information to be able to build their respective node info. - let fetch_ni = async |k: Option| -> (Option, Option) { - match k { - None => (None, None), - Some(child_k) => { - let (child_ctx, child_payload) = lookup - .fetch_ctx_and_payload_at(at, &child_k) - .await - .expect("cache not filled"); - // we need the grand child hashes for constructing the node info of the - // children of the node in argument - let child_left_hash = match child_ctx.left { - Some(left_left_k) => { - let (_, payload) = lookup - .fetch_ctx_and_payload_at(at, &left_left_k) - .await - .expect("cache not filled"); - Some(payload.hash()) - } - None => None, - }; - let child_right_hash = match child_ctx.right { - Some(left_right_k) => { - let (_, payload) = lookup - .fetch_ctx_and_payload_at(at, &left_right_k) - .await - .expect("cache not full"); - Some(payload.hash()) - } - None => None, - }; - let left_ni = NodeInfo::new( - &child_payload.embedded_hash(), - child_left_hash.as_ref(), - child_right_hash.as_ref(), - child_payload.value(), - child_payload.min(), - child_payload.max(), - ); - (Some(left_ni), Some(child_payload.hash())) - } - } - }; - let (left_node, left_hash) = fetch_ni(ctx.left).await; - let (right_node, right_hash) = fetch_ni(ctx.right).await; - ( - NodeInfo::new( - &node_payload.embedded_hash(), - left_hash.as_ref(), - right_hash.as_ref(), - node_payload.value(), - node_payload.min(), - node_payload.max(), - ), - left_node, - right_node, - ) -} - -pub fn generate_non_existence_proof( - node_info: NodeInfo, - left_child_info: Option, - right_child_info: Option, - primary: BlockPrimaryIndex, - planner: &mut QueryPlanner<'_>, - is_rows_tree_node: bool, -) -> Result> { - let index_ids = [ - planner.table.columns.primary_column().identifier, - planner.table.columns.secondary_column().identifier, - ]; - assert_eq!(index_ids[0], identifier_block_column()); - let column_ids = ColumnIDs::new( - index_ids[0], - index_ids[1], - planner - .table - .columns - .non_indexed_columns() - .iter() - .map(|column| column.identifier) - .collect_vec(), - ); - let query_hashes = QueryHashNonExistenceCircuits::new::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_ITEMS_PER_OUTPUT, - >( - &column_ids, - &planner.pis.predication_operations, - &planner.pis.result, - &planner.query.placeholders, - &planner.pis.bounds, - is_rows_tree_node, - )?; - let input = QueryCircuitInput::new_non_existence_input( - node_info, - left_child_info, - right_child_info, - U256::from(primary), - &index_ids, - &planner.pis.query_aggregations, - query_hashes, - is_rows_tree_node, - &planner.pis.bounds, - &planner.query.placeholders, - )?; - planner - .ctx - .run_query_proof("querying::non_existence", GlobalCircuitInput::Query(input)) -} - -/// Generate a proof for a node of the index tree which is outside of the query bounds -async fn prove_non_existence_index( - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, -) -> Result<()> { - let tree = &planner.table.index; - let current_epoch = tree.current_epoch(); - let (node_info, left_child_info, right_child_info) = get_node_info( - &IndexInfo::non_satisfying_info(tree), - &primary, - current_epoch, - ) - .await; - let proof_key = ProofKey::QueryAggregateIndex(( - planner.query.query.clone(), - planner.query.placeholders.placeholder_values(), - primary, - )); - info!("Non-existence circuit proof RUNNING for {current_epoch} -> {primary} "); - let proof = generate_non_existence_proof( - node_info, - left_child_info, - right_child_info, - primary, - planner, - false, - ) - .unwrap_or_else(|_| { - panic!("unable to generate non-existence proof for {current_epoch} -> {primary}") - }); - info!("Non-existence circuit proof DONE for {current_epoch} -> {primary} "); - planner.ctx.storage.store_proof(proof_key, proof.clone())?; - - Ok(()) -} - -pub async fn prove_non_existence_row( - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, -) -> Result<()> { - let (chosen_node, plan) = proving_plan_for_non_existence( - &planner.table.row, - planner.table.public_name.clone(), - &planner.table.db_pool, - primary, - planner.settings, - &planner.pis.bounds, - ) - .await?; - let (node_info, left_child_info, right_child_info) = - mp2_v1::query::planner::get_node_info(&planner.table.row, &chosen_node, primary as Epoch) - .await; - - let proof_key = ProofKey::QueryAggregateRow(( - planner.query.query.clone(), - planner.query.placeholders.placeholder_values(), - primary, - chosen_node.clone(), - )); - info!( - "Non-existence circuit proof RUNNING for {primary} -> {:?} ", - proof_key - ); - let proof = generate_non_existence_proof( - node_info, - left_child_info, - right_child_info, - primary, - planner, - true, - ) - .unwrap_or_else(|_| { - panic!( - "unable to generate non-existence proof for {primary} -> {:?}", - chosen_node - ) - }); - info!( - "Non-existence circuit proof DONE for {primary} -> {:?} ", - chosen_node - ); - planner.ctx.storage.store_proof(proof_key, proof.clone())?; - - let tree_info = RowInfo::no_satisfying_rows(&planner.table.row); - let mut planner = QueryPlanner { - ctx: planner.ctx, - table: planner.table, - query: planner.query.clone(), - pis: planner.pis, - columns: planner.columns.clone(), - settings: planner.settings, - }; - prove_query_on_tree(&mut planner, tree_info, plan, primary).await?; - - Ok(()) -} - -pub async fn prove_single_row>>( - ctx: &mut TestContext, - tree: &T, - columns: &TableColumns, - primary: BlockPrimaryIndex, - row_key: &RowTreeKey, - pis: &DynamicCircuitPis, - query: &QueryCooking, -) -> Result> { - // 1. Get the all the cells including primary and secondary index - // Note we can use the primary as epoch since now epoch == primary in the storage - let (row_ctx, row_payload) = tree - .fetch_ctx_and_payload_at(primary as Epoch, row_key) - .await - .expect("cache not full"); - - // API is gonna change on this but right now, we have to sort all the "rest" cells by index - // in the tree, and put the primary one and secondary one in front - let rest_cells = columns - .non_indexed_columns() - .iter() - .map(|tc| tc.identifier) - .filter_map(|id| { - row_payload - .cells - .find_by_column(id) - .map(|info| ColumnCell::new(id, info.value)) - }) - .collect::>(); - - let secondary_cell = ColumnCell::new( - row_payload.secondary_index_column, - row_payload.secondary_index_value(), - ); - let primary_cell = ColumnCell::new(identifier_block_column(), U256::from(primary)); - let row = RowCells::new(primary_cell, secondary_cell, rest_cells); - // 2. create input - let input = QueryCircuitInput::new_universal_circuit( - &row, - &pis.predication_operations, - &pis.result, - &query.placeholders, - row_ctx.is_leaf(), - &pis.bounds, - ) - .expect("unable to create universal query circuit inputs"); - // 3. run proof if not ran already - let proof_key = ProofKey::QueryUniversal(( - query.query.clone(), - query.placeholders.placeholder_values(), - primary, - row_key.clone(), - )); - let proof = { - info!("Universal query proof RUNNING for {primary} -> {row_key:?} "); - let proof = ctx - .run_query_proof("querying::universal", GlobalCircuitInput::Query(input)) - .expect("unable to generate universal proof for {epoch} -> {row_key:?}"); - info!("Universal query proof DONE for {primary} -> {row_key:?} "); - ctx.storage.store_proof(proof_key, proof.clone())?; - proof - }; - Ok(proof) -} - type BlockRange = (BlockPrimaryIndex, BlockPrimaryIndex); pub(crate) async fn cook_query_between_blocks( diff --git a/mp2-v1/tests/common/cases/query/mod.rs b/mp2-v1/tests/common/cases/query/mod.rs index abd11bf9b..d5d0aad4e 100644 --- a/mp2-v1/tests/common/cases/query/mod.rs +++ b/mp2-v1/tests/common/cases/query/mod.rs @@ -11,7 +11,7 @@ use log::info; use mp2_v1::{ api::MetadataHash, indexing::block::BlockPrimaryIndex, query::planner::execute_row_query, }; -use parsil::{parse_and_validate, utils::ParsilSettingsBuilder, PlaceholderSettings}; +use parsil::{assembler::DynamicCircuitPis, parse_and_validate, utils::ParsilSettingsBuilder, ParsilSettings, PlaceholderSettings}; use simple_select_queries::{ cook_query_no_matching_rows, cook_query_too_big_offset, cook_query_with_distinct, cook_query_with_matching_rows, cook_query_with_max_num_matching_rows, @@ -23,7 +23,7 @@ use verifiable_db::query::{ computational_hash_ids::Output, universal_circuit::universal_circuit_inputs::Placeholders, }; -use crate::common::{cases::planner::QueryPlanner, table::Table, TableInfo, TestContext}; +use crate::common::{table::{Table, TableColumns}, TableInfo, TestContext}; use super::table_source::TableSource; @@ -55,14 +55,6 @@ pub type GlobalCircuitInput = verifiable_db::api::QueryCircuitInput< >; pub type QueryCircuitInput = verifiable_db::query::api::CircuitInput< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_ITEMS_PER_OUTPUT, ->; - -#[cfg(feature = "batching_circuits")] -pub type BatchingQueryCircuitInput = verifiable_db::query::batching::CircuitInput< NUM_CHUNKS, NUM_ROWS, ROW_TREE_MAX_DEPTH, @@ -94,6 +86,17 @@ pub struct QueryCooking { pub(crate) offset: Option, } +pub(crate) struct QueryPlanner<'a> { + pub(crate) query: QueryCooking, + pub(crate) pis: &'a DynamicCircuitPis, + pub(crate) ctx: &'a mut TestContext, + pub(crate) settings: &'a ParsilSettings<&'a Table>, + // useful for non existence since we need to search in both trees the places to prove + // the fact a given node doesn't exist + pub(crate) table: &'a Table, + pub(crate) columns: TableColumns, +} + pub async fn test_query(ctx: &mut TestContext, table: Table, t: TableInfo) -> Result<()> { match &t.source { TableSource::Mapping(_) | TableSource::Merge(_) => query_mapping(ctx, &table, &t).await?, @@ -215,14 +218,10 @@ async fn test_query_mapping( match pis.result.query_variant() { Output::Aggregation => { prove_aggregation_query( - ctx, - table, - query_info, parsed, - &settings, res, *table_hash, - pis, + &mut planner, ) .await } diff --git a/mp2-v1/tests/common/cases/query/simple_select_queries.rs b/mp2-v1/tests/common/cases/query/simple_select_queries.rs index e29226a8b..370839426 100644 --- a/mp2-v1/tests/common/cases/query/simple_select_queries.rs +++ b/mp2-v1/tests/common/cases/query/simple_select_queries.rs @@ -5,12 +5,11 @@ use log::info; use mp2_common::types::HashOutput; use mp2_v1::{ api::MetadataHash, - indexing::{block::BlockPrimaryIndex, row::RowTreeKey, LagrangeNode}, - query::planner::execute_row_query, + indexing::{block::BlockPrimaryIndex, row::{RowPayload, RowTreeKey}, LagrangeNode}, + query::planner::{execute_row_query, get_node_info, TreeFetcher}, values_extraction::identifier_block_column, }; use parsil::{ - executor::generate_query_execution_with_keys, DEFAULT_MAX_BLOCK_PLACEHOLDER, - DEFAULT_MIN_BLOCK_PLACEHOLDER, + assembler::DynamicCircuitPis, executor::generate_query_execution_with_keys, DEFAULT_MAX_BLOCK_PLACEHOLDER, DEFAULT_MIN_BLOCK_PLACEHOLDER }; use ryhope::{ storage::{pgsql::ToFromBytea, RoEpochKvStorage}, @@ -21,9 +20,7 @@ use std::{fmt::Debug, hash::Hash}; use tokio_postgres::Row as PgSqlRow; use verifiable_db::{ query::{ - aggregation::{ChildPosition, NodeInfo}, - computational_hash_ids::ColumnIDs, - universal_circuit::universal_circuit_inputs::{PlaceholderId, Placeholders}, + computational_hash_ids::ColumnIDs, universal_circuit::universal_circuit_inputs::{ColumnCell, PlaceholderId, Placeholders, RowCells}, utils::{ChildPosition, NodeInfo} }, revelation::{api::MatchingRow, RowPath}, test_utils::MAX_NUM_OUTPUTS, @@ -32,20 +29,19 @@ use verifiable_db::{ use crate::common::{ cases::{ indexing::BLOCK_COLUMN_NAME, - planner::{IndexInfo, QueryPlanner, RowInfo, TreeInfo}, query::{ aggregated_queries::{ - check_final_outputs, find_longest_lived_key, get_node_info, prove_single_row, + check_final_outputs, find_longest_lived_key, }, - GlobalCircuitInput, RevelationCircuitInput, SqlReturn, SqlType, + GlobalCircuitInput, RevelationCircuitInput, SqlReturn, SqlType, QueryPlanner, }, }, proof_storage::{ProofKey, ProofStorage}, - table::Table, + table::{Table, TableColumns}, TableInfo, }; -use super::QueryCooking; +use super::{QueryCircuitInput, QueryCooking, TestContext}; pub(crate) async fn prove_query( mut parsed: Query, @@ -80,24 +76,12 @@ pub(crate) async fn prove_query( }) .collect::>>()?; // compute input for each matching row - let row_tree_info = RowInfo { - satisfiying_rows: matching_rows - .iter() - .map(|(key, _, _)| key) - .cloned() - .collect(), - tree: &planner.table.row, - }; - let index_tree_info = IndexInfo { - bounds: (planner.query.min_block, planner.query.max_block), - tree: &planner.table.index, - }; - let current_epoch = index_tree_info.tree.current_epoch(); + let current_epoch = planner.table.index.current_epoch(); let mut matching_rows_input = vec![]; for (key, epoch, result) in matching_rows.into_iter() { let row_proof = prove_single_row( planner.ctx, - &row_tree_info, + &planner.table.row, &planner.columns, epoch as BlockPrimaryIndex, &key, @@ -105,13 +89,21 @@ pub(crate) async fn prove_query( &planner.query, ) .await?; - let (row_node_info, _, _) = get_node_info(&row_tree_info, &key, epoch).await; - let (row_tree_path, row_tree_siblings) = get_path_info(&key, &row_tree_info, epoch).await?; + let (row_node_info, _, _) = get_node_info(&planner.table.row, &key, epoch).await; + let (row_tree_path, row_tree_siblings) = get_path_info( + &key, + &planner.table.row, + epoch) + .await?; let index_node_key = epoch as BlockPrimaryIndex; let (index_node_info, _, _) = - get_node_info(&index_tree_info, &index_node_key, current_epoch).await; + get_node_info(&planner.table.index, &index_node_key, current_epoch).await; let (index_tree_path, index_tree_siblings) = - get_path_info(&index_node_key, &index_tree_info, current_epoch).await?; + get_path_info( + &index_node_key, + &planner.table.index, + current_epoch + ).await?; let path = RowPath::new( row_node_info, row_tree_path, @@ -163,7 +155,7 @@ pub(crate) async fn prove_query( Ok(()) } -async fn get_path_info>( +async fn get_path_info>( key: &K, tree_info: &T, epoch: Epoch, @@ -175,7 +167,7 @@ where let mut tree_path = vec![]; let mut siblings = vec![]; let (mut node_ctx, mut node_payload) = tree_info - .fetch_ctx_and_payload_at(epoch, key) + .fetch_ctx_and_payload_at(key, epoch) .await .ok_or(Error::msg(format!("Node not found for key {:?}", key)))?; let mut previous_node_hash = node_payload.hash(); @@ -183,7 +175,7 @@ where while node_ctx.parent.is_some() { let parent_key = node_ctx.parent.unwrap(); (node_ctx, node_payload) = tree_info - .fetch_ctx_and_payload_at(epoch, &parent_key) + .fetch_ctx_and_payload_at(&parent_key, epoch) .await .ok_or(Error::msg(format!( "Node not found for key {:?}", @@ -199,7 +191,7 @@ where match node_ctx.right { Some(k) => { let (_, payload) = tree_info - .fetch_ctx_and_payload_at(epoch, &k) + .fetch_ctx_and_payload_at(&k, epoch) .await .ok_or(Error::msg(format!("Node not found for key {:?}", k)))?; Some(payload.hash()) @@ -212,7 +204,7 @@ where match node_ctx.left { Some(k) => { let (_, payload) = tree_info - .fetch_ctx_and_payload_at(epoch, &k) + .fetch_ctx_and_payload_at(&k, epoch) .await .ok_or(Error::msg(format!("Node not found for key {:?}", k)))?; Some(payload.hash()) @@ -250,6 +242,71 @@ where Ok((tree_path, siblings)) } +pub(crate) async fn prove_single_row>>( + ctx: &mut TestContext, + tree: &T, + columns: &TableColumns, + primary: BlockPrimaryIndex, + row_key: &RowTreeKey, + pis: &DynamicCircuitPis, + query: &QueryCooking, +) -> Result> { + // 1. Get the all the cells including primary and secondary index + // Note we can use the primary as epoch since now epoch == primary in the storage + let (row_ctx, row_payload) = tree + .fetch_ctx_and_payload_at(row_key, primary as Epoch) + .await + .expect("cache not full"); + + // API is gonna change on this but right now, we have to sort all the "rest" cells by index + // in the tree, and put the primary one and secondary one in front + let rest_cells = columns + .non_indexed_columns() + .iter() + .map(|tc| tc.identifier) + .filter_map(|id| { + row_payload + .cells + .find_by_column(id) + .map(|info| ColumnCell::new(id, info.value)) + }) + .collect::>(); + + let secondary_cell = ColumnCell::new( + row_payload.secondary_index_column, + row_payload.secondary_index_value(), + ); + let primary_cell = ColumnCell::new(identifier_block_column(), U256::from(primary)); + let row = RowCells::new(primary_cell, secondary_cell, rest_cells); + // 2. create input + let input = QueryCircuitInput::new_universal_circuit( + &row, + &pis.predication_operations, + &pis.result, + &query.placeholders, + row_ctx.is_leaf(), + &pis.bounds, + ) + .expect("unable to create universal query circuit inputs"); + // 3. run proof if not ran already + let proof_key = ProofKey::QueryUniversal(( + query.query.clone(), + query.placeholders.placeholder_values(), + primary, + row_key.clone(), + )); + let proof = { + info!("Universal query proof RUNNING for {primary} -> {row_key:?} "); + let proof = ctx + .run_query_proof("querying::universal", GlobalCircuitInput::Query(input)) + .expect("unable to generate universal proof for {epoch} -> {row_key:?}"); + info!("Universal query proof DONE for {primary} -> {row_key:?} "); + ctx.storage.store_proof(proof_key, proof.clone())?; + proof + }; + Ok(proof) +} + /// Cook a query where the number of matching rows is the same as the maximum number of /// outputs allowed pub(crate) async fn cook_query_with_max_num_matching_rows( diff --git a/mp2-v1/tests/common/proof_storage.rs b/mp2-v1/tests/common/proof_storage.rs index fd2a96d30..059d8aa56 100644 --- a/mp2-v1/tests/common/proof_storage.rs +++ b/mp2-v1/tests/common/proof_storage.rs @@ -3,7 +3,7 @@ use std::{ path::{Path, PathBuf}, }; -use super::{context::TestContextConfig, mkdir_all, table::TableID}; +use super::{cases::query::NUM_CHUNKS, context::TestContextConfig, mkdir_all, table::TableID}; use alloy::primitives::{Address, U256}; use anyhow::{bail, Context, Result}; use envconfig::Envconfig; @@ -68,14 +68,11 @@ pub enum ProofKey { #[allow(clippy::upper_case_acronyms)] IVC(BlockPrimaryIndex), QueryUniversal((QueryID, PlaceholderValues, BlockPrimaryIndex, RowTreeKey)), - QueryAggregateRow((QueryID, PlaceholderValues, BlockPrimaryIndex, RowTreeKey)), - QueryAggregateIndex((QueryID, PlaceholderValues, BlockPrimaryIndex)), - #[cfg(feature = "batching_circuits")] QueryAggregate( ( QueryID, PlaceholderValues, - mp2_v1::query::batching_planner::UTKey<{ super::cases::query::NUM_CHUNKS }>, + mp2_v1::query::batching_planner::UTKey, ), ), } @@ -131,15 +128,6 @@ impl Hash for ProofKey { "query_universal".hash(state); n.hash(state); } - ProofKey::QueryAggregateRow(n) => { - "query_aggregate_row".hash(state); - n.hash(state); - } - ProofKey::QueryAggregateIndex(n) => { - "query_aggregate_index".hash(state); - n.hash(state); - } - #[cfg(feature = "batching_circuits")] ProofKey::QueryAggregate(n) => { "query_aggregate".hash(state); n.hash(state); diff --git a/parsil/src/assembler.rs b/parsil/src/assembler.rs index bb4e22c1d..128385e9a 100644 --- a/parsil/src/assembler.rs +++ b/parsil/src/assembler.rs @@ -15,7 +15,7 @@ use sqlparser::ast::{ SelectItem, SetExpr, TableAlias, TableFactor, UnaryOperator, Value, }; use verifiable_db::query::{ - aggregation::{QueryBoundSource, QueryBounds}, + utils::{QueryBoundSource, QueryBounds}, computational_hash_ids::{AggregationOperation, Operation, PlaceholderIdentifier}, universal_circuit::universal_circuit_inputs::{ BasicOperation, InputOperand, OutputItem, Placeholders, ResultStructure, diff --git a/parsil/src/bracketer.rs b/parsil/src/bracketer.rs index 6b7358a2c..7a4908716 100644 --- a/parsil/src/bracketer.rs +++ b/parsil/src/bracketer.rs @@ -1,6 +1,6 @@ use alloy::primitives::U256; use ryhope::{KEY, PAYLOAD, VALID_FROM, VALID_UNTIL}; -use verifiable_db::query::aggregation::QueryBounds; +use verifiable_db::query::utils::QueryBounds; use crate::{symbols::ContextProvider, ParsilSettings}; diff --git a/parsil/src/isolator.rs b/parsil/src/isolator.rs index 66014d903..ca145145d 100644 --- a/parsil/src/isolator.rs +++ b/parsil/src/isolator.rs @@ -3,7 +3,7 @@ use anyhow::*; use log::warn; use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SelectItem, TableAlias, TableFactor}; -use verifiable_db::query::aggregation::QueryBounds; +use verifiable_db::query::utils::QueryBounds; use crate::{ errors::ValidationError, diff --git a/parsil/src/lib.rs b/parsil/src/lib.rs index 499f4b06d..df88aeaa4 100644 --- a/parsil/src/lib.rs +++ b/parsil/src/lib.rs @@ -7,7 +7,7 @@ pub use utils::ParsilSettings; pub use utils::PlaceholderSettings; pub use utils::DEFAULT_MAX_BLOCK_PLACEHOLDER; pub use utils::DEFAULT_MIN_BLOCK_PLACEHOLDER; -use verifiable_db::query::aggregation::QueryBounds; +use verifiable_db::query::utils::QueryBounds; pub mod assembler; pub mod bracketer; diff --git a/parsil/src/queries.rs b/parsil/src/queries.rs index 2efeefc1a..92b6d7b29 100644 --- a/parsil/src/queries.rs +++ b/parsil/src/queries.rs @@ -5,7 +5,7 @@ use crate::{keys_in_index_boundaries, symbols::ContextProvider, ParsilSettings}; use anyhow::*; use ryhope::{tree::sbbst::NodeIdx, Epoch, EPOCH, KEY, VALID_FROM, VALID_UNTIL}; use verifiable_db::query::{ - aggregation::QueryBounds, universal_circuit::universal_circuit_inputs::Placeholders, + utils::QueryBounds, universal_circuit::universal_circuit_inputs::Placeholders, }; /// Return a query read to be injected in the wide lineage computation for the diff --git a/verifiable-db/Cargo.toml b/verifiable-db/Cargo.toml index 512378046..3339e8d36 100644 --- a/verifiable-db/Cargo.toml +++ b/verifiable-db/Cargo.toml @@ -20,6 +20,7 @@ serde.workspace = true mp2_common = { path = "../mp2-common" } recursion_framework = { path = "../recursion-framework" } ryhope = { path = "../ryhope" } +mp2_test = { path = "../mp2-test" } [dev-dependencies] futures.workspace = true @@ -27,8 +28,6 @@ rand.workspace = true serial_test.workspace = true tokio.workspace = true -mp2_test = { path = "../mp2-test" } - [features] original_poseidon = ["mp2_common/original_poseidon"] batching_circuits = [] \ No newline at end of file diff --git a/verifiable-db/src/api.rs b/verifiable-db/src/api.rs index 8088f605e..bd7556d53 100644 --- a/verifiable-db/src/api.rs +++ b/verifiable-db/src/api.rs @@ -1,7 +1,5 @@ //! Main APIs and related structures -#[cfg(feature = "batching_circuits")] -use crate::query::batching::circuits::api::Parameters as BatchingQueryParams; use crate::{ block_tree, cells_tree, extraction::{ExtractionPI, ExtractionPIWrap}, diff --git a/verifiable-db/src/lib.rs b/verifiable-db/src/lib.rs index 3a678639e..1cac73092 100644 --- a/verifiable-db/src/lib.rs +++ b/verifiable-db/src/lib.rs @@ -16,5 +16,4 @@ pub mod results_tree; /// Module for the query revelation circuits pub mod revelation; pub mod row_tree; -#[cfg(test)] pub mod test_utils; diff --git a/verifiable-db/src/query/aggregation/child_proven_single_path_node.rs b/verifiable-db/src/query/aggregation/child_proven_single_path_node.rs deleted file mode 100644 index 90f0c5120..000000000 --- a/verifiable-db/src/query/aggregation/child_proven_single_path_node.rs +++ /dev/null @@ -1,366 +0,0 @@ -use std::iter; - -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - hash::hash_maybe_first, - public_inputs::PublicInputCommon, - serialization::{deserialize, serialize}, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::ToTargets, - D, F, -}; -use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget}, - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; - -use crate::query::public_inputs::PublicInputs; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ChildProvenSinglePathNodeWires { - value: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - subtree_hash: HashOutTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - sibling_hash: HashOutTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_left_child: BoolTarget, - unproven_min: UInt256Target, - unproven_max: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ChildProvenSinglePathNodeCircuit { - /// Value stored in the current node - pub(crate) value: U256, - /// Hash of the row/rows tree stored in the current node - pub(crate) subtree_hash: HashOut, - /// Hash of the sibling of the proven node child - pub(crate) sibling_hash: HashOut, - /// Flag indicating whether the proven child is the left child or the right one - pub(crate) is_left_child: bool, - /// Minimum value of the indexed column to be employed to compute the hash of the current node - pub(crate) unproven_min: U256, - /// Maximum value of the indexed column to be employed to compute the hash of the current node - pub(crate) unproven_max: U256, - /// Boolean flag specifying whether the proof is being generated for a node - /// in a rows tree of for a node in the index tree - pub(crate) is_rows_tree_node: bool, -} - -impl ChildProvenSinglePathNodeCircuit { - pub fn build( - b: &mut CBuilder, - child_proof: &PublicInputs, - ) -> ChildProvenSinglePathNodeWires { - let is_rows_tree_node = b.add_virtual_bool_target_safe(); - let is_left_child = b.add_virtual_bool_target_unsafe(); - let value = b.add_virtual_u256(); - let subtree_hash = b.add_virtual_hash(); - let sibling_hash = b.add_virtual_hash(); - let unproven_min = b.add_virtual_u256_unsafe(); - let unproven_max = b.add_virtual_u256_unsafe(); - - let node_min = b.select_u256( - is_left_child, - &child_proof.min_value_target(), - &unproven_min, - ); - let node_max = b.select_u256( - is_left_child, - &unproven_max, - &child_proof.max_value_target(), - ); - let column_id = b.select( - is_rows_tree_node, - child_proof.index_ids_target()[1], - child_proof.index_ids_target()[0], - ); - // Compute the node hash: - // node_hash = H(left_child_hash||right_child_hash||node_min||node_max||column_id||value||subtree_hash) - let rest: Vec<_> = node_min - .to_targets() - .into_iter() - .chain(node_max.to_targets()) - .chain(iter::once(column_id)) - .chain(value.to_targets()) - .chain(subtree_hash.elements) - .collect(); - - let node_hash = hash_maybe_first( - b, - is_left_child, - sibling_hash.elements, - child_proof.tree_hash_target().elements, - &rest, - ); - - // if is_left_child: - // value > child_proof.max_query - // else: - // value < child_proof.min_query - let is_greater_than_max = b.is_greater_than_u256(&value, &child_proof.max_query_target()); - let is_less_than_min = b.is_less_than_u256(&value, &child_proof.min_query_target()); - let condition = b.select( - is_left_child, - is_greater_than_max.target, - is_less_than_min.target, - ); - let ttrue = b._true(); - b.connect(condition, ttrue.target); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - child_proof.to_values_raw(), - &[child_proof.num_matching_rows_target()], - child_proof.to_ops_raw(), - child_proof.to_index_value_raw(), - &node_min.to_targets(), - &node_max.to_targets(), - child_proof.to_index_ids_raw(), - child_proof.to_min_query_raw(), - child_proof.to_max_query_raw(), - &[*child_proof.to_overflow_raw()], - child_proof.to_computational_hash_raw(), - child_proof.to_placeholder_hash_raw(), - ) - .register(b); - - ChildProvenSinglePathNodeWires { - value, - subtree_hash, - sibling_hash, - is_left_child, - unproven_min, - unproven_max, - is_rows_tree_node, - } - } - - fn assign( - &self, - pw: &mut PartialWitness, - wires: &ChildProvenSinglePathNodeWires, - ) { - pw.set_u256_target(&wires.value, self.value); - pw.set_hash_target(wires.subtree_hash, self.subtree_hash); - pw.set_hash_target(wires.sibling_hash, self.sibling_hash); - pw.set_bool_target(wires.is_left_child, self.is_left_child); - pw.set_u256_target(&wires.unproven_min, self.unproven_min); - pw.set_u256_target(&wires.unproven_max, self.unproven_max); - pw.set_bool_target(wires.is_rows_tree_node, self.is_rows_tree_node); - } -} - -pub(crate) const NUM_VERIFIED_PROOFS: usize = 1; - -impl CircuitLogicWires - for ChildProvenSinglePathNodeWires -{ - type CircuitBuilderParams = (); - type Inputs = ChildProvenSinglePathNodeCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - let child_proof = PublicInputs::from_slice(&verified_proofs[0].public_inputs); - - Self::Inputs::build(builder, &child_proof) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - query::pi_len, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - use mp2_common::{poseidon::H, utils::ToFields, C, D}; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::gen_random_field_hash, - }; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; - use rand::{thread_rng, Rng}; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestChildProvenSinglePathNodeCircuit<'a> { - c: ChildProvenSinglePathNodeCircuit, - child_proof: &'a [F], - } - - impl UserCircuit for TestChildProvenSinglePathNodeCircuit<'_> { - type Wires = (ChildProvenSinglePathNodeWires, Vec); - - fn build(b: &mut CBuilder) -> Self::Wires { - let child_proof = b - .add_virtual_target_arr::<{ pi_len::() }>() - .to_vec(); - let pi = PublicInputs::::from_slice(&child_proof); - - let wires = ChildProvenSinglePathNodeCircuit::build(b, &pi); - - (wires, child_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.child_proof); - } - } - - fn test_child_proven_single_path_node_circuit(is_rows_tree_node: bool, is_left_child: bool) { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Build the child proof. - let [child_proof] = random_aggregation_public_inputs(&ops); - let child_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&child_proof); - - let index_ids = child_pi.index_ids(); - let index_value = child_pi.index_value(); - let min_query = child_pi.min_query_value(); - let max_query = child_pi.max_query_value(); - - // Construct the witness. - let mut rng = thread_rng(); - let mut value = U256::from_limbs(rng.gen::<[u64; 4]>()); - let subtree_hash = gen_random_field_hash(); - let sibling_hash = gen_random_field_hash(); - let unproven_min = index_value - .checked_sub(U256::from(100)) - .unwrap_or(index_value); - let unproven_max = index_value - .checked_add(U256::from(100)) - .unwrap_or(index_value); - - if is_left_child { - while value <= max_query { - value = U256::from_limbs(rng.gen::<[u64; 4]>()); - } - } else { - while value >= min_query { - value = U256::from_limbs(rng.gen::<[u64; 4]>()); - } - } - - // Construct the test circuit. - let test_circuit = TestChildProvenSinglePathNodeCircuit { - c: ChildProvenSinglePathNodeCircuit { - value, - subtree_hash, - sibling_hash, - is_left_child, - unproven_min, - unproven_max, - is_rows_tree_node, - }, - child_proof: &child_proof, - }; - - // Proof for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - let [node_min, node_max] = if is_left_child { - [child_pi.min_value(), unproven_max] - } else { - [unproven_min, child_pi.max_value()] - }; - // Check the public inputs. - // Tree hash - { - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - - let child_hash = child_pi.tree_hash(); - let [left_child_hash, right_child_hash] = if is_left_child { - [child_hash, sibling_hash] - } else { - [sibling_hash, child_hash] - }; - - // H(left_child_hash||right_child_hash||node_min||node_max||column_id||value||subtree_hash) - let input: Vec<_> = left_child_hash - .to_fields() - .into_iter() - .chain(right_child_hash.to_fields()) - .chain(node_min.to_fields()) - .chain(node_max.to_fields()) - .chain(iter::once(column_id)) - .chain(value.to_fields()) - .chain(subtree_hash.to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&input); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values - assert_eq!(pi.to_values_raw(), child_pi.to_values_raw()); - // Count - assert_eq!(pi.num_matching_rows(), child_pi.num_matching_rows()); - // Operation IDs - assert_eq!(pi.operation_ids(), child_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), node_min); - // Maximum value - assert_eq!(pi.max_value(), node_max); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Overflow flag - assert_eq!(pi.overflow_flag(), child_pi.overflow_flag()); - // Computational hash - assert_eq!(pi.computational_hash(), child_pi.computational_hash()); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), child_pi.placeholder_hash()); - } - - #[test] - fn test_child_proven_node_for_row_node_with_left_child() { - test_child_proven_single_path_node_circuit(true, true); - } - #[test] - fn test_child_proven_node_for_row_node_with_right_child() { - test_child_proven_single_path_node_circuit(true, false); - } - #[test] - fn test_child_proven_node_for_index_node_with_left_child() { - test_child_proven_single_path_node_circuit(false, true); - } - #[test] - fn test_child_proven_node_for_index_node_with_right_child() { - test_child_proven_single_path_node_circuit(false, false); - } -} diff --git a/verifiable-db/src/query/aggregation/embedded_tree_proven_single_path_node.rs b/verifiable-db/src/query/aggregation/embedded_tree_proven_single_path_node.rs deleted file mode 100644 index 94ce04cec..000000000 --- a/verifiable-db/src/query/aggregation/embedded_tree_proven_single_path_node.rs +++ /dev/null @@ -1,572 +0,0 @@ -use std::iter; - -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - poseidon::{empty_poseidon_hash, H}, - public_inputs::PublicInputCommon, - serialization::{deserialize, deserialize_array, serialize, serialize_array}, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::{HashBuilder, ToTargets}, - D, F, -}; -use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget}, - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::array; - -use crate::query::public_inputs::PublicInputs; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct EmbeddedTreeProvenSinglePathNodeWires { - left_child_min: UInt256Target, - left_child_max: UInt256Target, - left_child_value: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - left_tree_hash: HashOutTarget, - #[serde( - serialize_with = "serialize_array", - deserialize_with = "deserialize_array" - )] - left_grand_children: [HashOutTarget; 2], - right_child_min: UInt256Target, - right_child_max: UInt256Target, - right_child_value: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - right_tree_hash: HashOutTarget, - #[serde( - serialize_with = "serialize_array", - deserialize_with = "deserialize_array" - )] - right_grand_children: [HashOutTarget; 2], - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - left_child_exists: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - right_child_exists: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, - min_query: UInt256Target, - max_query: UInt256Target, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct EmbeddedTreeProvenSinglePathNodeCircuit { - /// Minimum value associated to the left child - pub(crate) left_child_min: U256, - /// Maximum value associated to the left child - pub(crate) left_child_max: U256, - /// Value stored in the left child - pub(crate) left_child_value: U256, - /// Hashes of the row/rows tree stored in the left child - pub(crate) left_tree_hash: HashOut, - /// Hashes of the children nodes of the left child - pub(crate) left_grand_children: [HashOut; 2], - /// Minimum value associated to the right child - pub(crate) right_child_min: U256, - /// Maximum value associated to the right child - pub(crate) right_child_max: U256, - /// Value stored in the right child - pub(crate) right_child_value: U256, - /// Hashes of the row/rows tree stored in the right child - pub(crate) right_tree_hash: HashOut, - /// Hashes of the children nodes of the right child - pub(crate) right_grand_children: [HashOut; 2], - /// Boolean flag specifying whether there is a left child for the current node - pub(crate) left_child_exists: bool, - /// Boolean flag specifying whether there is a right child for the current node - pub(crate) right_child_exists: bool, - /// Boolean flag specifying whether the proof is being generated - /// for a node in a rows tree or for a node in the index tree - pub(crate) is_rows_tree_node: bool, - /// Minimum range bound specified in the query for the indexed column - pub(crate) min_query: U256, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: U256, -} - -impl EmbeddedTreeProvenSinglePathNodeCircuit { - pub fn build( - b: &mut CBuilder, - embedded_tree_proof: &PublicInputs, - ) -> EmbeddedTreeProvenSinglePathNodeWires { - let empty_hash = b.constant_hash(*empty_poseidon_hash()); - - let [left_child_min, left_child_max, left_child_value, right_child_min, right_child_max, right_child_value, min_query, max_query] = - array::from_fn(|_| b.add_virtual_u256_unsafe()); - let [left_tree_hash, right_tree_hash] = array::from_fn(|_| b.add_virtual_hash()); - let left_grand_children: [HashOutTarget; 2] = array::from_fn(|_| b.add_virtual_hash()); - let right_grand_children: [HashOutTarget; 2] = array::from_fn(|_| b.add_virtual_hash()); - let [left_child_exists, right_child_exists, is_rows_tree_node] = - array::from_fn(|_| b.add_virtual_bool_target_safe()); - - let index_value = embedded_tree_proof.index_value_target(); - - let column_id = b.select( - is_rows_tree_node, - embedded_tree_proof.index_ids_target()[1], - embedded_tree_proof.index_ids_target()[0], - ); - - let node_value = b.select_u256( - is_rows_tree_node, - &embedded_tree_proof.min_value_target(), - &index_value, - ); - - // H(left_grandchild_1||left_grandchild_2||left_min||left_max||column_id||left_value||left_tree_hash) - let left_child_inputs = left_grand_children[0] - .to_targets() - .into_iter() - .chain(left_grand_children[1].to_targets()) - .chain(left_child_min.to_targets()) - .chain(left_child_max.to_targets()) - .chain(iter::once(column_id)) - .chain(left_child_value.to_targets()) - .chain(left_tree_hash.to_targets()) - .collect(); - let left_hash_exists = b.hash_n_to_hash_no_pad::(left_child_inputs); - let left_child_hash = b.select_hash(left_child_exists, &left_hash_exists, &empty_hash); - - // H(right_grandchild_1||right_grandchild_2||right_min||right_max||column_id||right_value||right_tree_hash) - let right_child_inputs = right_grand_children[0] - .to_targets() - .into_iter() - .chain(right_grand_children[1].to_targets()) - .chain(right_child_min.to_targets()) - .chain(right_child_max.to_targets()) - .chain(iter::once(column_id)) - .chain(right_child_value.to_targets()) - .chain(right_tree_hash.to_targets()) - .collect(); - let right_hash_exists = b.hash_n_to_hash_no_pad::(right_child_inputs); - let right_child_hash = b.select_hash(right_child_exists, &right_hash_exists, &empty_hash); - - let node_min = b.select_u256(left_child_exists, &left_child_min, &node_value); - let node_max = b.select_u256(right_child_exists, &right_child_max, &node_value); - - // If the current node is not a rows tree, we need to ensure that - // the value of the primary indexed column for all the records stored in the rows tree - // found in this node is within the range specified by the query: - // min_i1 <= index_value <= max_i1 - // -> NOT((index_value < min_i1) OR (index_value > max_i1)) - let is_less_than = b.is_less_than_u256(&index_value, &min_query); - let is_greater_than = b.is_greater_than_u256(&index_value, &max_query); - let is_out_of_range = b.or(is_less_than, is_greater_than); - let is_within_range = b.not(is_out_of_range); - - // If the current node is in a rows tree, we need to ensure that - // the query bounds exposed as public inputs are the same as the one exposed - // by the proof for the row associated to the current node - let is_min_same = b.is_equal_u256(&embedded_tree_proof.min_query_target(), &min_query); - let is_max_same = b.is_equal_u256(&embedded_tree_proof.max_query_target(), &max_query); - let are_query_bounds_same = b.and(is_min_same, is_max_same); - - // if is_rows_tree_node: - // embedded_tree_proof.min_query == min_query && - // embedded_tree_proof.max_query == max_query - // else if not is_rows_tree_node: - // min_query <= index_value <= max_query - let rows_tree_condition = b.select( - is_rows_tree_node, - are_query_bounds_same.target, - is_within_range.target, - ); - let ttrue = b._true(); - b.connect(rows_tree_condition, ttrue.target); - - // Enforce that the subtree rooted in the left child contains - // only nodes outside of the range specified by the query - let is_less_than_min = b.is_less_than_u256(&left_child_max, &min_query); - let left_condition = b.and(left_child_exists, is_less_than_min); - // (left_child_exists AND (left_child_max < min_query)) == left_child_exists - b.connect(left_condition.target, left_child_exists.target); - - // Enforce that the subtree rooted in the right child contains - // only nodes outside of the range specified by the query - let is_greater_than_max = b.is_greater_than_u256(&right_child_min, &max_query); - let right_condition = b.and(right_child_exists, is_greater_than_max); - // (right_child_exists AND (right_child_min > min_query)) == right_child_exists - b.connect(right_condition.target, right_child_exists.target); - - // H(left_child_hash||right_child_hash||node_min||node_max||column_id||node_value||p.H) - let node_hash_inputs = left_child_hash - .elements - .into_iter() - .chain(right_child_hash.elements) - .chain(node_min.to_targets()) - .chain(node_max.to_targets()) - .chain(iter::once(column_id)) - .chain(node_value.to_targets()) - .chain(embedded_tree_proof.tree_hash_target().to_targets()) - .collect(); - let node_hash = b.hash_n_to_hash_no_pad::(node_hash_inputs); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - embedded_tree_proof.to_values_raw(), - &[embedded_tree_proof.num_matching_rows_target()], - embedded_tree_proof.to_ops_raw(), - embedded_tree_proof.to_index_value_raw(), - &node_min.to_targets(), - &node_max.to_targets(), - embedded_tree_proof.to_index_ids_raw(), - &min_query.to_targets(), - &max_query.to_targets(), - &[*embedded_tree_proof.to_overflow_raw()], - embedded_tree_proof.to_computational_hash_raw(), - embedded_tree_proof.to_placeholder_hash_raw(), - ) - .register(b); - - EmbeddedTreeProvenSinglePathNodeWires { - left_child_min, - left_child_max, - left_child_value, - left_tree_hash, - left_grand_children, - right_child_min, - right_child_max, - right_child_value, - right_tree_hash, - right_grand_children, - left_child_exists, - right_child_exists, - is_rows_tree_node, - min_query, - max_query, - } - } - - fn assign( - &self, - pw: &mut PartialWitness, - wires: &EmbeddedTreeProvenSinglePathNodeWires, - ) { - [ - (&wires.left_child_min, self.left_child_min), - (&wires.left_child_max, self.left_child_max), - (&wires.left_child_value, self.left_child_value), - (&wires.right_child_min, self.right_child_min), - (&wires.right_child_max, self.right_child_max), - (&wires.right_child_value, self.right_child_value), - (&wires.min_query, self.min_query), - (&wires.max_query, self.max_query), - ] - .iter() - .for_each(|(t, v)| pw.set_u256_target(t, *v)); - [ - (wires.left_tree_hash, self.left_tree_hash), - (wires.right_tree_hash, self.right_tree_hash), - ] - .iter() - .for_each(|(t, v)| pw.set_hash_target(*t, *v)); - wires - .left_grand_children - .iter() - .zip(self.left_grand_children) - .for_each(|(t, v)| pw.set_hash_target(*t, v)); - wires - .right_grand_children - .iter() - .zip(self.right_grand_children) - .for_each(|(t, v)| pw.set_hash_target(*t, v)); - [ - (wires.left_child_exists, self.left_child_exists), - (wires.right_child_exists, self.right_child_exists), - (wires.is_rows_tree_node, self.is_rows_tree_node), - ] - .iter() - .for_each(|(t, v)| pw.set_bool_target(*t, *v)); - } -} - -pub(crate) const NUM_VERIFIED_PROOFS: usize = 1; - -impl CircuitLogicWires - for EmbeddedTreeProvenSinglePathNodeWires -{ - type CircuitBuilderParams = (); - type Inputs = EmbeddedTreeProvenSinglePathNodeCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - let child_proof = PublicInputs::from_slice(&verified_proofs[0].public_inputs); - - Self::Inputs::build(builder, &child_proof) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use mp2_common::{utils::ToFields, C}; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::gen_random_field_hash, - }; - use plonky2::plonk::config::Hasher; - use rand::{thread_rng, Rng}; - - use crate::{ - query::pi_len, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestEmbeddedTreeProvenSinglePathNodeCircuit<'a> { - c: EmbeddedTreeProvenSinglePathNodeCircuit, - embedded_tree_proof: &'a [F], - } - - impl UserCircuit for TestEmbeddedTreeProvenSinglePathNodeCircuit<'_> { - type Wires = ( - EmbeddedTreeProvenSinglePathNodeWires, - Vec, - ); - - fn build(b: &mut CBuilder) -> Self::Wires { - let embedded_tree_proof = b - .add_virtual_target_arr::<{ pi_len::() }>() - .to_vec(); - let pi = PublicInputs::::from_slice(&embedded_tree_proof); - - let wires = EmbeddedTreeProvenSinglePathNodeCircuit::build(b, &pi); - - (wires, embedded_tree_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.embedded_tree_proof); - } - } - - fn test_embedded_tree_proven_single_path_node_circuit( - is_rows_tree_node: bool, - left_child_exists: bool, - right_child_exists: bool, - ) { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Build the subtree proof. - let [embdeed_tree_proof] = random_aggregation_public_inputs(&ops); - let embedded_tree_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&embdeed_tree_proof); - - let index_ids = embedded_tree_pi.index_ids(); - let index_value = embedded_tree_pi.index_value(); - - // Construct the witness. - let mut rng = thread_rng(); - let [left_child_min, mut left_child_max, left_child_value, mut right_child_min, right_child_max, right_child_value] = - array::from_fn(|_| U256::from_limbs(rng.gen::<[u64; 4]>())); - let left_tree_hash = gen_random_field_hash(); - let left_grand_children: [HashOut; 2] = array::from_fn(|_| gen_random_field_hash()); - let right_tree_hash = gen_random_field_hash(); - let right_grand_children: [HashOut; 2] = array::from_fn(|_| gen_random_field_hash()); - let mut min_query = U256::from(100); - let mut max_query = U256::from(200); - - if is_rows_tree_node { - min_query = embedded_tree_pi.min_query_value(); - max_query = embedded_tree_pi.max_query_value(); - } else { - if min_query > index_value { - min_query = index_value - U256::from(1); - } - if max_query < index_value { - max_query = index_value + U256::from(1); - } - } - - if left_child_exists { - left_child_max = min_query - U256::from(1); - } - - if right_child_exists { - right_child_min = max_query + U256::from(1); - } - - // Construct the test circuit. - let test_circuit = TestEmbeddedTreeProvenSinglePathNodeCircuit { - c: EmbeddedTreeProvenSinglePathNodeCircuit { - left_child_min, - left_child_max, - left_child_value, - left_tree_hash, - left_grand_children, - right_child_min, - right_child_max, - right_child_value, - right_tree_hash, - right_grand_children, - left_child_exists, - right_child_exists, - is_rows_tree_node, - min_query, - max_query, - }, - embedded_tree_proof: &embdeed_tree_proof, - }; - - // Proof for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - let node_value = if is_rows_tree_node { - embedded_tree_pi.min_value() - } else { - index_value - }; - let node_min = if left_child_exists { - left_child_min - } else { - node_value - }; - let node_max = if right_child_exists { - right_child_max - } else { - node_value - }; - // Check the public inputs. - // Tree hash - { - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - - let empty_hash = empty_poseidon_hash(); - // H(left_grandchild_1||left_grandchild_2||left_min||left_max||column_id||left_value||left_subtree_hash) - let left_child_inputs: Vec<_> = left_grand_children[0] - .to_fields() - .into_iter() - .chain(left_grand_children[1].to_fields()) - .chain(left_child_min.to_fields()) - .chain(left_child_max.to_fields()) - .chain(iter::once(column_id)) - .chain(left_child_value.to_fields()) - .chain(left_tree_hash.to_fields()) - .collect(); - let left_hash_exists = H::hash_no_pad(&left_child_inputs); - let left_child_hash = if left_child_exists { - left_hash_exists - } else { - *empty_hash - }; - // H(right_grandchild_1||right_grandchild_2||right_min||right_max||column_id||right_value||right_subtree_hash) - let right_child_inputs: Vec<_> = right_grand_children[0] - .to_fields() - .into_iter() - .chain(right_grand_children[1].to_fields()) - .chain(right_child_min.to_fields()) - .chain(right_child_max.to_fields()) - .chain(iter::once(column_id)) - .chain(right_child_value.to_fields()) - .chain(right_tree_hash.to_fields()) - .collect(); - let right_hash_exists = H::hash_no_pad(&right_child_inputs); - let right_child_hash = if right_child_exists { - right_hash_exists - } else { - *empty_hash - }; - - let node_hash_input: Vec<_> = left_child_hash - .to_fields() - .into_iter() - .chain(right_child_hash.to_fields()) - .chain(node_min.to_fields()) - .chain(node_max.to_fields()) - .chain(iter::once(column_id)) - .chain(node_value.to_fields()) - .chain(embedded_tree_pi.tree_hash().to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&node_hash_input); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values - assert_eq!(pi.to_values_raw(), embedded_tree_pi.to_values_raw()); - // Count - assert_eq!(pi.num_matching_rows(), embedded_tree_pi.num_matching_rows()); - // Operation IDs - assert_eq!(pi.operation_ids(), embedded_tree_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), node_min); - // Maximum value - assert_eq!(pi.max_value(), node_max); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Overflow flag - assert_eq!(pi.overflow_flag(), embedded_tree_pi.overflow_flag()); - // Computational hash - assert_eq!( - pi.computational_hash(), - embedded_tree_pi.computational_hash() - ); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), embedded_tree_pi.placeholder_hash()); - } - - #[test] - fn test_embedded_tree_proven_node_for_row_node_with_no_child() { - test_embedded_tree_proven_single_path_node_circuit(true, false, false); - } - #[test] - fn test_embedded_tree_proven_node_for_row_node_with_left_child() { - test_embedded_tree_proven_single_path_node_circuit(true, true, false); - } - #[test] - fn test_embedded_tree_proven_node_for_row_node_with_right_child() { - test_embedded_tree_proven_single_path_node_circuit(true, false, true); - } - #[test] - fn test_embedded_tree_proven_node_for_row_node_with_both_children() { - test_embedded_tree_proven_single_path_node_circuit(true, true, true); - } - #[test] - fn test_embedded_tree_proven_node_for_index_node_with_no_child() { - test_embedded_tree_proven_single_path_node_circuit(false, false, false); - } - #[test] - fn test_embedded_tree_proven_node_for_index_node_with_left_child() { - test_embedded_tree_proven_single_path_node_circuit(false, true, false); - } - #[test] - fn test_embedded_tree_proven_node_for_index_node_with_right_child() { - test_embedded_tree_proven_single_path_node_circuit(false, false, true); - } - #[test] - fn test_embedded_tree_proven_node_for_index_node_with_both_children() { - test_embedded_tree_proven_single_path_node_circuit(false, true, true); - } -} diff --git a/verifiable-db/src/query/aggregation/full_node_index_leaf.rs b/verifiable-db/src/query/aggregation/full_node_index_leaf.rs deleted file mode 100644 index ffe02d5aa..000000000 --- a/verifiable-db/src/query/aggregation/full_node_index_leaf.rs +++ /dev/null @@ -1,246 +0,0 @@ -//! Module handling the leaf full node of the index tree for query aggregation circuits - -use crate::query::public_inputs::PublicInputs; -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - poseidon::{empty_poseidon_hash, H}, - public_inputs::PublicInputCommon, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::ToTargets, - D, F, -}; -use plonky2::{ - iop::{target::Target, witness::PartialWitness}, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::iter; - -/// Leaf wires -/// The constant generic parameter is only used for impl `CircuitLogicWires`. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeIndexLeafWires { - min_query: UInt256Target, - max_query: UInt256Target, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeIndexLeafCircuit { - /// Minimum range bound specified in the query for the indexed column - pub(crate) min_query: U256, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: U256, -} - -impl FullNodeIndexLeafCircuit { - pub fn build( - b: &mut CBuilder, - subtree_proof: &PublicInputs, - ) -> FullNodeIndexLeafWires { - let ttrue = b._true(); - let empty_hash = b.constant_hash(*empty_poseidon_hash()); - let empty_hash_targets = empty_hash.to_targets(); - - let [min_query, max_query] = [0; 2].map(|_| b.add_virtual_u256()); - - let index_ids = subtree_proof.index_ids_target(); - let index_value = subtree_proof.index_value_target(); - let index_value_targets = subtree_proof.to_index_value_raw(); - - // Ensure the value of the indexed column for all the records stored in the - // subtree found in this node is within the range specified by the query: - // p.I >= MIN_query AND p.I <= MAX_query - let is_not_less_than_min = b.is_less_or_equal_than_u256(&min_query, &index_value); - let is_not_greater_than_max = b.is_less_or_equal_than_u256(&index_value, &max_query); - let is_in_range = b.and(is_not_less_than_min, is_not_greater_than_max); - b.connect(is_in_range.target, ttrue.target); - - // Compute the node hash: - // node_hash = H(H("") || H("") || p.I || p.I || p.index_ids[0] || p.I || p.H)) - let inputs = empty_hash_targets - .iter() - .chain(empty_hash_targets.iter()) - .chain(index_value_targets) - .chain(index_value_targets) - .chain(iter::once(&index_ids[0])) - .chain(index_value_targets) - .cloned() - .chain(subtree_proof.tree_hash_target().to_targets()) - .collect(); - let node_hash = b.hash_n_to_hash_no_pad::(inputs); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - subtree_proof.to_values_raw(), - &[subtree_proof.num_matching_rows_target()], - subtree_proof.to_ops_raw(), - index_value_targets, - index_value_targets, - index_value_targets, - subtree_proof.to_index_ids_raw(), - &min_query.to_targets(), - &max_query.to_targets(), - &[*subtree_proof.to_overflow_raw()], - subtree_proof.to_computational_hash_raw(), - subtree_proof.to_placeholder_hash_raw(), - ) - .register(b); - - FullNodeIndexLeafWires { - min_query, - max_query, - } - } - - fn assign(&self, pw: &mut PartialWitness, wires: &FullNodeIndexLeafWires) { - pw.set_u256_target(&wires.min_query, self.min_query); - pw.set_u256_target(&wires.max_query, self.max_query); - } -} - -/// Subtree proof number = 1, child proof number = 0 -pub(crate) const NUM_VERIFIED_PROOFS: usize = 1; - -impl CircuitLogicWires - for FullNodeIndexLeafWires -{ - type CircuitBuilderParams = (); - type Inputs = FullNodeIndexLeafCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - // The first one is the subtree proof. - let subtree_proof = PublicInputs::from_slice(&verified_proofs[0].public_inputs); - - Self::Inputs::build(builder, &subtree_proof) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - query::{aggregation::utils::tests::unify_subtree_proof, pi_len}, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - use mp2_common::{utils::ToFields, C}; - use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestFullNodeIndexLeafCircuit<'a> { - c: FullNodeIndexLeafCircuit, - subtree_proof: &'a [F], - } - - impl UserCircuit for TestFullNodeIndexLeafCircuit<'_> { - // Circuit wires + subtree proof - type Wires = (FullNodeIndexLeafWires, Vec); - - fn build(b: &mut CBuilder) -> Self::Wires { - let subtree_proof = b - .add_virtual_target_arr::<{ pi_len::() }>() - .to_vec(); - let subtree_pi = PublicInputs::::from_slice(&subtree_proof); - - let wires = FullNodeIndexLeafCircuit::build(b, &subtree_pi); - - (wires, subtree_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.subtree_proof); - } - } - - #[test] - fn test_query_agg_full_node_index_leaf() { - let min_query = U256::from(100); - let max_query = U256::from(200); - - // Generate the subtree proof. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - let [mut subtree_proof] = random_aggregation_public_inputs(&ops); - unify_subtree_proof::(&mut subtree_proof, false, min_query, max_query); - let subtree_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&subtree_proof); - - let index_value = subtree_pi.index_value(); - let index_value_fields = subtree_pi.to_index_value_raw(); - let index_ids = subtree_pi.index_ids(); - - // Construct the test circuit. - let test_circuit = TestFullNodeIndexLeafCircuit { - c: FullNodeIndexLeafCircuit { - min_query, - max_query, - }, - subtree_proof: &subtree_proof, - }; - - // Prove for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - // Check the public inputs. - // Tree hash - { - // H(H("") || H("") || p.I || p.I || p.index_ids[0] || p.I || p.H)) - let empty_hash = empty_poseidon_hash(); - let empty_hash_fields = empty_hash.to_fields(); - let inputs: Vec<_> = empty_hash_fields - .iter() - .chain(empty_hash_fields.iter()) - .chain(index_value_fields) - .chain(index_value_fields) - .chain(iter::once(&index_ids[0])) - .chain(index_value_fields) - .chain(subtree_pi.to_hash_raw()) - .cloned() - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values - assert_eq!(pi.to_values_raw(), subtree_pi.to_values_raw()); - // Count - assert_eq!(pi.num_matching_rows(), subtree_pi.num_matching_rows()); - // Operation IDs - assert_eq!(pi.operation_ids(), subtree_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), index_value); - // Maximum value - assert_eq!(pi.max_value(), index_value); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Overflow flag - assert_eq!(pi.overflow_flag(), subtree_pi.overflow_flag()); - // Computational hash - assert_eq!(pi.computational_hash(), subtree_pi.computational_hash()); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), subtree_pi.placeholder_hash()); - } -} diff --git a/verifiable-db/src/query/aggregation/full_node_with_one_child.rs b/verifiable-db/src/query/aggregation/full_node_with_one_child.rs deleted file mode 100644 index 8ac0b9ef1..000000000 --- a/verifiable-db/src/query/aggregation/full_node_with_one_child.rs +++ /dev/null @@ -1,412 +0,0 @@ -//! Module handling the full node with one child for query aggregation circuits - -use crate::query::{ - aggregation::{output_computation::compute_output_item, utils::constrain_input_proofs}, - public_inputs::PublicInputs, -}; -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - hash::hash_maybe_first, - poseidon::empty_poseidon_hash, - public_inputs::PublicInputCommon, - serialization::{deserialize, serialize}, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::ToTargets, - D, F, -}; -use plonky2::{ - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::{iter, slice}; - -/// Full node wires with one child -/// The constant generic parameter is only used for impl `CircuitLogicWires`. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeWithOneChildWires { - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_left_child: BoolTarget, - min_query: UInt256Target, - max_query: UInt256Target, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeWithOneChildCircuit { - /// The flag specified if the proof is generated for a node in a rows tree or - /// for a node in the index tree - pub(crate) is_rows_tree_node: bool, - /// The flag specified if the child node is the left or right child - pub(crate) is_left_child: bool, - /// Minimum range bound specified in the query for the indexed column - /// It's a range bound for the primary indexed column for index tree, - /// and secondary indexed column for rows tree. - pub(crate) min_query: U256, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: U256, -} - -impl FullNodeWithOneChildCircuit { - pub fn build( - b: &mut CBuilder, - subtree_proof: &PublicInputs, - child_proof: &PublicInputs, - ) -> FullNodeWithOneChildWires - where - [(); MAX_NUM_RESULTS - 1]:, - { - let zero = b.zero(); - let empty_hash = b.constant_hash(*empty_poseidon_hash()); - - let is_rows_tree_node = b.add_virtual_bool_target_safe(); - let is_left_child = b.add_virtual_bool_target_unsafe(); - let [min_query, max_query] = [0; 2].map(|_| b.add_virtual_u256_unsafe()); - - // Check the consistency for the subtree proof and child proof. - constrain_input_proofs( - b, - is_rows_tree_node, - &min_query, - &max_query, - subtree_proof, - slice::from_ref(child_proof), - ); - - // Choose the column ID and node value to be hashed depending on which tree - // the current node belongs to. - let index_ids = subtree_proof.index_ids_target(); - let column_id = b.select(is_rows_tree_node, index_ids[1], index_ids[0]); - let index_value = subtree_proof.index_value_target(); - let node_value = b.select_u256( - is_rows_tree_node, - &subtree_proof.min_value_target(), - &index_value, - ); - - let node_min = b.select_u256(is_left_child, &child_proof.min_value_target(), &node_value); - let node_max = b.select_u256(is_left_child, &node_value, &child_proof.max_value_target()); - - // Compute the node hash: - // H(left_child.H || right_child.H || node_min || node_max || column_id || node_value || p.H)) - let rest: Vec<_> = node_min - .to_targets() - .into_iter() - .chain(node_max.to_targets()) - .chain(iter::once(column_id)) - .chain(node_value.to_targets()) - .chain(subtree_proof.tree_hash_target().to_targets()) - .collect(); - let node_hash = hash_maybe_first( - b, - is_left_child, - empty_hash.elements, - child_proof.tree_hash_target().elements, - &rest, - ); - - // Aggregate the output values of children and the overflow number. - let mut num_overflows = zero; - let mut aggregated_values = vec![]; - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = compute_output_item(b, i, &[subtree_proof, child_proof]); - - aggregated_values.append(&mut output); - num_overflows = b.add(num_overflows, overflow); - } - - // count = current.count + child.count - let count = b.add( - subtree_proof.num_matching_rows_target(), - child_proof.num_matching_rows_target(), - ); - - // overflow = (pC.overflow + pR.overflow + num_overflows) != 0 - let overflow = b.add_many([ - subtree_proof.to_overflow_raw(), - child_proof.to_overflow_raw(), - &num_overflows, - ]); - let overflow = b.is_not_equal(overflow, zero); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - aggregated_values.as_slice(), - &[count], - subtree_proof.to_ops_raw(), - subtree_proof.to_index_value_raw(), - &node_min.to_targets(), - &node_max.to_targets(), - subtree_proof.to_index_ids_raw(), - &min_query.to_targets(), - &max_query.to_targets(), - &[overflow.target], - subtree_proof.to_computational_hash_raw(), - subtree_proof.to_placeholder_hash_raw(), - ) - .register(b); - - FullNodeWithOneChildWires { - is_rows_tree_node, - is_left_child, - min_query, - max_query, - } - } - - fn assign( - &self, - pw: &mut PartialWitness, - wires: &FullNodeWithOneChildWires, - ) { - pw.set_bool_target(wires.is_rows_tree_node, self.is_rows_tree_node); - pw.set_bool_target(wires.is_left_child, self.is_left_child); - pw.set_u256_target(&wires.min_query, self.min_query); - pw.set_u256_target(&wires.max_query, self.max_query); - } -} - -/// Subtree proof number = 1, child proof number = 1 -pub(crate) const NUM_VERIFIED_PROOFS: usize = 2; - -impl CircuitLogicWires - for FullNodeWithOneChildWires -where - [(); MAX_NUM_RESULTS - 1]:, -{ - type CircuitBuilderParams = (); - type Inputs = FullNodeWithOneChildCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - // The first one is the subtree proof, and the second is the child proof. - let [subtree_proof, child_proof] = - verified_proofs.map(|p| PublicInputs::from_slice(&p.public_inputs)); - - Self::Inputs::build(builder, &subtree_proof, &child_proof) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - query::{ - aggregation::{ - tests::compute_output_item_value, - utils::tests::{unify_child_proof, unify_subtree_proof}, - }, - pi_len, - }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - use mp2_common::{poseidon::H, utils::ToFields, C}; - use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; - use std::array; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestFullNodeWithOneChildCircuit<'a> { - c: FullNodeWithOneChildCircuit, - subtree_proof: &'a [F], - child_proof: &'a [F], - } - - impl UserCircuit for TestFullNodeWithOneChildCircuit<'_> { - // Circuit wires + subtree proof + child proof - type Wires = ( - FullNodeWithOneChildWires, - Vec, - Vec, - ); - - fn build(b: &mut CBuilder) -> Self::Wires { - let proofs = array::from_fn(|_| { - b.add_virtual_target_arr::<{ pi_len::() }>() - .to_vec() - }); - let [subtree_pi, child_pi] = - array::from_fn(|i| PublicInputs::::from_slice(&proofs[i])); - - let wires = FullNodeWithOneChildCircuit::build(b, &subtree_pi, &child_pi); - - let [subtree_proof, child_proof] = proofs; - - (wires, subtree_proof, child_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.subtree_proof); - pw.set_target_arr(&wires.2, self.child_proof); - } - } - - fn test_full_node_with_one_child_circuit(is_rows_tree_node: bool, is_left_child: bool) { - let min_query = U256::from(100); - let max_query = U256::from(200); - - // Generate the input proofs. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - let [mut subtree_proof, mut child_proof] = random_aggregation_public_inputs(&ops); - unify_subtree_proof::( - &mut subtree_proof, - is_rows_tree_node, - min_query, - max_query, - ); - let subtree_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&subtree_proof); - unify_child_proof::( - &mut child_proof, - is_rows_tree_node, - min_query, - max_query, - &subtree_pi, - ); - let child_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&child_proof); - - // Construct the expected public input values. - let index_ids = subtree_pi.index_ids(); - let index_value = subtree_pi.index_value(); - let node_value = if is_rows_tree_node { - subtree_pi.min_value() - } else { - index_value - }; - let [node_min, node_max] = if is_left_child { - [child_pi.min_value(), node_value] - } else { - [node_value, child_pi.max_value()] - }; - - // Construct the test circuit. - let test_circuit = TestFullNodeWithOneChildCircuit { - c: FullNodeWithOneChildCircuit { - is_rows_tree_node, - is_left_child, - min_query, - max_query, - }, - subtree_proof: &subtree_proof, - child_proof: &child_proof, - }; - - // Prove for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - // Check the public inputs. - // Tree hash - { - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - let empty_hash = empty_poseidon_hash(); - let child_hash = child_pi.tree_hash(); - let [left_child_hash, right_child_hash] = if is_left_child { - [child_hash, *empty_hash] - } else { - [*empty_hash, child_hash] - }; - - // H(left_child.H || right_child.H || node_min || node_max || column_id || node_value || p.H)) - let inputs: Vec<_> = left_child_hash - .to_fields() - .into_iter() - .chain(right_child_hash.to_fields()) - .chain(node_min.to_fields()) - .chain(node_max.to_fields()) - .chain(iter::once(column_id)) - .chain(node_value.to_fields()) - .chain(subtree_pi.tree_hash().to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values and overflow flag - { - let mut num_overflows = 0; - let mut aggregated_values = vec![]; - - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = - compute_output_item_value(i, &[&subtree_pi, &child_pi]); - - aggregated_values.append(&mut output); - num_overflows += overflow; - } - - assert_eq!(pi.to_values_raw(), aggregated_values); - assert_eq!( - pi.overflow_flag(), - subtree_pi.overflow_flag() || child_pi.overflow_flag() || num_overflows != 0 - ); - } - // Count - assert_eq!( - pi.num_matching_rows(), - subtree_pi.num_matching_rows() + child_pi.num_matching_rows(), - ); - // Operation IDs - assert_eq!(pi.operation_ids(), subtree_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), node_min); - // Maximum value - assert_eq!(pi.max_value(), node_max); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Computational hash - assert_eq!(pi.computational_hash(), subtree_pi.computational_hash()); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), subtree_pi.placeholder_hash()); - } - - #[test] - fn test_query_agg_full_node_with_one_child_for_row_node_with_left_child() { - test_full_node_with_one_child_circuit(true, true); - } - - #[test] - fn test_query_agg_full_node_with_one_child_for_row_node_with_right_child() { - test_full_node_with_one_child_circuit(true, false); - } - - #[test] - fn test_query_agg_full_node_with_one_child_for_index_node_with_left_child() { - test_full_node_with_one_child_circuit(false, true); - } - - #[test] - fn test_query_agg_full_node_with_one_child_for_index_node_with_right_child() { - test_full_node_with_one_child_circuit(false, false); - } -} diff --git a/verifiable-db/src/query/aggregation/full_node_with_two_children.rs b/verifiable-db/src/query/aggregation/full_node_with_two_children.rs deleted file mode 100644 index 1594e2ecb..000000000 --- a/verifiable-db/src/query/aggregation/full_node_with_two_children.rs +++ /dev/null @@ -1,398 +0,0 @@ -//! Module handling the full node with two children for query aggregation circuits - -use crate::query::{ - aggregation::{output_computation::compute_output_item, utils::constrain_input_proofs}, - public_inputs::PublicInputs, -}; -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - poseidon::H, - public_inputs::PublicInputCommon, - serialization::{deserialize, serialize}, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::ToTargets, - D, F, -}; -use plonky2::{ - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::iter; - -/// Full node wires with two children -/// The constant generic parameter is only used for impl `CircuitLogicWires`. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeWithTwoChildrenWires { - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, - min_query: UInt256Target, - max_query: UInt256Target, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeWithTwoChildrenCircuit { - /// The flag specified if the proof is generated for a node in a rows tree or - /// for a node in the index tree - pub(crate) is_rows_tree_node: bool, - /// Minimum range bound specified in the query for the indexed column - /// It's a range bound for the primary indexed column for index tree, - /// and secondary indexed column for rows tree. - pub(crate) min_query: U256, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: U256, -} - -impl FullNodeWithTwoChildrenCircuit { - pub fn build( - b: &mut CBuilder, - subtree_proof: &PublicInputs, - child_proofs: &[PublicInputs; 2], - ) -> FullNodeWithTwoChildrenWires - where - [(); MAX_NUM_RESULTS - 1]:, - { - let zero = b.zero(); - - let is_rows_tree_node = b.add_virtual_bool_target_safe(); - let [min_query, max_query] = [0; 2].map(|_| b.add_virtual_u256_unsafe()); - - // Check the consistency for the subtree proof and child proofs. - constrain_input_proofs( - b, - is_rows_tree_node, - &min_query, - &max_query, - subtree_proof, - child_proofs, - ); - - // Choose the column ID and node value to be hashed depending on which tree - // the current node belongs to. - let index_ids = subtree_proof.index_ids_target(); - let column_id = b.select(is_rows_tree_node, index_ids[1], index_ids[0]); - let index_value = subtree_proof.index_value_target(); - let node_value = b.select_u256( - is_rows_tree_node, - &subtree_proof.min_value_target(), - &index_value, - ); - - // Compute the node hash: - // node_hash = H(p1.H || p2.H || p1.min || p2.max || column_id || node_value || p.H) - let [child_proof1, child_proof2] = child_proofs; - let inputs = child_proof1 - .tree_hash_target() - .to_targets() - .into_iter() - .chain(child_proof2.tree_hash_target().to_targets()) - .chain(child_proof1.min_value_target().to_targets()) - .chain(child_proof2.max_value_target().to_targets()) - .chain(iter::once(column_id)) - .chain(node_value.to_targets()) - .chain(subtree_proof.tree_hash_target().to_targets()) - .collect(); - let node_hash = b.hash_n_to_hash_no_pad::(inputs); - - // Aggregate the output values of children and the overflow number. - let mut num_overflows = zero; - let mut aggregated_values = vec![]; - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = - compute_output_item(b, i, &[subtree_proof, child_proof1, child_proof2]); - - aggregated_values.append(&mut output); - num_overflows = b.add(num_overflows, overflow); - } - - // count = p1.count + p2.count + p.count - let count = b.add( - child_proof1.num_matching_rows_target(), - child_proof2.num_matching_rows_target(), - ); - let count = b.add(count, subtree_proof.num_matching_rows_target()); - - // overflow = (p.overflow + p1.overflow + p2.overflow + num_overflows) != 0 - let overflow = b.add_many([ - subtree_proof.to_overflow_raw(), - child_proof1.to_overflow_raw(), - child_proof2.to_overflow_raw(), - &num_overflows, - ]); - let overflow = b.is_not_equal(overflow, zero); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - aggregated_values.as_slice(), - &[count], - subtree_proof.to_ops_raw(), - subtree_proof.to_index_value_raw(), - child_proof1.to_min_value_raw(), - child_proof2.to_max_value_raw(), - subtree_proof.to_index_ids_raw(), - &min_query.to_targets(), - &max_query.to_targets(), - &[overflow.target], - subtree_proof.to_computational_hash_raw(), - subtree_proof.to_placeholder_hash_raw(), - ) - .register(b); - - FullNodeWithTwoChildrenWires { - is_rows_tree_node, - min_query, - max_query, - } - } - - fn assign( - &self, - pw: &mut PartialWitness, - wires: &FullNodeWithTwoChildrenWires, - ) { - pw.set_bool_target(wires.is_rows_tree_node, self.is_rows_tree_node); - pw.set_u256_target(&wires.min_query, self.min_query); - pw.set_u256_target(&wires.max_query, self.max_query); - } -} - -/// Subtree proof number = 1, child proof number = 2 -pub(crate) const NUM_VERIFIED_PROOFS: usize = 3; - -impl CircuitLogicWires - for FullNodeWithTwoChildrenWires -where - [(); MAX_NUM_RESULTS - 1]:, -{ - type CircuitBuilderParams = (); - type Inputs = FullNodeWithTwoChildrenCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - // The first one is the subtree proof, and the remainings are child proofs. - let [subtree_proof, child_proof1, child_proof2] = - verified_proofs.map(|p| PublicInputs::from_slice(&p.public_inputs)); - - Self::Inputs::build(builder, &subtree_proof, &[child_proof1, child_proof2]) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - query::{ - aggregation::{ - tests::compute_output_item_value, - utils::tests::{unify_child_proof, unify_subtree_proof}, - }, - pi_len, - }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - use mp2_common::{utils::ToFields, C}; - use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; - use std::array; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestFullNodeWithTwoChildrenCircuit<'a> { - c: FullNodeWithTwoChildrenCircuit, - subtree_proof: &'a [F], - left_child_proof: &'a [F], - right_child_proof: &'a [F], - } - - impl UserCircuit for TestFullNodeWithTwoChildrenCircuit<'_> { - // Circuit wires + subtree proof + left child proof + right child proof - type Wires = ( - FullNodeWithTwoChildrenWires, - Vec, - Vec, - Vec, - ); - - fn build(b: &mut CBuilder) -> Self::Wires { - let proofs = array::from_fn(|_| { - b.add_virtual_target_arr::<{ pi_len::() }>() - .to_vec() - }); - let [subtree_pi, left_child_pi, right_child_pi] = - array::from_fn(|i| PublicInputs::::from_slice(&proofs[i])); - - let wires = FullNodeWithTwoChildrenCircuit::build( - b, - &subtree_pi, - &[left_child_pi, right_child_pi], - ); - - let [subtree_proof, left_child_proof, right_child_proof] = proofs; - - (wires, subtree_proof, left_child_proof, right_child_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.subtree_proof); - pw.set_target_arr(&wires.2, self.left_child_proof); - pw.set_target_arr(&wires.3, self.right_child_proof); - } - } - - fn test_full_node_with_two_children_circuit(is_rows_tree_node: bool) { - let min_query = U256::from(100); - let max_query = U256::from(200); - - // Generate the input proofs. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - let [mut subtree_proof, mut left_child_proof, mut right_child_proof] = - random_aggregation_public_inputs(&ops); - unify_subtree_proof::( - &mut subtree_proof, - is_rows_tree_node, - min_query, - max_query, - ); - let subtree_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&subtree_proof); - [&mut left_child_proof, &mut right_child_proof] - .iter_mut() - .for_each(|p| { - unify_child_proof::( - p, - is_rows_tree_node, - min_query, - max_query, - &subtree_pi, - ) - }); - let left_child_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&left_child_proof); - let right_child_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&right_child_proof); - - // Construct the expected public input values. - let index_ids = subtree_pi.index_ids(); - let index_value = subtree_pi.index_value(); - let node_value = if is_rows_tree_node { - subtree_pi.min_value() - } else { - index_value - }; - - // Construct the test circuit. - let test_circuit = TestFullNodeWithTwoChildrenCircuit { - c: FullNodeWithTwoChildrenCircuit { - is_rows_tree_node, - min_query, - max_query, - }, - subtree_proof: &subtree_proof, - left_child_proof: &left_child_proof, - right_child_proof: &right_child_proof, - }; - - // Prove for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - // Check the public inputs. - // Tree hash - { - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - - // H(p1.H || p2.H || p1.min || p2.max || column_id || node_value || p.H) - let inputs: Vec<_> = left_child_pi - .tree_hash() - .to_fields() - .into_iter() - .chain(right_child_pi.tree_hash().to_fields()) - .chain(left_child_pi.min_value().to_fields()) - .chain(right_child_pi.max_value().to_fields()) - .chain(iter::once(column_id)) - .chain(node_value.to_fields()) - .chain(subtree_pi.tree_hash().to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values and overflow flag - { - let mut num_overflows = 0; - let mut aggregated_values = vec![]; - - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = - compute_output_item_value(i, &[&subtree_pi, &left_child_pi, &right_child_pi]); - - aggregated_values.append(&mut output); - num_overflows += overflow; - } - - assert_eq!(pi.to_values_raw(), aggregated_values); - assert_eq!( - pi.overflow_flag(), - subtree_pi.overflow_flag() - || left_child_pi.overflow_flag() - || right_child_pi.overflow_flag() - || num_overflows != 0 - ); - } - // Count - assert_eq!( - pi.num_matching_rows(), - subtree_pi.num_matching_rows() - + left_child_pi.num_matching_rows() - + right_child_pi.num_matching_rows(), - ); - // Operation IDs - assert_eq!(pi.operation_ids(), subtree_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), left_child_pi.min_value()); - // Maximum value - assert_eq!(pi.max_value(), right_child_pi.max_value()); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Computational hash - assert_eq!(pi.computational_hash(), subtree_pi.computational_hash()); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), subtree_pi.placeholder_hash()); - } - - #[test] - fn test_query_agg_full_node_with_two_children_for_row_node() { - test_full_node_with_two_children_circuit(true); - } - - #[test] - fn test_query_agg_full_node_with_two_children_for_index_node() { - test_full_node_with_two_children_circuit(false); - } -} diff --git a/verifiable-db/src/query/aggregation/non_existence_inter.rs b/verifiable-db/src/query/aggregation/non_existence_inter.rs deleted file mode 100644 index 7ed147e5d..000000000 --- a/verifiable-db/src/query/aggregation/non_existence_inter.rs +++ /dev/null @@ -1,761 +0,0 @@ -//! Module handling the non-existence intermediate node for query aggregation circuits - -use crate::query::{ - aggregation::output_computation::compute_dummy_output_targets, - public_inputs::PublicInputs, - universal_circuit::universal_query_gadget::{ - QueryBound, QueryBoundTarget, QueryBoundTargetInputs, - }, -}; -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - poseidon::{empty_poseidon_hash, H}, - public_inputs::PublicInputCommon, - serialization::{ - deserialize, deserialize_array, deserialize_long_array, serialize, serialize_array, - serialize_long_array, - }, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::{HashBuilder, ToTargets}, - D, F, -}; -use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget}, - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::{array, iter}; - -/// Non-existence intermediate node wires -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct NonExistenceInterNodeWires { - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, - min_query: QueryBoundTargetInputs, - max_query: QueryBoundTargetInputs, - value: UInt256Target, - index_value: UInt256Target, - index_ids: [Target; 2], - #[serde( - serialize_with = "serialize_long_array", - deserialize_with = "deserialize_long_array" - )] - ops: [Target; MAX_NUM_RESULTS], - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - subtree_hash: HashOutTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - computational_hash: HashOutTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - placeholder_hash: HashOutTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - left_child_min: UInt256Target, - left_child_max: UInt256Target, - left_child_value: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - left_tree_hash: HashOutTarget, - #[serde( - serialize_with = "serialize_array", - deserialize_with = "deserialize_array" - )] - left_grand_children: [HashOutTarget; 2], - right_child_min: UInt256Target, - right_child_max: UInt256Target, - right_child_value: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - right_tree_hash: HashOutTarget, - #[serde( - serialize_with = "serialize_array", - deserialize_with = "deserialize_array" - )] - right_grand_children: [HashOutTarget; 2], - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - left_child_exists: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - right_child_exists: BoolTarget, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct NonExistenceInterNodeCircuit { - /// The flag specified if the proof is generated for a node in a rows tree or - /// for a node in the index tree - pub(crate) is_rows_tree_node: bool, - /// Minimum range bound specified in the query for the indexed column - /// It's a range bound for the primary indexed column for index tree, - /// and secondary indexed column for rows tree. - pub(crate) min_query: QueryBound, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: QueryBound, - pub(crate) value: U256, - /// Value of the indexed column for the row stored in the current node - /// (meaningful only if the current node belongs to a rows tree, - /// can be equal to `value` if the current node belongs to the index tree) - pub(crate) index_value: U256, - /// Integer identifiers of the indexed columns - pub(crate) index_ids: [F; 2], - /// Set of identifiers of the aggregation operations for each of the `S` items found in `V` - #[serde( - serialize_with = "serialize_long_array", - deserialize_with = "deserialize_long_array" - )] - pub(crate) ops: [F; MAX_NUM_RESULTS], - /// Hash of the tree stored in the current node - pub(crate) subtree_hash: HashOut, - /// Computational hash associated to the processing of single rows of the query - /// (meaningless in this case, we just need to provide it for public input compliance) - pub(crate) computational_hash: HashOut, - /// Placeholder hash associated to the processing of single rows of the query - /// (meaningless in this case, we just need to provide it for public input compliance) - pub(crate) placeholder_hash: HashOut, - /// Minimum value associated to the left child - pub(crate) left_child_min: U256, - /// Maximum value associated to the left child - pub(crate) left_child_max: U256, - /// Value stored in the left child - pub(crate) left_child_value: U256, - /// Hashes of the row/rows tree stored in the left child - pub(crate) left_tree_hash: HashOut, - /// Hashes of the children nodes of the left child - pub(crate) left_grand_children: [HashOut; 2], - /// Minimum value associated to the right child - pub(crate) right_child_min: U256, - /// Maximum value associated to the right child - pub(crate) right_child_max: U256, - /// Value stored in the right child - pub(crate) right_child_value: U256, - /// Hashes of the row/rows tree stored in the right child - pub(crate) right_tree_hash: HashOut, - /// Hashes of the children nodes of the right child - pub(crate) right_grand_children: [HashOut; 2], - /// Boolean flag specifying whether there is a left child for the current node - pub(crate) left_child_exists: bool, - /// Boolean flag specifying whether there is a right child for the current node - pub(crate) right_child_exists: bool, -} - -impl NonExistenceInterNodeCircuit { - pub fn build(b: &mut CBuilder) -> NonExistenceInterNodeWires { - let ttrue = b._true(); - let ffalse = b._false(); - let zero = b.zero(); - let empty_hash = b.constant_hash(*empty_poseidon_hash()); - - let is_rows_tree_node = b.add_virtual_bool_target_safe(); - let left_child_exists = b.add_virtual_bool_target_safe(); - let right_child_exists = b.add_virtual_bool_target_safe(); - // Initialize as unsafe, since all these Uint256s are either exposed as - // public inputs or passed as inputs for hash computation. - let [value, index_value, left_child_value, left_child_min, left_child_max, right_child_value, right_child_min, right_child_max] = - b.add_virtual_u256_arr_unsafe(); - // compute min and max query bounds for secondary index - - let index_ids = b.add_virtual_target_arr(); - let ops = b.add_virtual_target_arr(); - let [subtree_hash, computational_hash, placeholder_hash, left_child_subtree_hash, left_grand_child_hash1, left_grand_child_hash2, right_child_subtree_hash, right_grand_child_hash1, right_grand_child_hash2] = - array::from_fn(|_| b.add_virtual_hash()); - - let min_query = QueryBoundTarget::new(b); - let max_query = QueryBoundTarget::new(b); - - let min_query_value = min_query.get_bound_value(); - let max_query_value = max_query.get_bound_value(); - - let [min_query_targets, max_query_targets] = - [&min_query_value, &max_query_value].map(|v| v.to_targets()); - let column_id = b.select(is_rows_tree_node, index_ids[1], index_ids[0]); - - // Enforce that the value associated to the current node is out of the range - // specified by the query: - // value < MIN_query OR value > MAX_query - let is_value_less_than_min = b.is_less_than_u256(&value, min_query_value); - let is_value_greater_than_max = b.is_less_than_u256(max_query_value, &value); - let is_out_of_range = b.or(is_value_less_than_min, is_value_greater_than_max); - b.connect(is_out_of_range.target, ttrue.target); - - // Enforce that the records found in the subtree rooted in the child node - // are all out of the range specified by the query. If left child exists, - // ensure left_child_max < MIN_query; if right child exists, ensure right_child_min > MAX_query. - let is_child_less_than_min = b.is_less_than_u256(&left_child_max, min_query_value); - let is_left_child_out_of_range = b.and(left_child_exists, is_child_less_than_min); - b.connect(is_left_child_out_of_range.target, left_child_exists.target); - let is_child_greater_than_max = b.is_less_than_u256(max_query_value, &right_child_min); - let is_right_child_out_of_range = b.and(right_child_exists, is_child_greater_than_max); - b.connect( - is_right_child_out_of_range.target, - right_child_exists.target, - ); - - // Compute dummy values for each of the `S` values to be returned as output. - let outputs = compute_dummy_output_targets(b, &ops); - - // Recompute hash of left child node to bind left_child_min and left_child_max inputs: - // H(h1 || h2 || child_min || child_max || column_id || child_value || child_subtree_hash) - let inputs = left_grand_child_hash1 - .to_targets() - .into_iter() - .chain(left_grand_child_hash2.to_targets()) - .chain(left_child_min.to_targets()) - .chain(left_child_max.to_targets()) - .chain(iter::once(column_id)) - .chain(left_child_value.to_targets()) - .chain(left_child_subtree_hash.to_targets()) - .collect(); - let left_child_hash = b.hash_n_to_hash_no_pad::(inputs); - - let left_child_hash = b.select_hash(left_child_exists, &left_child_hash, &empty_hash); - - // Recompute hash of right child node to bind right_child_min and right_child_max inputs: - // H(h1 || h2 || child_min || child_max || column_id || child_value || child_subtree_hash) - let inputs = right_grand_child_hash1 - .to_targets() - .into_iter() - .chain(right_grand_child_hash2.to_targets()) - .chain(right_child_min.to_targets()) - .chain(right_child_max.to_targets()) - .chain(iter::once(column_id)) - .chain(right_child_value.to_targets()) - .chain(right_child_subtree_hash.to_targets()) - .collect(); - let right_child_hash = b.hash_n_to_hash_no_pad::(inputs); - - let right_child_hash = b.select_hash(right_child_exists, &right_child_hash, &empty_hash); - - // node_min = left_child_exists ? left_child_min : value - let node_min = b.select_u256(left_child_exists, &left_child_min, &value); - // node_max = right_child_exists ? right_child_max : value - let node_max = b.select_u256(right_child_exists, &right_child_max, &value); - let [node_min_targets, node_max_targets] = [node_min, node_max].map(|u| u.to_targets()); - - // Compute the node hash: - // H(left_child_hash || right_child_hash || node_min || node_max || column_id || value || subtree_hash) - let inputs = left_child_hash - .to_targets() - .into_iter() - .chain(right_child_hash.to_targets()) - .chain(node_min_targets.clone()) - .chain(node_max_targets.clone()) - .chain(iter::once(column_id)) - .chain(value.to_targets()) - .chain(subtree_hash.to_targets()) - .collect(); - let node_hash = b.hash_n_to_hash_no_pad::(inputs); - - // We add the query bounds to the placeholder hash only if the current node is in a rows tree. - let placeholder_hash_with_query_bounds = - QueryBoundTarget::add_query_bounds_to_placeholder_hash( - b, - &min_query, - &max_query, - &placeholder_hash, - ); - let new_placeholder_hash = b.select_hash( - is_rows_tree_node, - &placeholder_hash_with_query_bounds, - &placeholder_hash, - ); - // We add the query bounds to the computational hash only if the current - // node is in a rows tree. - let computational_hash_with_query_bounds = - QueryBoundTarget::add_query_bounds_to_computational_hash( - b, - &min_query, - &max_query, - &computational_hash, - ); - let new_computational_hash = b.select_hash( - is_rows_tree_node, - &computational_hash_with_query_bounds, - &computational_hash, - ); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - outputs.as_slice(), - &[zero], - &ops, - &index_value.to_targets(), - &node_min_targets, - &node_max_targets, - &index_ids, - &min_query_targets, - &max_query_targets, - &[ffalse.target], - &new_computational_hash.to_targets(), - &new_placeholder_hash.to_targets(), - ) - .register(b); - - let left_grand_children = [left_grand_child_hash1, left_grand_child_hash2]; - let right_grand_children = [right_grand_child_hash1, right_grand_child_hash2]; - - NonExistenceInterNodeWires { - is_rows_tree_node, - left_child_exists, - right_child_exists, - min_query: min_query.into(), - max_query: max_query.into(), - value, - index_value, - left_child_value, - left_child_min, - left_child_max, - right_child_value, - right_child_min, - right_child_max, - index_ids, - ops, - subtree_hash, - computational_hash, - placeholder_hash, - left_tree_hash: left_child_subtree_hash, - left_grand_children, - right_tree_hash: right_child_subtree_hash, - right_grand_children, - } - } - - fn assign( - &self, - pw: &mut PartialWitness, - wires: &NonExistenceInterNodeWires, - ) { - [ - (wires.is_rows_tree_node, self.is_rows_tree_node), - (wires.left_child_exists, self.left_child_exists), - (wires.right_child_exists, self.right_child_exists), - ] - .iter() - .for_each(|(t, v)| pw.set_bool_target(*t, *v)); - [ - (&wires.value, self.value), - (&wires.index_value, self.index_value), - (&wires.left_child_value, self.left_child_value), - (&wires.left_child_min, self.left_child_min), - (&wires.left_child_max, self.left_child_max), - (&wires.right_child_value, self.right_child_value), - (&wires.right_child_min, self.right_child_min), - (&wires.right_child_max, self.right_child_max), - ] - .iter() - .for_each(|(t, v)| pw.set_u256_target(t, *v)); - wires.min_query.assign(pw, &self.min_query); - wires.max_query.assign(pw, &self.max_query); - pw.set_target_arr(&wires.index_ids, &self.index_ids); - pw.set_target_arr(&wires.ops, &self.ops); - [ - (wires.subtree_hash, self.subtree_hash), - (wires.computational_hash, self.computational_hash), - (wires.placeholder_hash, self.placeholder_hash), - (wires.left_tree_hash, self.left_tree_hash), - (wires.right_tree_hash, self.right_tree_hash), - ] - .iter() - .for_each(|(t, v)| pw.set_hash_target(*t, *v)); - wires - .left_grand_children - .iter() - .zip(self.left_grand_children) - .for_each(|(t, v)| pw.set_hash_target(*t, v)); - wires - .right_grand_children - .iter() - .zip(self.right_grand_children) - .for_each(|(t, v)| pw.set_hash_target(*t, v)); - } -} - -/// Verified proof number = 0 -pub(crate) const NUM_VERIFIED_PROOFS: usize = 0; - -impl CircuitLogicWires - for NonExistenceInterNodeWires -{ - type CircuitBuilderParams = (); - type Inputs = NonExistenceInterNodeCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - _verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - Self::Inputs::build(builder) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::{ - query::{ - aggregation::{ - output_computation::tests::compute_dummy_output_values, QueryBoundSource, - QueryBounds, - }, - computational_hash_ids::{AggregationOperation, Identifiers}, - universal_circuit::universal_circuit_inputs::{PlaceholderId, Placeholders}, - }, - test_utils::random_aggregation_operations, - }; - use mp2_common::{array::ToField, poseidon::H, utils::ToFields, C}; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::gen_random_field_hash, - }; - use plonky2::{ - field::types::{Field, Sample}, - plonk::config::Hasher, - }; - - use rand::{prelude::SliceRandom, thread_rng, Rng}; - - const MAX_NUM_RESULTS: usize = 20; - - impl UserCircuit for NonExistenceInterNodeCircuit { - type Wires = NonExistenceInterNodeWires; - - fn build(b: &mut CBuilder) -> Self::Wires { - NonExistenceInterNodeCircuit::build(b) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.assign(pw, wires); - } - } - - fn test_non_existence_inter_circuit( - is_rows_tree_node: bool, - left_child_exists: bool, - right_child_exists: bool, - ops: [F; MAX_NUM_RESULTS], - ) { - let min_query_value = U256::from(1000); - let max_query_value = U256::from(3000); - - let mut rng = &mut thread_rng(); - // value < MIN_query OR value > MAX_query - let value = *[ - min_query_value - U256::from(1), - max_query_value + U256::from(1), - ] - .choose(&mut rng) - .unwrap(); - let [left_child_min, left_child_max] = if left_child_exists { - // left_child_max < MIN_query - [U256::from_limbs(rng.gen()), min_query_value - U256::from(1)] - } else { - // no constraints otherwise - [U256::from_limbs(rng.gen()), U256::from_limbs(rng.gen())] - }; - let [right_child_min, right_child_max] = if right_child_exists { - // right_child_min > MAX_query - [max_query_value + U256::from(1), U256::from_limbs(rng.gen())] - } else { - // no constraints otherwise - [U256::from_limbs(rng.gen()), U256::from_limbs(rng.gen())] - }; - let [index_value, left_child_value, right_child_value] = - array::from_fn(|_| U256::from_limbs(rng.gen())); - let index_ids = F::rand_array(); - let [subtree_hash, computational_hash, placeholder_hash, left_child_subtree_hash, left_grand_child_hash1, left_grand_child_hash2, right_child_subtree_hash, right_grand_child_hash1, right_grand_child_hash2] = - array::from_fn(|_| gen_random_field_hash()); - let left_grand_children = [left_grand_child_hash1, left_grand_child_hash2]; - let right_grand_children = [right_grand_child_hash1, right_grand_child_hash2]; - - let first_placeholder_id = PlaceholderId::Generic(0); - - let (min_query, max_query, _placeholders) = if is_rows_tree_node { - let dummy_min_query_primary = U256::ZERO; //dummy value, circuit will employ only bounds for secondary index - let dummy_max_query_primary = U256::MAX; //dummy value, circuit will employ only bounds for secondary index - let placeholders = Placeholders::from(( - vec![(first_placeholder_id, max_query_value)], - dummy_min_query_primary, - dummy_max_query_primary, - )); - - let query_bounds = QueryBounds::new( - &placeholders, - Some(QueryBoundSource::Constant(min_query_value)), - Some(QueryBoundSource::Placeholder(first_placeholder_id)), - ) - .unwrap(); - ( - QueryBound::new_secondary_index_bound( - &placeholders, - &query_bounds.min_query_secondary, - ) - .unwrap(), - QueryBound::new_secondary_index_bound( - &placeholders, - &query_bounds.max_query_secondary, - ) - .unwrap(), - placeholders, - ) - } else { - // min_query and max_query should be primary index bounds - let placeholders = Placeholders::new_empty(min_query_value, max_query_value); - ( - QueryBound::new_primary_index_bound(&placeholders, true).unwrap(), - QueryBound::new_primary_index_bound(&placeholders, false).unwrap(), - placeholders, - ) - }; - - // Construct the test circuit. - let test_circuit = NonExistenceInterNodeCircuit { - is_rows_tree_node, - left_child_exists, - right_child_exists, - min_query: min_query.clone(), - max_query: max_query.clone(), - value, - index_value, - left_child_value, - left_child_min, - left_child_max, - index_ids, - ops, - subtree_hash, - computational_hash, - placeholder_hash, - left_tree_hash: left_child_subtree_hash, - left_grand_children, - right_child_value, - right_child_min, - right_child_max, - right_tree_hash: right_child_subtree_hash, - right_grand_children, - }; - - // Prove for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - // node_min = is_left_child ? child_min : value - // node_max = is_left_child ? value : child_max - let node_min = if left_child_exists { - left_child_min - } else { - value - }; - let node_max = if right_child_exists { - right_child_max - } else { - value - }; - - // Check the public inputs. - // Tree hash - { - let empty_hash = empty_poseidon_hash(); - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - - // H(h1 || h2 || child_min || child_max || column_id || child_value || child_subtree_hash) - let inputs: Vec<_> = left_grand_child_hash1 - .to_fields() - .into_iter() - .chain(left_grand_child_hash2.to_fields()) - .chain(left_child_min.to_fields()) - .chain(left_child_max.to_fields()) - .chain(iter::once(column_id)) - .chain(left_child_value.to_fields()) - .chain(left_child_subtree_hash.to_fields()) - .collect(); - let left_child_hash = H::hash_no_pad(&inputs); - - let left_child_hash = if left_child_exists { - left_child_hash - } else { - *empty_hash - }; - - // H(h1 || h2 || child_min || child_max || column_id || child_value || child_subtree_hash) - let inputs: Vec<_> = right_grand_child_hash1 - .to_fields() - .into_iter() - .chain(right_grand_child_hash2.to_fields()) - .chain(right_child_min.to_fields()) - .chain(right_child_max.to_fields()) - .chain(iter::once(column_id)) - .chain(right_child_value.to_fields()) - .chain(right_child_subtree_hash.to_fields()) - .collect(); - let right_child_hash = H::hash_no_pad(&inputs); - - let right_child_hash = if right_child_exists { - right_child_hash - } else { - *empty_hash - }; - - // H(left_child_hash || right_child_hash || node_min || node_max || column_id || value || subtree_hash) - let inputs: Vec<_> = left_child_hash - .to_fields() - .into_iter() - .chain(right_child_hash.to_fields()) - .chain(node_min.to_fields()) - .chain(node_max.to_fields()) - .chain(iter::once(column_id)) - .chain(value.to_fields()) - .chain(subtree_hash.to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values - { - let outputs = compute_dummy_output_values(&ops); - assert_eq!(pi.to_values_raw(), outputs); - } - // Count - assert_eq!(pi.num_matching_rows(), F::ZERO); - // Operation IDs - assert_eq!(pi.operation_ids(), ops); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), node_min); - // Maximum value - assert_eq!(pi.max_value(), node_max); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query_value); - // Maximum query - assert_eq!(pi.max_query_value(), max_query_value); - // overflow_flag - assert!(!pi.overflow_flag()); - // Computational hash - { - let exp_hash = if is_rows_tree_node { - QueryBound::add_secondary_query_bounds_to_computational_hash( - &QueryBoundSource::Constant(min_query_value), - &QueryBoundSource::Placeholder(first_placeholder_id), - &computational_hash, - ) - .unwrap() - } else { - computational_hash - }; - assert_eq!(pi.computational_hash(), exp_hash); - } - // Placeholder hash - { - let exp_hash = if is_rows_tree_node { - QueryBound::add_secondary_query_bounds_to_placeholder_hash( - &min_query, - &max_query, - &placeholder_hash, - ) - } else { - placeholder_hash - }; - - assert_eq!(pi.placeholder_hash(), exp_hash); - } - } - - #[test] - fn test_query_agg_non_existence_inter_for_row_node_and_left_child() { - // Generate the random operations. - let mut ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Set the first operation to ID for testing the digest. - // The condition of the first aggregation operation ID is not associated - // with the `is_rows_tree_node` and `is_left_child` flag. - ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - test_non_existence_inter_circuit(true, true, false, ops); - } - - #[test] - fn test_query_agg_non_existence_inter_for_row_node_and_right_child() { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - test_non_existence_inter_circuit(true, false, true, ops); - } - - #[test] - fn test_query_agg_non_existence_inter_for_index_node_and_left_child() { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - test_non_existence_inter_circuit(false, true, false, ops); - } - - #[test] - fn test_query_agg_non_existence_inter_for_index_node_and_right_child() { - // Generate the random operations. - let mut ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Set the first operation to ID for testing the digest. - // The condition of the first aggregation operation ID is not associated - // with the `is_rows_tree_node` and `is_left_child` flag. - ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - test_non_existence_inter_circuit(false, false, true, ops); - } - - #[test] - fn test_query_agg_non_existence_for_row_tree_leaf_node() { - // Generate the random operations. - let mut ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Set the first operation to ID for testing the digest. - // The condition of the first aggregation operation ID is not associated - // with the `is_rows_tree_node` and `is_left_child` flag. - ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - test_non_existence_inter_circuit(true, false, false, ops); - } - - #[test] - fn test_query_agg_non_existence_for_index_tree_leaf_node() { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - test_non_existence_inter_circuit(false, false, false, ops); - } - - #[test] - fn test_query_agg_non_existence_for_row_tree_full_node() { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - test_non_existence_inter_circuit(true, true, true, ops); - } - - #[test] - fn test_query_agg_non_existence_for_index_tree_full_node() { - // Generate the random operations. - let mut ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Set the first operation to ID for testing the digest. - // The condition of the first aggregation operation ID is not associated - // with the `is_rows_tree_node` and `is_left_child` flag. - ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - test_non_existence_inter_circuit(false, true, true, ops); - } -} diff --git a/verifiable-db/src/query/aggregation/partial_node.rs b/verifiable-db/src/query/aggregation/partial_node.rs deleted file mode 100644 index 5e9119e6f..000000000 --- a/verifiable-db/src/query/aggregation/partial_node.rs +++ /dev/null @@ -1,519 +0,0 @@ -//! Module handling the partial node for query aggregation circuits - -use crate::query::{ - aggregation::{output_computation::compute_output_item, utils::constrain_input_proofs}, - public_inputs::PublicInputs, -}; -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - hash::hash_maybe_first, - poseidon::H, - public_inputs::PublicInputCommon, - serialization::{deserialize, deserialize_array, serialize, serialize_array}, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::ToTargets, - D, F, -}; -use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget}, - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::{array, iter, slice}; - -/// Partial node wires -/// The constant generic parameter is only used for impl `CircuitLogicWires`. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct PartialNodeWires { - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_left_child: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - sibling_tree_hash: HashOutTarget, - #[serde( - serialize_with = "serialize_array", - deserialize_with = "deserialize_array" - )] - sibling_child_hashes: [HashOutTarget; 2], - sibling_value: UInt256Target, - sibling_min: UInt256Target, - sibling_max: UInt256Target, - min_query: UInt256Target, - max_query: UInt256Target, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct PartialNodeCircuit { - /// The flag specified if the proof is generated for a node in a rows tree or - /// for a node in the index tree - pub(crate) is_rows_tree_node: bool, - /// The flag indicating if the proven child is the left child or right child - pub(crate) is_left_child: bool, - /// Hash of the rows tree stored in the sibling of the proven child - pub(crate) sibling_tree_hash: HashOut, - /// The child hashes of the proven child's sibling - pub(crate) sibling_child_hashes: [HashOut; 2], - /// Value of the indexed column for the rows tree stored in the sibling of - /// the proven child - pub(crate) sibling_value: U256, - /// Minimum value of the indexed column for the subtree rooted in the sibling - /// of the proven child - pub(crate) sibling_min: U256, - /// Maximum value of the indexed column for the subtree rooted in the sibling - /// of the proven child - pub(crate) sibling_max: U256, - /// Minimum range bound specified in the query for the indexed column - /// It's a range bound for the primary indexed column for index tree, - /// and secondary indexed column for rows tree. - pub(crate) min_query: U256, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: U256, -} - -impl PartialNodeCircuit { - pub fn build( - b: &mut CBuilder, - subtree_proof: &PublicInputs, - child_proof: &PublicInputs, - ) -> PartialNodeWires - where - [(); MAX_NUM_RESULTS - 1]:, - { - let ttrue = b._true(); - let zero = b.zero(); - - let is_rows_tree_node = b.add_virtual_bool_target_safe(); - let is_left_child = b.add_virtual_bool_target_unsafe(); - let [sibling_tree_hash, sibling_child_hash1, sibling_child_hash2] = - array::from_fn(|_| b.add_virtual_hash()); - let [sibling_value, sibling_min, sibling_max, min_query, max_query] = - array::from_fn(|_| b.add_virtual_u256_unsafe()); - - // Check the consistency for the subtree proof and child proof. - constrain_input_proofs( - b, - is_rows_tree_node, - &min_query, - &max_query, - subtree_proof, - slice::from_ref(child_proof), - ); - - // Check that the subtree rooted in sibling node contains only leaves with - // indexed columns values outside the query range. - // If the proved child is the left child, ensure sibling_min > MAX_query, - // otherwise sibling_max < MIN_query. - let is_greater_than_max = b.is_less_than_u256(&max_query, &sibling_min); - let is_less_than_min = b.is_less_than_u256(&sibling_max, &min_query); - let is_out_of_range = b.select( - is_left_child, - is_greater_than_max.target, - is_less_than_min.target, - ); - b.connect(is_out_of_range, ttrue.target); - - // Choose the column ID and node value to be hashed depending on which tree - // the current node belongs to. - let index_ids = subtree_proof.index_ids_target(); - let column_id = b.select(is_rows_tree_node, index_ids[1], index_ids[0]); - let index_value = subtree_proof.index_value_target(); - let node_value = b.select_u256( - is_rows_tree_node, - &subtree_proof.min_value_target(), - &index_value, - ); - - // Recompute the tree hash for the sibling node: - // H(h1 || h2 || sibling_min || sibling_max || column_id || sibling_value || sibling_tree_hash) - let inputs = sibling_child_hash1 - .to_targets() - .into_iter() - .chain(sibling_child_hash2.to_targets()) - .chain(sibling_min.to_targets()) - .chain(sibling_max.to_targets()) - .chain(iter::once(column_id)) - .chain(sibling_value.to_targets()) - .chain(sibling_tree_hash.to_targets()) - .collect(); - let sibling_hash = b.hash_n_to_hash_no_pad::(inputs); - - // node_min = is_left_child ? child.min : sibling_min - let node_min = b.select_u256(is_left_child, &child_proof.min_value_target(), &sibling_min); - // node_max = is_left_child ? sibling_max : child.max - let node_max = b.select_u256(is_left_child, &sibling_max, &child_proof.max_value_target()); - - // Compute the node hash: - // H(left_child_hash || right_child_hash || node_min || node_max || column_id || node_value || p.H) - let rest: Vec<_> = node_min - .to_targets() - .into_iter() - .chain(node_max.to_targets()) - .chain(iter::once(column_id)) - .chain(node_value.to_targets()) - .chain(subtree_proof.tree_hash_target().to_targets()) - .collect(); - let node_hash = hash_maybe_first( - b, - is_left_child, - sibling_hash.elements, - child_proof.tree_hash_target().elements, - &rest, - ); - - // Aggregate the output values of children and the overflow number. - let mut num_overflows = zero; - let mut aggregated_values = vec![]; - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = compute_output_item(b, i, &[subtree_proof, child_proof]); - - aggregated_values.append(&mut output); - num_overflows = b.add(num_overflows, overflow); - } - - // count = p.count + child.count - let count = b.add( - subtree_proof.num_matching_rows_target(), - child_proof.num_matching_rows_target(), - ); - - // overflow = (pC.overflow + pR.overflow + num_overflows) != 0 - let overflow = b.add_many([ - subtree_proof.to_overflow_raw(), - child_proof.to_overflow_raw(), - &num_overflows, - ]); - let overflow = b.is_not_equal(overflow, zero); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - aggregated_values.as_slice(), - &[count], - subtree_proof.to_ops_raw(), - subtree_proof.to_index_value_raw(), - &node_min.to_targets(), - &node_max.to_targets(), - subtree_proof.to_index_ids_raw(), - &min_query.to_targets(), - &max_query.to_targets(), - &[overflow.target], - subtree_proof.to_computational_hash_raw(), - subtree_proof.to_placeholder_hash_raw(), - ) - .register(b); - - let sibling_child_hashes = [sibling_child_hash1, sibling_child_hash2]; - - PartialNodeWires { - is_rows_tree_node, - is_left_child, - sibling_tree_hash, - sibling_child_hashes, - sibling_value, - sibling_min, - sibling_max, - min_query, - max_query, - } - } - - fn assign(&self, pw: &mut PartialWitness, wires: &PartialNodeWires) { - [ - (wires.is_rows_tree_node, self.is_rows_tree_node), - (wires.is_left_child, self.is_left_child), - ] - .iter() - .for_each(|(t, v)| pw.set_bool_target(*t, *v)); - [ - (&wires.sibling_value, self.sibling_value), - (&wires.sibling_min, self.sibling_min), - (&wires.sibling_max, self.sibling_max), - (&wires.min_query, self.min_query), - (&wires.max_query, self.max_query), - ] - .iter() - .for_each(|(t, v)| pw.set_u256_target(t, *v)); - pw.set_hash_target(wires.sibling_tree_hash, self.sibling_tree_hash); - wires - .sibling_child_hashes - .iter() - .zip(self.sibling_child_hashes) - .for_each(|(t, v)| pw.set_hash_target(*t, v)); - } -} - -/// Subtree proof number = 1, child proof number = 1 -pub(crate) const NUM_VERIFIED_PROOFS: usize = 2; - -impl CircuitLogicWires - for PartialNodeWires -where - [(); MAX_NUM_RESULTS - 1]:, -{ - type CircuitBuilderParams = (); - type Inputs = PartialNodeCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - // The first one is the subtree proof, and the second is the child proof. - let [subtree_proof, child_proof] = - verified_proofs.map(|p| PublicInputs::from_slice(&p.public_inputs)); - - Self::Inputs::build(builder, &subtree_proof, &child_proof) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - query::{ - aggregation::{ - tests::compute_output_item_value, - utils::tests::{unify_child_proof, unify_subtree_proof}, - }, - pi_len, - }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - use mp2_common::{poseidon::H, utils::ToFields, C}; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::gen_random_field_hash, - }; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; - use rand::{thread_rng, Rng}; - use std::array; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestPartialNodeCircuit<'a> { - c: PartialNodeCircuit, - subtree_proof: &'a [F], - child_proof: &'a [F], - } - - impl UserCircuit for TestPartialNodeCircuit<'_> { - // Circuit wires + query proof + child proof - type Wires = (PartialNodeWires, Vec, Vec); - - fn build(b: &mut CBuilder) -> Self::Wires { - let proofs = array::from_fn(|_| { - b.add_virtual_target_arr::<{ pi_len::() }>() - .to_vec() - }); - let [subtree_pi, child_pi] = - array::from_fn(|i| PublicInputs::::from_slice(&proofs[i])); - - let wires = PartialNodeCircuit::build(b, &subtree_pi, &child_pi); - - let [subtree_proof, child_proof] = proofs; - - (wires, subtree_proof, child_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.subtree_proof); - pw.set_target_arr(&wires.2, self.child_proof); - } - } - - fn test_partial_node_circuit(is_rows_tree_node: bool, is_left_child: bool) { - let min_query = U256::from(100); - let max_query = U256::from(200); - - let [sibling_tree_hash, sibling_child_hash1, sibling_child_hash2] = - array::from_fn(|_| gen_random_field_hash()); - - let mut rng = thread_rng(); - let sibling_value = U256::from_limbs(rng.gen()); - let [sibling_min, sibling_max] = if is_left_child { - // sibling_min > MAX_query - [max_query + U256::from(1), U256::from_limbs(rng.gen())] - } else { - [U256::from_limbs(rng.gen()), min_query - U256::from(1)] - }; - - // Generate the input proofs. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - let [mut subtree_proof, mut child_proof] = random_aggregation_public_inputs(&ops); - unify_subtree_proof::( - &mut subtree_proof, - is_rows_tree_node, - min_query, - max_query, - ); - let subtree_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&subtree_proof); - unify_child_proof::( - &mut child_proof, - is_rows_tree_node, - min_query, - max_query, - &subtree_pi, - ); - let child_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&child_proof); - - // Construct the expected public input values. - let index_ids = subtree_pi.index_ids(); - let index_value = subtree_pi.index_value(); - let node_value = if is_rows_tree_node { - subtree_pi.min_value() - } else { - index_value - }; - let [node_min, node_max] = if is_left_child { - [child_pi.min_value(), sibling_max] - } else { - [sibling_min, child_pi.max_value()] - }; - - // Construct the test circuit. - let sibling_child_hashes = [sibling_child_hash1, sibling_child_hash2]; - let test_circuit = TestPartialNodeCircuit { - c: PartialNodeCircuit { - is_rows_tree_node, - is_left_child, - sibling_tree_hash, - sibling_child_hashes, - sibling_value, - sibling_min, - sibling_max, - min_query, - max_query, - }, - subtree_proof: &subtree_proof, - child_proof: &child_proof, - }; - - // Prove for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - // Check the public inputs. - // Tree hash - { - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - - // H(h1 || h2 || sibling_min || sibling_max || column_id || sibling_value || sibling_tree_hash) - let inputs: Vec<_> = sibling_child_hash1 - .to_fields() - .into_iter() - .chain(sibling_child_hash2.to_fields()) - .chain(sibling_min.to_fields()) - .chain(sibling_max.to_fields()) - .chain(iter::once(column_id)) - .chain(sibling_value.to_fields()) - .chain(sibling_tree_hash.to_fields()) - .collect(); - let sibling_hash = H::hash_no_pad(&inputs); - - let child_hash = child_pi.tree_hash(); - let [left_child_hash, right_child_hash] = if is_left_child { - [child_hash, sibling_hash] - } else { - [sibling_hash, child_hash] - }; - - // H(left_child_hash || right_child_hash || node_min || node_max || column_id || node_value || p.H) - let inputs: Vec<_> = left_child_hash - .to_fields() - .into_iter() - .chain(right_child_hash.to_fields()) - .chain(node_min.to_fields()) - .chain(node_max.to_fields()) - .chain(iter::once(column_id)) - .chain(node_value.to_fields()) - .chain(subtree_pi.tree_hash().to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values and overflow flag - { - let mut num_overflows = 0; - let mut aggregated_values = vec![]; - - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = - compute_output_item_value(i, &[&subtree_pi, &child_pi]); - - aggregated_values.append(&mut output); - num_overflows += overflow; - } - - assert_eq!(pi.to_values_raw(), aggregated_values); - assert_eq!( - pi.overflow_flag(), - subtree_pi.overflow_flag() || child_pi.overflow_flag() || num_overflows != 0 - ); - } - // Count - assert_eq!( - pi.num_matching_rows(), - subtree_pi.num_matching_rows() + child_pi.num_matching_rows(), - ); - // Operation IDs - assert_eq!(pi.operation_ids(), subtree_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), node_min); - // Maximum value - assert_eq!(pi.max_value(), node_max); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Computational hash - assert_eq!(pi.computational_hash(), subtree_pi.computational_hash()); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), subtree_pi.placeholder_hash()); - } - - #[test] - fn test_query_agg_partial_node_for_row_node_with_left_child() { - test_partial_node_circuit(true, true); - } - - #[test] - fn test_query_agg_partial_node_for_row_node_with_right_child() { - test_partial_node_circuit(true, false); - } - - #[test] - fn test_query_agg_partial_node_for_index_node_with_left_child() { - test_partial_node_circuit(false, true); - } - - #[test] - fn test_query_agg_partial_node_for_index_node_with_right_child() { - test_partial_node_circuit(false, false); - } -} diff --git a/verifiable-db/src/query/aggregation/utils.rs b/verifiable-db/src/query/aggregation/utils.rs deleted file mode 100644 index eff88308f..000000000 --- a/verifiable-db/src/query/aggregation/utils.rs +++ /dev/null @@ -1,153 +0,0 @@ -//! Utility functions for query aggregation circuits - -use crate::query::public_inputs::PublicInputs; -use mp2_common::{ - array::Array, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target}, - F, -}; -use plonky2::{ - field::types::Field, - iop::target::{BoolTarget, Target}, -}; - -/// Check the consistency for the subtree proof and child proofs. -pub(crate) fn constrain_input_proofs( - b: &mut CBuilder, - is_rows_tree_node: BoolTarget, - min_query: &UInt256Target, - max_query: &UInt256Target, - subtree_proof: &PublicInputs, - child_proofs: &[PublicInputs], -) { - let ffalse = b._false(); - - let index_ids = subtree_proof.index_ids_target(); - let index_value = subtree_proof.index_value_target(); - - // Ensure the proofs in the same rows tree are employing the same value - // of the primary indexed column: - // is_rows_tree_node == is_rows_tree_node AND p.I == p1.I AND p.I == p2.I ... - let is_equals: Vec<_> = child_proofs - .iter() - .map(|p| b.is_equal_u256(&index_value, &p.index_value_target())) - .collect(); - let is_equal = is_equals - .into_iter() - .fold(is_rows_tree_node, |acc, is_equal| b.and(acc, is_equal)); - b.connect(is_equal.target, is_rows_tree_node.target); - - // Ensure the value of the indexed column for all the records stored in the - // rows tree found in this node is within the range specified by the query: - // NOT(is_rows_tree_node) == NOT(is_row_tree_node) AND p.I >= MIN_query AND p.I <= MAX_query - // And assume: is_out_of_range = p.I < MIN_query OR p.I > MAX_query - // => (1 - is_rows_tree_node) * is_out_of_range = 0 - // => is_out_of_range - is_out_of_range * is_rows_tree_node = 0 - let is_less_than_min = b.is_less_than_u256(&index_value, min_query); - let is_greater_than_max = b.is_less_than_u256(max_query, &index_value); - let is_out_of_range = b.or(is_less_than_min, is_greater_than_max); - let is_false = b.arithmetic( - F::NEG_ONE, - F::ONE, - is_rows_tree_node.target, - is_out_of_range.target, - is_out_of_range.target, - ); - b.connect(is_false, ffalse.target); - - // p.index_ids == p1.index_ids == p2.index_ids ... - let index_ids = Array::from(index_ids); - child_proofs - .iter() - .for_each(|p| index_ids.enforce_equal(b, &Array::from(p.index_ids_target()))); - - // p.C == p1.C == p2.C ... - let computational_hash = subtree_proof.computational_hash_target(); - child_proofs - .iter() - .for_each(|p| b.connect_hashes(computational_hash, p.computational_hash_target())); - - // p.H_p == p1.H_p == p2.H_p = ... - let placeholder_hash = subtree_proof.placeholder_hash_target(); - child_proofs - .iter() - .for_each(|p| b.connect_hashes(placeholder_hash, p.placeholder_hash_target())); - - // MIN_query = p1.MIN_I == p2.MIN_I ... - child_proofs - .iter() - .for_each(|p| b.enforce_equal_u256(min_query, &p.min_query_target())); - - // MAX_query = p1.MAX_I == p2.MAX_I ... - child_proofs - .iter() - .for_each(|p| b.enforce_equal_u256(max_query, &p.max_query_target())); - - // if the subtree proof is generated for a rows tree node, - // the query bounds must be same: - // is_row_tree_node = is_row_tree_node AND MIN_query == p.MIN_I AND MAX_query == p.MAX_I - let is_min_query_equal = b.is_equal_u256(min_query, &subtree_proof.min_query_target()); - let is_max_query_equal = b.is_equal_u256(max_query, &subtree_proof.max_query_target()); - let is_equal = b.and(is_min_query_equal, is_max_query_equal); - let is_equal = b.and(is_equal, is_rows_tree_node); - b.connect(is_equal.target, is_rows_tree_node.target); -} - -#[cfg(test)] -pub(crate) mod tests { - use super::*; - use alloy::primitives::U256; - use mp2_common::utils::ToFields; - - /// Assign the subtree proof to make it consistent. - pub(crate) fn unify_subtree_proof( - proof: &mut [F], - is_rows_tree_node: bool, - min_query: U256, - max_query: U256, - ) { - let [index_value_range, min_query_range, max_query_range] = [ - QueryPublicInputs::IndexValue, - QueryPublicInputs::MinQuery, - QueryPublicInputs::MaxQuery, - ] - .map(PublicInputs::::to_range); - - if is_rows_tree_node { - // p.MIN_I == MIN_query AND p.MAX_I == MAX_query - proof[min_query_range].copy_from_slice(&min_query.to_fields()); - proof[max_query_range].copy_from_slice(&max_query.to_fields()); - } else { - // p.I >= MIN_query AND p.I <= MAX_query - let index_value: U256 = (min_query + max_query) >> 1; - proof[index_value_range].copy_from_slice(&index_value.to_fields()); - } - } - - /// Assign the child proof to make it consistent. - pub(crate) fn unify_child_proof( - proof: &mut [F], - is_rows_tree_node: bool, - min_query: U256, - max_query: U256, - subtree_pi: &PublicInputs, - ) { - let [index_value_range, min_query_range, max_query_range] = [ - QueryPublicInputs::IndexValue, - QueryPublicInputs::MinQuery, - QueryPublicInputs::MaxQuery, - ] - .map(PublicInputs::::to_range); - - // child.MIN_I == MIN_query - // child.MAX_I == MAX_query - proof[min_query_range.clone()].copy_from_slice(&min_query.to_fields()); - proof[max_query_range.clone()].copy_from_slice(&max_query.to_fields()); - - if is_rows_tree_node { - // child.I == p.I - proof[index_value_range.clone()].copy_from_slice(subtree_pi.to_index_value_raw()); - } - } -} diff --git a/verifiable-db/src/query/api.rs b/verifiable-db/src/query/api.rs index 322b90042..6eed9daf3 100644 --- a/verifiable-db/src/query/api.rs +++ b/verifiable-db/src/query/api.rs @@ -10,26 +10,14 @@ use recursion_framework::{ }; use serde::{Deserialize, Serialize}; -#[cfg(feature = "batching_circuits")] -use mp2_common::{default_config, poseidon::H}; -#[cfg(feature = "batching_circuits")] -use plonky2::plonk::config::Hasher; -#[cfg(feature = "batching_circuits")] -use recursion_framework::{ - circuit_builder::CircuitWithUniversalVerifierBuilder, - framework::prepare_recursive_circuit_for_circuit_set, -}; - use crate::query::{ - aggregation::{ChildPosition, NodeInfo, QueryBounds, QueryHashNonExistenceCircuits}, - batching::{ - circuits::{ - chunk_aggregation::{ChunkAggregationCircuit, ChunkAggregationInputs, ChunkAggregationWires}, - non_existence::{NonExistenceCircuit, NonExistenceWires}, - row_chunk_processing::{RowChunkProcessingCircuit, RowChunkProcessingWires}, - }, - row_chunk::row_process_gadget::RowProcessingGadgetInputs, + utils::{ChildPosition, NodeInfo, QueryBounds, QueryHashNonExistenceCircuits}, + circuits::{ + chunk_aggregation::{ChunkAggregationCircuit, ChunkAggregationInputs, ChunkAggregationWires}, + non_existence::{NonExistenceCircuit, NonExistenceWires}, + row_chunk_processing::{RowChunkProcessingCircuit, RowChunkProcessingWires}, }, + row_chunk_gadgets::row_process_gadget::RowProcessingGadgetInputs, computational_hash_ids::{AggregationOperation, ColumnIDs, Identifiers}, universal_circuit::{ output_with_aggregation::Circuit as OutputAggCircuit, diff --git a/verifiable-db/src/query/batching/circuits/api.rs b/verifiable-db/src/query/batching/circuits/api.rs deleted file mode 100644 index 54832c99d..000000000 --- a/verifiable-db/src/query/batching/circuits/api.rs +++ /dev/null @@ -1,756 +0,0 @@ -use std::iter::{repeat, repeat_with}; - -use anyhow::{ensure, Result}; - -use itertools::Itertools; -use mp2_common::{array::ToField, proof::ProofWithVK, types::HashOutput, C, D, F}; -use plonky2::iop::target::Target; -use recursion_framework::{ - circuit_builder::CircuitWithUniversalVerifier, framework::RecursiveCircuits, -}; -use serde::{Deserialize, Serialize}; - -#[cfg(feature = "batching_circuits")] -use mp2_common::{default_config, poseidon::H}; -#[cfg(feature = "batching_circuits")] -use plonky2::plonk::config::Hasher; -#[cfg(feature = "batching_circuits")] -use recursion_framework::{ - circuit_builder::CircuitWithUniversalVerifierBuilder, - framework::prepare_recursive_circuit_for_circuit_set, -}; - -use crate::query::{ - aggregation::{ChildPosition, NodeInfo, QueryBounds, QueryHashNonExistenceCircuits}, - batching::{ - circuits::chunk_aggregation::ChunkAggregationCircuit, public_inputs::PublicInputs, - row_process_gadget::RowProcessingGadgetInputs, - }, - computational_hash_ids::{AggregationOperation, ColumnIDs, Identifiers}, - universal_circuit::{ - output_with_aggregation::Circuit as OutputAggCircuit, - universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure, RowCells}, - }, -}; - -use super::{ - chunk_aggregation::{ChunkAggregationInputs, ChunkAggregationWires}, - non_existence::{NonExistenceCircuit, NonExistenceWires}, - row_chunk_processing::{RowChunkProcessingCircuit, RowChunkProcessingWires}, -}; - -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct Parameters< - const NUM_CHUNKS: usize, - const NUM_ROWS: usize, - const ROW_TREE_MAX_DEPTH: usize, - const INDEX_TREE_MAX_DEPTH: usize, - const MAX_NUM_COLUMNS: usize, - const MAX_NUM_PREDICATE_OPS: usize, - const MAX_NUM_RESULT_OPS: usize, - const MAX_NUM_RESULTS: usize, -> where - [(); ROW_TREE_MAX_DEPTH - 1]:, - [(); INDEX_TREE_MAX_DEPTH - 1]:, - [(); MAX_NUM_RESULTS - 1]:, - [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, -{ - row_chunk_agg_circuit: CircuitWithUniversalVerifier< - F, - C, - D, - 0, - RowChunkProcessingWires< - NUM_ROWS, - ROW_TREE_MAX_DEPTH, - INDEX_TREE_MAX_DEPTH, - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - OutputAggCircuit, - >, - >, - //ToDo: add row_chunk_circuit for queries without aggregation, once we integrate results tree - aggregation_circuit: CircuitWithUniversalVerifier< - F, - C, - D, - NUM_CHUNKS, - ChunkAggregationWires, - >, - non_existence_circuit: CircuitWithUniversalVerifier< - F, - C, - D, - 0, - NonExistenceWires, - >, - circuit_set: RecursiveCircuits, -} - -pub const fn num_io() -> usize { - PublicInputs::::total_len() -} -#[cfg(feature = "batching_circuits")] -impl< - const NUM_CHUNKS: usize, - const NUM_ROWS: usize, - const ROW_TREE_MAX_DEPTH: usize, - const INDEX_TREE_MAX_DEPTH: usize, - const MAX_NUM_COLUMNS: usize, - const MAX_NUM_PREDICATE_OPS: usize, - const MAX_NUM_RESULT_OPS: usize, - const MAX_NUM_RESULTS: usize, - > - Parameters< - NUM_CHUNKS, - NUM_ROWS, - ROW_TREE_MAX_DEPTH, - INDEX_TREE_MAX_DEPTH, - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - > -where - [(); ROW_TREE_MAX_DEPTH - 1]:, - [(); INDEX_TREE_MAX_DEPTH - 1]:, - [(); MAX_NUM_RESULTS - 1]:, - [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, - [(); >::HASH_SIZE]:, - [(); num_io::()]:, -{ - const CIRCUIT_SET_SIZE: usize = 3; - - pub(crate) fn build() -> Self { - let builder = - CircuitWithUniversalVerifierBuilder::() }>::new::( - default_config(), - Self::CIRCUIT_SET_SIZE, - ); - let row_chunk_agg_circuit = builder.build_circuit(()); - let aggregation_circuit = builder.build_circuit(()); - let non_existence_circuit = builder.build_circuit(()); - - let circuits = vec![ - prepare_recursive_circuit_for_circuit_set(&row_chunk_agg_circuit), - prepare_recursive_circuit_for_circuit_set(&aggregation_circuit), - prepare_recursive_circuit_for_circuit_set(&non_existence_circuit), - ]; - let circuit_set = RecursiveCircuits::new(circuits); - - Self { - row_chunk_agg_circuit, - aggregation_circuit, - non_existence_circuit, - circuit_set, - } - } - - pub(crate) fn generate_proof( - &self, - input: CircuitInput< - NUM_CHUNKS, - NUM_ROWS, - ROW_TREE_MAX_DEPTH, - INDEX_TREE_MAX_DEPTH, - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >, - ) -> Result> { - let proof: ProofWithVK = match input { - CircuitInput::RowChunkWithAggregation(row_chunk_processing_circuit) => ( - self.circuit_set.generate_proof( - &self.row_chunk_agg_circuit, - [], - [], - row_chunk_processing_circuit, - )?, - self.row_chunk_agg_circuit - .circuit_data() - .verifier_only - .clone(), - ) - .into(), - CircuitInput::ChunkAggregation(chunk_aggregation_inputs) => { - let ChunkAggregationInputs { - chunk_proofs, - circuit, - } = chunk_aggregation_inputs; - let input_vd = chunk_proofs - .iter() - .map(|p| p.verifier_data()) - .cloned() - .collect_vec(); - let input_proofs = chunk_proofs.map(|p| p.proof); - ( - self.circuit_set.generate_proof( - &self.aggregation_circuit, - input_proofs, - input_vd.iter().collect_vec().try_into().unwrap(), - circuit, - )?, - self.aggregation_circuit - .circuit_data() - .verifier_only - .clone(), - ) - .into() - } - CircuitInput::NonExistence(non_existence_circuit) => ( - self.circuit_set.generate_proof( - &self.non_existence_circuit, - [], - [], - non_existence_circuit, - )?, - self.non_existence_circuit - .circuit_data() - .verifier_only - .clone(), - ) - .into(), - }; - proof.serialize() - } - - pub(crate) fn get_circuit_set(&self) -> &RecursiveCircuits { - &self.circuit_set - } -} - -#[cfg(feature = "batching_circuits")] -#[cfg(test)] -mod tests { - use alloy::primitives::U256; - use itertools::Itertools; - use mp2_common::{ - array::ToField, - proof::ProofWithVK, - utils::{FromFields, ToFields}, - F, - }; - use mp2_test::utils::{gen_random_u256, random_vector}; - use rand::thread_rng; - - use crate::query::{ - aggregation::{ - output_computation::tests::compute_dummy_output_values, tests::aggregate_output_values, - ChildPosition, QueryBoundSource, QueryBounds, - }, - batching::{ - circuits::{ - api::{CircuitInput, NodePath, RowInput, TreePathInputs}, - tests::{build_test_tree, compute_output_values_for_row}, - }, - public_inputs::PublicInputs, - row_chunk::tests::{BoundaryRowData, BoundaryRowNodeInfo}, - }, - computational_hash_ids::{ - AggregationOperation, ColumnIDs, Identifiers, Operation, PlaceholderIdentifier, - }, - merkle_path::tests::{generate_test_tree, NeighborInfo}, - universal_circuit::{ - universal_circuit_inputs::{ - BasicOperation, ColumnCell, InputOperand, OutputItem, PlaceholderId, Placeholders, - ResultStructure, RowCells, - }, - universal_query_circuit::placeholder_hash, - universal_query_gadget::CurveOrU256, - ComputationalHash, - }, - }; - - use plonky2::{ - field::types::{Field, PrimeField64}, - plonk::config::GenericHashOut, - }; - - use super::Parameters; - - const NUM_CHUNKS: usize = 4; - const NUM_ROWS: usize = 3; - const ROW_TREE_MAX_DEPTH: usize = 10; - const INDEX_TREE_MAX_DEPTH: usize = 15; - const MAX_NUM_COLUMNS: usize = 30; - const MAX_NUM_PREDICATE_OPS: usize = 20; - const MAX_NUM_RESULT_OPS: usize = 30; - const MAX_NUM_RESULTS: usize = 10; - - #[tokio::test] - async fn test_api() { - const NUM_ACTUAL_COLUMNS: usize = 5; - // generate a proof for the following query: - // SELECT AVG(C1/C2), MIN(C1*(C3-4)), MAX(C5%$1), COUNT(C4) FROM T WHERE (C4 > $2 + 4 XOR C3 < C1*C2) AND C2 >= $3*4 AND C2 <= $4 AND C1 >= 2876 AND C1 <= 7894 - let rng = &mut thread_rng(); - let column_ids = random_vector::(NUM_ACTUAL_COLUMNS); - let primary_index = F::from_canonical_u64(column_ids[0]); - let secondary_index = F::from_canonical_u64(column_ids[1]); - let column_ids = ColumnIDs::new(column_ids[0], column_ids[1], column_ids[2..].to_vec()); - - // query bound values - let min_query_primary = U256::from(2876); - let max_query_primary = U256::from(7894); - let min_query_secondary = U256::from(68); - let max_query_secondary = U256::from(9768443); - - // define placeholders - let first_placeholder_id = PlaceholderId::Generic(0); - let second_placeholder_id = PlaceholderIdentifier::Generic(1); - let mut placeholders = Placeholders::new_empty(min_query_primary, max_query_primary); - [first_placeholder_id, second_placeholder_id] - .iter() - .for_each(|id| placeholders.insert(*id, gen_random_u256(rng))); - let third_placeholder_id = PlaceholderId::Generic(2); - // value of $3 is min_secondary/4 - placeholders.insert(third_placeholder_id, min_query_secondary / U256::from(4)); - let fourth_placeholder_id = PlaceholderId::Generic(3); - // $4 is equal to max_secondary - placeholders.insert(fourth_placeholder_id, max_query_secondary); - let bounds = QueryBounds::new( - &placeholders, - Some(QueryBoundSource::Operation(BasicOperation { - first_operand: InputOperand::Placeholder(third_placeholder_id), - second_operand: Some(InputOperand::Constant(U256::from(4))), - op: Operation::MulOp, - })), - Some(QueryBoundSource::Placeholder(fourth_placeholder_id)), - ) - .unwrap(); - - // build predicate_operations - let mut predicate_operations = vec![]; - // C4 > $2 - let placeholder_cmp = BasicOperation { - first_operand: InputOperand::Column(3), - second_operand: Some(InputOperand::Placeholder(second_placeholder_id)), - op: Operation::GreaterThanOp, - }; - predicate_operations.push(placeholder_cmp); - // C1 * C2 - let column_prod = BasicOperation { - first_operand: InputOperand::Column(0), - second_operand: Some(InputOperand::Column(1)), - op: Operation::MulOp, - }; - predicate_operations.push(column_prod); - // C3 < C1*C2 - let cmp_expr = BasicOperation { - first_operand: InputOperand::Column(2), - second_operand: Some(InputOperand::PreviousValue( - BasicOperation::locate_previous_operation(&predicate_operations, &column_prod) - .unwrap(), - )), - op: Operation::LessThanOp, - }; - predicate_operations.push(cmp_expr); - // C4 > $2 XOR C3 < C1*C2 - let xor_expr = BasicOperation { - first_operand: InputOperand::PreviousValue( - BasicOperation::locate_previous_operation(&predicate_operations, &placeholder_cmp) - .unwrap(), - ), - second_operand: Some(InputOperand::PreviousValue( - BasicOperation::locate_previous_operation(&predicate_operations, &cmp_expr) - .unwrap(), - )), - op: Operation::XorOp, - }; - predicate_operations.push(xor_expr); - // build operations to compute results - let mut result_operations = vec![]; - // C1/C2 - let column_div = BasicOperation { - first_operand: InputOperand::Column(0), - second_operand: Some(InputOperand::Column(1)), - op: Operation::DivOp, - }; - result_operations.push(column_div); - // C3 - 4 - let constant_sub = BasicOperation { - first_operand: InputOperand::Column(2), - second_operand: Some(InputOperand::Constant(U256::from(4))), - op: Operation::SubOp, - }; - result_operations.push(constant_sub); - // C1*(C3-4) - let prod_expr = BasicOperation { - first_operand: InputOperand::Column(0), - second_operand: Some(InputOperand::PreviousValue( - BasicOperation::locate_previous_operation(&result_operations, &constant_sub) - .unwrap(), - )), - op: Operation::MulOp, - }; - result_operations.push(prod_expr); - // C5 % $1 - let placeholder_mod = BasicOperation { - first_operand: InputOperand::Column(4), - second_operand: Some(InputOperand::Placeholder(first_placeholder_id)), - op: Operation::ModOp, - }; - result_operations.push(placeholder_mod); - let output_items = vec![ - OutputItem::ComputedValue( - BasicOperation::locate_previous_operation(&result_operations, &column_div).unwrap(), - ), - OutputItem::ComputedValue( - BasicOperation::locate_previous_operation(&result_operations, &prod_expr).unwrap(), - ), - OutputItem::ComputedValue( - BasicOperation::locate_previous_operation(&result_operations, &placeholder_mod) - .unwrap(), - ), - OutputItem::Column(3), - ]; - let output_ops: [F; 4] = [ - AggregationOperation::AvgOp.to_field(), - AggregationOperation::MinOp.to_field(), - AggregationOperation::MaxOp.to_field(), - AggregationOperation::CountOp.to_field(), - ]; - - let results = ResultStructure::new_for_query_with_aggregation( - result_operations, - output_items, - output_ops - .iter() - .map(|op| op.to_canonical_u64()) - .collect_vec(), - ) - .unwrap(); - - let [node_0, node_1, node_2] = build_test_tree(&bounds, &column_ids.to_vec()).await; - - let params = Parameters::< - NUM_CHUNKS, - NUM_ROWS, - ROW_TREE_MAX_DEPTH, - INDEX_TREE_MAX_DEPTH, - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >::build(); - - let to_row_cells = |values: &[U256]| { - let column_cells = values - .iter() - .zip(column_ids.to_vec().iter()) - .map(|(&value, &id)| ColumnCell { value, id }) - .collect_vec(); - RowCells::new( - column_cells[0].clone(), - column_cells[1].clone(), - column_cells[2..].to_vec(), - ) - }; - - // we split the rows to be proven in chunks: - // - first chunk with rows 1A and 1C - // - second chunk with rows 2B, 2D and 2A - - // prove first chunk - let [node_1a, node_1b, node_1c, node_1d] = node_1 - .rows_tree - .iter() - .map(|n| n.node) - .collect_vec() - .try_into() - .unwrap(); - let path_1a = vec![]; - - let path_1 = vec![]; - let node_1_children = [Some(node_0.node), Some(node_2.node)]; - - let row_path_1a = NodePath::new( - TreePathInputs::new(node_1a, path_1a, [Some(node_1b), Some(node_1c)]), - TreePathInputs::new(node_1.node, path_1.clone(), node_1_children), - ); - - let row_cells_1a = to_row_cells(&node_1.rows_tree[0].values); - let row_1a = RowInput::new(&row_cells_1a, &row_path_1a); - - let path_1c = vec![(node_1a, ChildPosition::Right)]; - let row_path_1c = NodePath::new( - TreePathInputs::new(node_1c, path_1c, [None, Some(node_1d)]), - TreePathInputs::new(node_1.node, path_1, node_1_children), - ); - - let row_cells_1c = to_row_cells(&node_1.rows_tree[2].values); - let row_1c = RowInput::new(&row_cells_1c, &row_path_1c); - - let row_chunk_inputs = CircuitInput::new_row_chunks_input( - &[row_1a, row_1c], - &predicate_operations, - &placeholders, - &bounds, - &results, - ) - .unwrap(); - - let expected_placeholder_hash = - if let CircuitInput::RowChunkWithAggregation(input) = &row_chunk_inputs { - let placeholder_hash_ids = input.ids_for_placeholder_hash(); - placeholder_hash(&placeholder_hash_ids, &placeholders, &bounds).unwrap() - } else { - unreachable!() - }; - - let first_chunk_proof = params.generate_proof(row_chunk_inputs).unwrap(); - - // prove second chunk - let [node_2a, node_2b, node_2c, node_2d] = node_2 - .rows_tree - .iter() - .map(|n| n.node) - .collect_vec() - .try_into() - .unwrap(); - let path_2d = vec![ - (node_2b, ChildPosition::Right), - (node_2a, ChildPosition::Left), - ]; - - let path_2 = vec![(node_1.node, ChildPosition::Right)]; - let node_2_children = [None, None]; - let row_path_2d = NodePath::new( - TreePathInputs::new(node_2d, path_2d, [None, None]), - TreePathInputs::new(node_2.node, path_2.clone(), node_2_children), - ); - - let row_cells_2d = to_row_cells(&node_2.rows_tree[3].values); - - let row_2d = RowInput::new(&row_cells_2d, &row_path_2d); - - let path_2b = vec![(node_2a, ChildPosition::Left)]; - let row_path_2b = NodePath::new( - TreePathInputs::new(node_2b, path_2b, [Some(node_2c), Some(node_2d)]), - TreePathInputs::new(node_2.node, path_2.clone(), node_2_children), - ); - - let row_cells_2b = to_row_cells(&node_2.rows_tree[1].values); - - let row_2b = RowInput::new(&row_cells_2b, &row_path_2b); - - let path_2a = vec![]; - let row_path_2a = NodePath::new( - TreePathInputs::new(node_2a, path_2a, [Some(node_2b), None]), - TreePathInputs::new(node_2.node, path_2, node_2_children), - ); - - let row_cells_2a = to_row_cells(&node_2.rows_tree[0].values); - - let row_2a = RowInput::new(&row_cells_2a, &row_path_2a); - - let second_chunk_inputs = CircuitInput::new_row_chunks_input( - &[row_2b, row_2d, row_2a], - &predicate_operations, - &placeholders, - &bounds, - &results, - ) - .unwrap(); - - let second_chunk_proof = params.generate_proof(second_chunk_inputs).unwrap(); - - // now, aggregate the 2 chunks - let aggregation_input = - CircuitInput::new_chunk_aggregation_input(&[first_chunk_proof, second_chunk_proof]) - .unwrap(); - - let final_proof = params.generate_proof(aggregation_input).unwrap(); - - // check public inputs - let proof = ProofWithVK::deserialize(&final_proof).unwrap(); - let pis = PublicInputs::::from_slice(&proof.proof().public_inputs); - - let (predicate_1a, error_1a, output_1a) = compute_output_values_for_row::( - &row_cells_1a, - &predicate_operations, - &results, - &placeholders, - ); - let (predicate_1c, error_1c, output_1c) = compute_output_values_for_row::( - &row_cells_1c, - &predicate_operations, - &results, - &placeholders, - ); - let (predicate_2a, error_2a, output_2a) = compute_output_values_for_row::( - &row_cells_2a, - &predicate_operations, - &results, - &placeholders, - ); - let (predicate_2b, error_2b, output_2b) = compute_output_values_for_row::( - &row_cells_2b, - &predicate_operations, - &results, - &placeholders, - ); - let (predicate_2d, error_2d, output_2d) = compute_output_values_for_row::( - &row_cells_2d, - &predicate_operations, - &results, - &placeholders, - ); - - let (expected_outputs, expected_err) = { - let outputs = [output_1a, output_1c, output_2d, output_2b, output_2a]; - let mut num_overflows = 0; - let outputs = output_ops - .into_iter() - .enumerate() - .map(|(i, op)| { - let (out, overflows) = aggregate_output_values(i, &outputs, op); - num_overflows += overflows; - U256::from_fields(CurveOrU256::::from_slice(&out).to_u256_raw()) - }) - .collect_vec(); - (outputs, num_overflows != 0) - }; - - let computational_hash = ComputationalHash::from_bytes( - (&Identifiers::computational_hash_universal_circuit( - &column_ids, - &predicate_operations, - &results, - Some(bounds.min_query_secondary().into()), - Some(bounds.max_query_secondary().into()), - ) - .unwrap()) - .into(), - ); - - let left_boundary_row = { - // left boundary row should correspond to row 1a - // predecessor of node_1a is node_1b, which is not in the path - let predecessor_1a = NeighborInfo::new(node_1b.value, None); - // successor of node_1a is node_1c, which is not in the path - let successor_1a = NeighborInfo::new(node_1c.value, None); - let row_info_1a = BoundaryRowNodeInfo { - end_node_hash: node_1a.compute_node_hash(secondary_index), - predecessor_info: predecessor_1a, - successor_info: successor_1a, - }; - // predecessor of node_1 is node_0, which is not in the path - let predecessor_1 = NeighborInfo::new(node_0.node.value, None); - // successor of node_1 is node_2, which is not in the path - let successor_1 = NeighborInfo::new(node_2.node.value, None); - let index_info_1a = BoundaryRowNodeInfo { - end_node_hash: node_1.node.compute_node_hash(primary_index), - predecessor_info: predecessor_1, - successor_info: successor_1, - }; - BoundaryRowData { - row_node_info: row_info_1a, - index_node_info: index_info_1a, - } - }; - - let right_boundary_row = { - // right boundary row should correspond to row 2a - // predecessor of node_2a is node_2d, which is not in the path - let predecessor_2a = NeighborInfo::new(node_2d.value, None); - // No successor of node_2a - let successor_2a = NeighborInfo::new_dummy_successor(); - let row_info_2a = BoundaryRowNodeInfo { - end_node_hash: node_2a.compute_node_hash(secondary_index), - predecessor_info: predecessor_2a, - successor_info: successor_2a, - }; - // predecessor of node_2 is node_1, which is in the path - let node_1_hash = node_1.node.compute_node_hash(primary_index); - let predecessor_2 = NeighborInfo::new(node_1.node.value, Some(node_1_hash)); - // No successor of node_2 - let successor_2 = NeighborInfo::new_dummy_successor(); - let index_info_2a = BoundaryRowNodeInfo { - end_node_hash: node_2.node.compute_node_hash(primary_index), - predecessor_info: predecessor_2, - successor_info: successor_2, - }; - BoundaryRowData { - row_node_info: row_info_2a, - index_node_info: index_info_2a, - } - }; - - let root = node_1.node.compute_node_hash(primary_index); - assert_eq!(root, pis.tree_hash(),); - assert_eq!(&pis.operation_ids()[..output_ops.len()], &output_ops); - - assert_eq!( - pis.overflow_flag(), - error_1a | error_1c | error_2d | error_2b | error_2a | expected_err - ); - assert_eq!( - pis.num_matching_rows(), - F::from_canonical_u8( - predicate_1a as u8 - + predicate_1c as u8 - + predicate_2b as u8 - + predicate_2d as u8 - + predicate_2a as u8 - ), - ); - assert_eq!(pis.first_value_as_u256(), expected_outputs[0],); - assert_eq!( - expected_outputs[1..], - pis.values()[..expected_outputs.len() - 1], - ); - assert_eq!(pis.to_left_row_raw(), left_boundary_row.to_fields(),); - assert_eq!(pis.to_right_row_raw(), right_boundary_row.to_fields(),); - - assert_eq!(pis.min_primary(), min_query_primary); - assert_eq!(pis.max_primary(), max_query_primary); - assert_eq!(pis.min_secondary(), min_query_secondary); - assert_eq!(pis.max_secondary(), max_query_secondary); - assert_eq!(pis.computational_hash(), computational_hash); - assert_eq!(pis.placeholder_hash(), expected_placeholder_hash); - - // generate an index tree with all nodes out side of primary index range to test non-existence circuit API - let [node_a, node_b, _node_c, node_d, node_e, _node_f, _node_g] = generate_test_tree( - primary_index, - Some((max_query_primary + U256::from(1), U256::MAX)), - ); - // we use node_e to prove non-existence - let path_e = vec![ - (node_d, ChildPosition::Left), - (node_b, ChildPosition::Left), - (node_a, ChildPosition::Left), - ]; - let merkle_path_e = TreePathInputs::new(node_e, path_e, [None, None]); - - let input = CircuitInput::new_non_existence_input( - merkle_path_e, - &column_ids, - &predicate_operations, - &results, - &placeholders, - &bounds, - ) - .unwrap(); - - let proof = params.generate_proof(input).unwrap(); - - // check public inputs - let proof = ProofWithVK::deserialize(&proof).unwrap(); - let pis = PublicInputs::::from_slice(&proof.proof().public_inputs); - - let root = node_a.compute_node_hash(primary_index); - assert_eq!(root, pis.tree_hash(),); - assert_eq!(&pis.operation_ids()[..output_ops.len()], &output_ops); - let expected_outputs = compute_dummy_output_values(&pis.operation_ids()); - assert_eq!(pis.to_values_raw(), &expected_outputs,); - assert_eq!(pis.num_matching_rows(), F::ZERO,); - assert!(!pis.overflow_flag()); - assert_eq!(pis.min_primary(), min_query_primary); - assert_eq!(pis.max_primary(), max_query_primary); - assert_eq!(pis.computational_hash(), computational_hash); - assert_eq!(pis.placeholder_hash(), expected_placeholder_hash); - } -} diff --git a/verifiable-db/src/query/batching/mod.rs b/verifiable-db/src/query/batching/mod.rs deleted file mode 100644 index 015763f6b..000000000 --- a/verifiable-db/src/query/batching/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub(crate) mod circuits; -pub(crate) mod row_chunk; diff --git a/verifiable-db/src/query/batching/circuits/chunk_aggregation.rs b/verifiable-db/src/query/circuits/chunk_aggregation.rs similarity index 99% rename from verifiable-db/src/query/batching/circuits/chunk_aggregation.rs rename to verifiable-db/src/query/circuits/chunk_aggregation.rs index 90a9f0d47..98dd02a86 100644 --- a/verifiable-db/src/query/batching/circuits/chunk_aggregation.rs +++ b/verifiable-db/src/query/circuits/chunk_aggregation.rs @@ -23,7 +23,7 @@ use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; use crate::query::{ - batching::row_chunk::aggregate_chunks::aggregate_chunks, pi_len, public_inputs::PublicInputs + row_chunk_gadgets::aggregate_chunks::aggregate_chunks, pi_len, public_inputs::PublicInputs }; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -198,7 +198,7 @@ mod tests { use crate::{ query::{ - aggregation::tests::aggregate_output_values, + utils::tests::aggregate_output_values, public_inputs::PublicInputs, computational_hash_ids::{AggregationOperation, Identifiers}, universal_circuit::universal_query_gadget::OutputValues, diff --git a/verifiable-db/src/query/batching/circuits/mod.rs b/verifiable-db/src/query/circuits/mod.rs similarity index 97% rename from verifiable-db/src/query/batching/circuits/mod.rs rename to verifiable-db/src/query/circuits/mod.rs index a6302b080..aa81d4d5c 100644 --- a/verifiable-db/src/query/batching/circuits/mod.rs +++ b/verifiable-db/src/query/circuits/mod.rs @@ -17,16 +17,12 @@ mod tests { }; use rand::thread_rng; - use crate::query::{ - aggregation::{NodeInfo, QueryBounds}, - public_inputs::tests::gen_values_in_range, - computational_hash_ids::AggregationOperation, - merkle_path::tests::build_node, - universal_circuit::{ + use crate::{query::{ + computational_hash_ids::AggregationOperation, merkle_path::tests::build_node, universal_circuit::{ universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure, RowCells}, universal_query_gadget::OutputValues, - }, - }; + }, utils::{NodeInfo, QueryBounds} + }, test_utils::gen_values_in_range}; /// Data structure employed to represent a node of a rows tree in the tests #[derive(Clone, Debug)] diff --git a/verifiable-db/src/query/batching/circuits/non_existence.rs b/verifiable-db/src/query/circuits/non_existence.rs similarity index 97% rename from verifiable-db/src/query/batching/circuits/non_existence.rs rename to verifiable-db/src/query/circuits/non_existence.rs index e77139bfd..b35a5a1a9 100644 --- a/verifiable-db/src/query/batching/circuits/non_existence.rs +++ b/verifiable-db/src/query/circuits/non_existence.rs @@ -22,7 +22,7 @@ use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; use crate::query::{ - aggregation::{output_computation::compute_dummy_output_targets, QueryBounds}, api::TreePathInputs, batching::row_chunk::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget}, merkle_path::{ + utils::QueryBounds, output_computation::compute_dummy_output_targets, api::TreePathInputs, row_chunk_gadgets::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget}, merkle_path::{ MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, NeighborInfoTarget, }, pi_len, public_inputs::PublicInputs, universal_circuit::{ ComputationalHash, ComputationalHashTarget, PlaceholderHash, PlaceholderHashTarget, @@ -281,16 +281,9 @@ mod tests { use crate::{ query::{ - aggregation::{ - output_computation::tests::compute_dummy_output_values, ChildPosition, QueryBounds, - }, - api::TreePathInputs, - batching::row_chunk::tests::{BoundaryRowData, BoundaryRowNodeInfo}, - public_inputs::{tests::gen_values_in_range, PublicInputs}, - merkle_path::tests::{generate_test_tree, NeighborInfo}, - universal_circuit::universal_circuit_inputs::Placeholders, + api::TreePathInputs, merkle_path::{tests::generate_test_tree, NeighborInfo}, output_computation::tests::compute_dummy_output_values, public_inputs::PublicInputs, row_chunk_gadgets::{BoundaryRowData, BoundaryRowNodeInfo}, universal_circuit::universal_circuit_inputs::Placeholders, utils::{ChildPosition, QueryBounds} }, - test_utils::random_aggregation_operations, + test_utils::{gen_values_in_range, random_aggregation_operations}, }; use super::{NonExistenceCircuit, NonExistenceWires}; diff --git a/verifiable-db/src/query/batching/circuits/row_chunk_processing.rs b/verifiable-db/src/query/circuits/row_chunk_processing.rs similarity index 99% rename from verifiable-db/src/query/batching/circuits/row_chunk_processing.rs rename to verifiable-db/src/query/circuits/row_chunk_processing.rs index a752ddbe2..22b9a9162 100644 --- a/verifiable-db/src/query/batching/circuits/row_chunk_processing.rs +++ b/verifiable-db/src/query/circuits/row_chunk_processing.rs @@ -10,7 +10,7 @@ use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::query::{ - aggregation::QueryBounds, batching::row_chunk:: + utils::QueryBounds, row_chunk_gadgets:: { row_process_gadget::{RowProcessingGadgetInputWires, RowProcessingGadgetInputs}, aggregate_chunks::aggregate_chunks, RowChunkDataTarget, @@ -374,24 +374,22 @@ mod tests { use rand::thread_rng; use crate::query::{ - aggregation::{ + utils::{ tests::aggregate_output_values, ChildPosition, QueryBoundSource, QueryBounds, }, - batching::{ - circuits::{ - row_chunk_processing::RowChunkProcessingCircuit, - tests::{build_test_tree, compute_output_values_for_row}, - }, - row_chunk::{ - tests::{BoundaryRowData, BoundaryRowNodeInfo}, - row_process_gadget::RowProcessingGadgetInputs - }, + circuits::{ + row_chunk_processing::RowChunkProcessingCircuit, + tests::{build_test_tree, compute_output_values_for_row}, + }, + row_chunk_gadgets::{ + BoundaryRowData, BoundaryRowNodeInfo, + row_process_gadget::RowProcessingGadgetInputs }, public_inputs::PublicInputs, computational_hash_ids::{ AggregationOperation, ColumnIDs, Identifiers, Operation, PlaceholderIdentifier, }, - merkle_path::{tests::NeighborInfo, MerklePathWithNeighborsGadget}, + merkle_path::{NeighborInfo, MerklePathWithNeighborsGadget}, universal_circuit::{ output_no_aggregation::Circuit as NoAggOutputCircuit, output_with_aggregation::Circuit as AggOutputCircuit, diff --git a/verifiable-db/src/query/computational_hash_ids.rs b/verifiable-db/src/query/computational_hash_ids.rs index 00446f8b1..ef55f2870 100644 --- a/verifiable-db/src/query/computational_hash_ids.rs +++ b/verifiable-db/src/query/computational_hash_ids.rs @@ -31,7 +31,7 @@ use serde::{Deserialize, Serialize}; use crate::revelation::placeholders_check::placeholder_ids_hash; use super::{ - aggregation::QueryBoundSource, + utils::QueryBoundSource, universal_circuit::{ universal_circuit_inputs::{ BasicOperation, InputOperand, OutputItem, PlaceholderIdsSet, ResultStructure, diff --git a/verifiable-db/src/query/merkle_path.rs b/verifiable-db/src/query/merkle_path.rs index 4445e5915..a1c0a3545 100644 --- a/verifiable-db/src/query/merkle_path.rs +++ b/verifiable-db/src/query/merkle_path.rs @@ -14,20 +14,23 @@ use mp2_common::{ }, types::{CBuilder, HashOutput}, u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256, NUM_LIMBS}, - utils::{FromTargets, HashBuilder, SelectTarget, ToTargets}, + utils::{FromFields, FromTargets, HashBuilder, SelectTarget, ToFields, ToTargets, TryIntoBool}, D, F, }; +use mp2_test::utils::gen_random_field_hash; use plonky2::{ hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, + field::types::Field, plonk::{circuit_builder::CircuitBuilder, config::GenericHashOut}, }; +use rand::Rng; use serde::{Deserialize, Serialize}; -use super::aggregation::{ChildPosition, NodeInfo, NodeInfoTarget}; +use super::utils::{ChildPosition, NodeInfo, NodeInfoTarget}; #[derive(Clone, Debug, Serialize, Deserialize)] /// Input wires for Merkle path verification gadget @@ -688,6 +691,66 @@ where } } +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct NeighborInfo { + pub(crate) is_found: bool, + pub(crate) is_in_path: bool, + pub(crate) value: U256, + pub(crate) hash: HashOut, +} + +impl FromFields for NeighborInfo { + fn from_fields(t: &[F]) -> Self { + assert!(t.len() >= NeighborInfoTarget::NUM_TARGETS); + Self { + is_found: t[0].try_into_bool().unwrap(), + is_in_path: t[1].try_into_bool().unwrap(), + value: U256::from_fields(&t[2..2 + NUM_LIMBS]), + hash: HashOut::from_vec(t[2 + NUM_LIMBS..NeighborInfoTarget::NUM_TARGETS].to_vec()), + } + } +} + +impl ToFields for NeighborInfo { + fn to_fields(&self) -> Vec { + [F::from_bool(self.is_found), F::from_bool(self.is_in_path)] + .into_iter() + .chain(self.value.to_fields()) + .chain(self.hash.to_fields()) + .collect() + } +} + +impl NeighborInfo { + // Initialize `Self` for the predecessor/successor of a node. `value` + // must be the value of the predecessor/successor, while `hash` must + // be its hash. If `hash` is `None`, it is assumed that the + // predecessor/successor is not located in the path of the node + pub(crate) fn new(value: U256, hash: Option>) -> Self { + Self { + is_found: true, + is_in_path: hash.is_some(), + value, + hash: hash.unwrap_or(*empty_poseidon_hash()), + } + } + /// Generate at random data about the successor/predecessor of a node. The generated + /// predecessor/successor must have the `value` provided as input; + /// the existence of the generated predecessor/successor depends on the `is_found` input: + /// - if `is_found` is `None`, then the existence of the generated predecessor/successor + /// is chosen at random + /// - otherwise, the generated predecessor/successor will be marked as found if and only if + /// the flag wrapped by `is_found` is `true` + pub(crate) fn sample(rng: &mut R, value: U256, is_found: Option) -> Self { + NeighborInfo { + is_found: is_found.unwrap_or(rng.gen()), + is_in_path: rng.gen(), + value, + hash: gen_random_field_hash(), + } + } +} + #[cfg(test)] pub(crate) mod tests { use std::array; @@ -696,8 +759,8 @@ pub(crate) mod tests { use mp2_common::{ poseidon::empty_poseidon_hash, types::HashOutput, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256, NUM_LIMBS}, - utils::{FromFields, FromTargets, ToFields, ToTargets, TryIntoBool}, + u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, + utils::{FromFields, FromTargets, ToTargets}, C, D, F, }; use mp2_test::{ @@ -705,7 +768,7 @@ pub(crate) mod tests { utils::{gen_random_field_hash, gen_random_u256}, }; use plonky2::{ - field::types::{Field, Sample}, + field::types::Sample, hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, iop::{ target::Target, @@ -715,13 +778,12 @@ pub(crate) mod tests { circuit_builder::CircuitBuilder, config::GenericHashOut, proof::ProofWithPublicInputs, }, }; - use rand::{thread_rng, Rng}; + use rand::thread_rng; - use crate::query::aggregation::{ChildPosition, NodeInfo}; + use crate::query::utils::{ChildPosition, NodeInfo}; use super::{ - MerklePathGadget, MerklePathTargetInputs, MerklePathWithNeighborsGadget, - MerklePathWithNeighborsTargetInputs, NeighborInfoTarget, + MerklePathGadget, MerklePathTargetInputs, MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, NeighborInfo, NeighborInfoTarget }; #[derive(Clone, Debug)] @@ -757,50 +819,7 @@ pub(crate) mod tests { } } - #[derive(Clone, Debug, Eq, PartialEq)] - pub(crate) struct NeighborInfo { - pub(crate) is_found: bool, - pub(crate) is_in_path: bool, - pub(crate) value: U256, - pub(crate) hash: HashOut, - } - - impl FromFields for NeighborInfo { - fn from_fields(t: &[F]) -> Self { - assert!(t.len() >= NeighborInfoTarget::NUM_TARGETS); - Self { - is_found: t[0].try_into_bool().unwrap(), - is_in_path: t[1].try_into_bool().unwrap(), - value: U256::from_fields(&t[2..2 + NUM_LIMBS]), - hash: HashOut::from_vec(t[2 + NUM_LIMBS..NeighborInfoTarget::NUM_TARGETS].to_vec()), - } - } - } - - impl ToFields for NeighborInfo { - fn to_fields(&self) -> Vec { - [F::from_bool(self.is_found), F::from_bool(self.is_in_path)] - .into_iter() - .chain(self.value.to_fields()) - .chain(self.hash.to_fields()) - .collect() - } - } - impl NeighborInfo { - // Initialize `Self` for the predecessor/successor of a node. `value` - // must be the value of the predecessor/successor, while `hash` must - // be its hash. If `hash` is `None`, it is assumed that the - // predecessor/successor is not located in the path of the node - pub(crate) fn new(value: U256, hash: Option>) -> Self { - Self { - is_found: true, - is_in_path: hash.is_some(), - value, - hash: hash.unwrap_or(*empty_poseidon_hash()), - } - } - // Initialize `Self` for a node with no predecessor pub(crate) fn new_dummy_predecessor() -> Self { Self { @@ -820,22 +839,6 @@ pub(crate) mod tests { hash: *empty_poseidon_hash(), } } - - /// Generate at random data about the successor/predecessor of a node. The generated - /// predecessor/successor must have the `value` provided as input; - /// the existence of the generated predecessor/successor depends on the `is_found` input: - /// - if `is_found` is `None`, then the existence of the generated predecessor/successor - /// is chosen at random - /// - otherwise, the generated predecessor/successor will be marked as found if and only if - /// the flag wrapped by `is_found` is `true` - pub(crate) fn sample(rng: &mut R, value: U256, is_found: Option) -> Self { - NeighborInfo { - is_found: is_found.unwrap_or(rng.gen()), - is_in_path: rng.gen(), - value, - hash: gen_random_field_hash(), - } - } } #[derive(Clone, Debug)] diff --git a/verifiable-db/src/query/mod.rs b/verifiable-db/src/query/mod.rs index cee9a91b1..4b69d5497 100644 --- a/verifiable-db/src/query/mod.rs +++ b/verifiable-db/src/query/mod.rs @@ -1,13 +1,15 @@ use plonky2::iop::target::Target; use public_inputs::PublicInputs; -pub mod aggregation; pub mod api; -pub mod batching; pub mod computational_hash_ids; pub mod merkle_path; pub mod public_inputs; pub mod universal_circuit; +pub(crate) mod circuits; +pub(crate) mod row_chunk_gadgets; +pub(crate) mod output_computation; +pub mod utils; pub const fn pi_len() -> usize { PublicInputs::::total_len() diff --git a/verifiable-db/src/query/aggregation/output_computation.rs b/verifiable-db/src/query/output_computation.rs similarity index 99% rename from verifiable-db/src/query/aggregation/output_computation.rs rename to verifiable-db/src/query/output_computation.rs index 04b71f2bc..76e126451 100644 --- a/verifiable-db/src/query/aggregation/output_computation.rs +++ b/verifiable-db/src/query/output_computation.rs @@ -156,7 +156,7 @@ pub(crate) mod tests { use super::*; use crate::{ query::{ - aggregation::tests::compute_output_item_value, pi_len, + utils::tests::compute_output_item_value, pi_len, public_inputs::PublicInputs, universal_circuit::universal_query_gadget::CurveOrU256, }, diff --git a/verifiable-db/src/query/public_inputs.rs b/verifiable-db/src/query/public_inputs.rs index d64ef98b1..d3d1736be 100644 --- a/verifiable-db/src/query/public_inputs.rs +++ b/verifiable-db/src/query/public_inputs.rs @@ -16,13 +16,13 @@ use plonky2::{ use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; use crate::query::{ - aggregation::output_computation::compute_dummy_output_targets, + output_computation::compute_dummy_output_targets, universal_circuit::universal_query_gadget::{ CurveOrU256Target, OutputValues, OutputValuesTarget, UniversalQueryOutputWires, }, }; -use super::batching::row_chunk::{BoundaryRowDataTarget, RowChunkDataTarget}; +use super::row_chunk_gadgets::{BoundaryRowDataTarget, RowChunkDataTarget}; /// Query circuits public inputs pub enum QueryPublicInputs { @@ -590,138 +590,20 @@ impl PublicInputsUniversalCircuit<'_, F, S> { #[cfg(test)] pub(crate) mod tests { - use std::array; - - use alloy::primitives::U256; - use itertools::Itertools; - use mp2_common::{array::ToField, public_inputs::PublicInputCommon, utils::ToFields, C, D, F}; + use mp2_common::{public_inputs::PublicInputCommon, utils::ToFields, C, D, F}; use mp2_test::{ circuit::{run_circuit, UserCircuit}, - utils::{gen_random_field_hash, gen_random_u256, random_vector}, + utils::random_vector, }; use plonky2::{ - field::types::{Field, Sample}, iop::{ target::Target, witness::{PartialWitness, WitnessWrite}, }, plonk::circuit_builder::CircuitBuilder, }; - use plonky2_ecgfp5::curve::curve::Point; - use rand::{thread_rng, Rng}; - - use crate::query::{ - aggregation::{QueryBoundSource, QueryBounds}, - batching::row_chunk::tests::BoundaryRowData, - computational_hash_ids::{AggregationOperation, Identifiers}, - universal_circuit::universal_circuit_inputs::Placeholders, - }; - use super::{OutputValues, PublicInputsFactory, PublicInputs, QueryPublicInputs}; - - /// Generate a set of values in a given range ensuring that the i+1-th generated value is - /// bigger than the i-th generated value - pub(crate) fn gen_values_in_range( - rng: &mut R, - lower: U256, - upper: U256, - ) -> [U256; N] { - assert!(upper >= lower, "{upper} is smaller than {lower}"); - let mut prev_value = lower; - array::from_fn(|_| { - let range = (upper - prev_value).checked_add(U256::from(1)); - let gen_value = match range { - Some(range) => prev_value + gen_random_u256(rng) % range, - None => gen_random_u256(rng), - }; - prev_value = gen_value; - gen_value - }) - } - - impl PublicInputsFactory<'_, F, S, UNIVERSAL_CIRCUIT> { - pub(crate) fn sample_from_ops(ops: &[F; S]) -> [Vec; NUM_INPUTS] - where - [(); S - 1]:, - { - let rng = &mut thread_rng(); - - let tree_hash = gen_random_field_hash(); - let computational_hash = gen_random_field_hash(); - let placeholder_hash = gen_random_field_hash(); - let [min_primary, max_primary] = gen_values_in_range(rng, U256::ZERO, U256::MAX); - let [min_secondary, max_secondary] = gen_values_in_range(rng, U256::ZERO, U256::MAX); - - let query_bounds = { - let placeholders = Placeholders::new_empty(min_primary, max_primary); - QueryBounds::new( - &placeholders, - Some(QueryBoundSource::Constant(min_secondary)), - Some(QueryBoundSource::Constant(max_secondary)), - ) - .unwrap() - }; - - let is_first_op_id = - ops[0] == Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - let mut previous_row: Option = None; - array::from_fn(|_| { - // generate output values - let output_values = if is_first_op_id { - // generate random curve point - OutputValues::::new_outputs_no_aggregation(&Point::sample(rng)) - } else { - let values = (0..S).map(|_| gen_random_u256(rng)).collect_vec(); - OutputValues::::new_aggregation_outputs(&values) - }; - // generate random count and overflow flag - let count = F::from_canonical_u32(rng.gen()); - let overflow = F::from_bool(rng.gen()); - // generate boundary rows - let left_boundary_row = if let Some(row) = &previous_row { - row.sample_consecutive_row(rng, &query_bounds) - } else { - BoundaryRowData::sample(rng, &query_bounds) - }; - let right_boundary_row = BoundaryRowData::sample(rng, &query_bounds); - assert!( - left_boundary_row.index_node_info.predecessor_info.value >= min_primary - && left_boundary_row.index_node_info.predecessor_info.value <= max_primary - ); - assert!( - left_boundary_row.index_node_info.successor_info.value >= min_primary - && left_boundary_row.index_node_info.successor_info.value <= max_primary - ); - assert!( - right_boundary_row.index_node_info.predecessor_info.value >= min_primary - && right_boundary_row.index_node_info.predecessor_info.value <= max_primary - ); - assert!( - right_boundary_row.index_node_info.successor_info.value >= min_primary - && right_boundary_row.index_node_info.successor_info.value <= max_primary - ); - previous_row = Some(right_boundary_row.clone()); - - PublicInputs::::new( - &tree_hash.to_fields(), - &output_values.to_fields(), - &[count], - ops, - &left_boundary_row.to_fields(), - &right_boundary_row.to_fields(), - &min_primary.to_fields(), - &max_primary.to_fields(), - &min_secondary.to_fields(), - &max_secondary.to_fields(), - &[overflow], - &computational_hash.to_fields(), - &placeholder_hash.to_fields(), - ) - .to_vec() - }) - } - } + use super::{PublicInputs, QueryPublicInputs}; const S: usize = 10; #[derive(Clone, Debug)] diff --git a/verifiable-db/src/query/batching/row_chunk/aggregate_chunks.rs b/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs similarity index 98% rename from verifiable-db/src/query/batching/row_chunk/aggregate_chunks.rs rename to verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs index e8d09fa3a..c3a64296b 100644 --- a/verifiable-db/src/query/batching/row_chunk/aggregate_chunks.rs +++ b/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs @@ -124,16 +124,15 @@ mod tests { use crate::{ query::{ - aggregation::{tests::aggregate_output_values, ChildPosition, NodeInfo}, - batching::row_chunk::{ - tests::{BoundaryRowData, BoundaryRowNodeInfo, RowChunkData}, + utils::{tests::aggregate_output_values, ChildPosition, NodeInfo}, + row_chunk_gadgets::{BoundaryRowData, BoundaryRowNodeInfo, tests::RowChunkData, BoundaryRowDataTarget, BoundaryRowNodeInfoTarget, RowChunkDataTarget, }, public_inputs::PublicInputs, computational_hash_ids::{AggregationOperation, Identifiers}, merkle_path::{ - tests::{build_node, generate_test_tree, NeighborInfo}, - MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, + tests::{build_node, generate_test_tree}, + NeighborInfo, MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, }, universal_circuit::universal_query_gadget::{ OutputValues, OutputValuesTarget, UniversalQueryOutputWires, diff --git a/verifiable-db/src/query/batching/row_chunk/consecutive_rows.rs b/verifiable-db/src/query/row_chunk_gadgets/consecutive_rows.rs similarity index 99% rename from verifiable-db/src/query/batching/row_chunk/consecutive_rows.rs rename to verifiable-db/src/query/row_chunk_gadgets/consecutive_rows.rs index 6471386ad..800f6c846 100644 --- a/verifiable-db/src/query/batching/row_chunk/consecutive_rows.rs +++ b/verifiable-db/src/query/row_chunk_gadgets/consecutive_rows.rs @@ -235,7 +235,7 @@ mod tests { use rand::thread_rng; use crate::query::{ - aggregation::{ChildPosition, NodeInfo}, + utils::{ChildPosition, NodeInfo}, merkle_path::{ tests::{build_node, generate_test_tree}, MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, diff --git a/verifiable-db/src/query/batching/row_chunk/mod.rs b/verifiable-db/src/query/row_chunk_gadgets/mod.rs similarity index 71% rename from verifiable-db/src/query/batching/row_chunk/mod.rs rename to verifiable-db/src/query/row_chunk_gadgets/mod.rs index 9a3628e83..ee1f745e1 100644 --- a/verifiable-db/src/query/batching/row_chunk/mod.rs +++ b/verifiable-db/src/query/row_chunk_gadgets/mod.rs @@ -4,20 +4,25 @@ //! the chunk are labelled as the `left_boundary_row` and the `right_boundary_row`, //! respectively, and are the rows employed to aggregate 2 different chunks. +use alloy::primitives::U256; use mp2_common::{ serialization::circuit_data_serialization::SerializableRichField, - utils::{FromTargets, HashBuilder, SelectTarget, ToTargets}, + utils::{FromFields, FromTargets, HashBuilder, SelectTarget, ToFields, ToTargets}, F, }; +use mp2_test::utils::gen_random_field_hash; use plonky2::{ - hash::hash_types::{HashOutTarget, NUM_HASH_OUT_ELTS}, + hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, iop::target::{BoolTarget, Target}, plonk::circuit_builder::CircuitBuilder, }; +use rand::Rng; -use crate::query::{ +use crate::{query::{ merkle_path::{MerklePathWithNeighborsTarget, NeighborInfoTarget}, universal_circuit::universal_query_gadget::UniversalQueryOutputWires, -}; +}, test_utils::gen_values_in_range}; + +use super::{merkle_path::NeighborInfo, utils::QueryBounds}; /// This module contains gadgets to aggregate 2 different row chunks pub(crate) mod aggregate_chunks; @@ -208,38 +213,43 @@ where } } -#[cfg(test)] -pub(crate) mod tests { - use alloy::primitives::U256; - use mp2_common::{ - utils::{FromFields, FromTargets, ToFields}, - F, - }; - use mp2_test::utils::gen_random_field_hash; - use plonky2::{ - field::types::Field, - hash::hash_types::{HashOut, NUM_HASH_OUT_ELTS}, - }; - use rand::Rng; - - use crate::query::{ - aggregation::QueryBounds, - batching::row_chunk::BoundaryRowDataTarget, - merkle_path::{tests::NeighborInfo, NeighborInfoTarget}, - universal_circuit::universal_query_gadget::OutputValues, - public_inputs::tests::gen_values_in_range, - }; - use super::BoundaryRowNodeInfoTarget; +#[derive(Clone, Debug)] +pub(crate) struct BoundaryRowNodeInfo { + pub(crate) end_node_hash: HashOut, + pub(crate) predecessor_info: NeighborInfo, + pub(crate) successor_info: NeighborInfo, +} - #[derive(Clone, Debug)] - pub(crate) struct BoundaryRowNodeInfo { - pub(crate) end_node_hash: HashOut, - pub(crate) predecessor_info: NeighborInfo, - pub(crate) successor_info: NeighborInfo, +impl ToFields for BoundaryRowNodeInfo { + fn to_fields(&self) -> Vec { + self.end_node_hash + .to_fields() + .into_iter() + .chain(self.predecessor_info.to_fields()) + .chain(self.successor_info.to_fields()) + .collect() + } +} + +impl FromFields for BoundaryRowNodeInfo { + fn from_fields(t: &[F]) -> Self { + assert!(t.len() >= BoundaryRowNodeInfoTarget::NUM_TARGETS); + let end_node_hash = HashOut::from_partial(&t[..NUM_HASH_OUT_ELTS]); + let predecessor_info = NeighborInfo::from_fields(&t[NUM_HASH_OUT_ELTS..]); + let successor_info = NeighborInfo::from_fields( + &t[NUM_HASH_OUT_ELTS + NeighborInfoTarget::NUM_TARGETS..], + ); + + Self { + end_node_hash, + predecessor_info, + successor_info, + } } +} - impl BoundaryRowNodeInfo { +impl BoundaryRowNodeInfo { /// Generate an instance of `Self` representing a random node, given the `query_bounds` /// provided as input and a flag `is_index_tree` specifying whether the random node /// should be part of an index tree or of a rows tree. It is used to generate test data @@ -369,132 +379,121 @@ pub(crate) mod tests { successor_info, } } - } +} - impl ToFields for BoundaryRowNodeInfo { - fn to_fields(&self) -> Vec { - self.end_node_hash - .to_fields() - .into_iter() - .chain(self.predecessor_info.to_fields()) - .chain(self.successor_info.to_fields()) - .collect() - } +#[derive(Clone, Debug)] +pub(crate) struct BoundaryRowData { + pub(crate) row_node_info: BoundaryRowNodeInfo, + pub(crate) index_node_info: BoundaryRowNodeInfo, +} + +impl ToFields for BoundaryRowData { + fn to_fields(&self) -> Vec { + self.row_node_info + .to_fields() + .into_iter() + .chain(self.index_node_info.to_fields()) + .collect() } +} - impl FromFields for BoundaryRowNodeInfo { - fn from_fields(t: &[F]) -> Self { - assert!(t.len() >= BoundaryRowNodeInfoTarget::NUM_TARGETS); - let end_node_hash = HashOut::from_partial(&t[..NUM_HASH_OUT_ELTS]); - let predecessor_info = NeighborInfo::from_fields(&t[NUM_HASH_OUT_ELTS..]); - let successor_info = NeighborInfo::from_fields( - &t[NUM_HASH_OUT_ELTS + NeighborInfoTarget::NUM_TARGETS..], - ); +impl FromFields for BoundaryRowData { + fn from_fields(t: &[F]) -> Self { + assert!(t.len() >= BoundaryRowDataTarget::NUM_TARGETS); + let row_node_info = BoundaryRowNodeInfo::from_fields(t); + let index_node_info = + BoundaryRowNodeInfo::from_fields(&t[BoundaryRowNodeInfoTarget::NUM_TARGETS..]); - Self { - end_node_hash, - predecessor_info, - successor_info, - } + Self { + row_node_info, + index_node_info, } } - #[derive(Clone, Debug)] - pub(crate) struct BoundaryRowData { - pub(crate) row_node_info: BoundaryRowNodeInfo, - pub(crate) index_node_info: BoundaryRowNodeInfo, - } +} - impl ToFields for BoundaryRowData { - fn to_fields(&self) -> Vec { - self.row_node_info - .to_fields() - .into_iter() - .chain(self.index_node_info.to_fields()) - .collect() +impl BoundaryRowData { + /// Generate a random instance of `Self`, given the `query_bounds` provided as inputs. + /// It is employed to generate test data without the need to build an actual test tree + pub(crate) fn sample(rng: &mut R, query_bounds: &QueryBounds) -> Self { + Self { + row_node_info: BoundaryRowNodeInfo::sample(rng, query_bounds, false), + index_node_info: BoundaryRowNodeInfo::sample(rng, query_bounds, true), } } - impl FromFields for BoundaryRowData { - fn from_fields(t: &[F]) -> Self { - assert!(t.len() >= BoundaryRowDataTarget::NUM_TARGETS); - let row_node_info = BoundaryRowNodeInfo::from_fields(t); + /// Given the boundary row `self`, generates at random the data of the consecutive row of + /// `self`, given the `query_bounds` provided as input. It is employed to generate test data + /// without the need to build an actual test tree + pub(crate) fn sample_consecutive_row( + &self, + rng: &mut R, + query_bounds: &QueryBounds, + ) -> Self { + if self.row_node_info.successor_info.is_found + && self.row_node_info.successor_info.value + <= *query_bounds.max_query_secondary().value() + { + // the successor must be in the same rows tree + let row_node_info = + self.row_node_info + .sample_successor_in_tree(rng, query_bounds, false); + Self { + row_node_info, + index_node_info: self.index_node_info.clone(), + } + } else { + // the successor must be in a different rows tree + let end_node_hash = gen_random_field_hash(); + // predecessor value must be out of range in this case + let [predecessor_value] = gen_values_in_range( + rng, + U256::ZERO, + query_bounds + .min_query_secondary() + .value() + .checked_sub(U256::from(1)) + .unwrap_or(U256::ZERO), + ); + let predecessor_info = NeighborInfo::sample(rng, predecessor_value, None); + let [successor_value] = gen_values_in_range( + rng, + predecessor_value.max(*query_bounds.min_query_secondary().value()), // successor value must + // always be greater than min_secondary in circuit + U256::MAX, + ); + let successor_info = NeighborInfo::sample(rng, successor_value, None); + let row_node_info = BoundaryRowNodeInfo { + end_node_hash, + predecessor_info, + successor_info, + }; + // index tree node must be a successor of `self.index_node` let index_node_info = - BoundaryRowNodeInfo::from_fields(&t[BoundaryRowNodeInfoTarget::NUM_TARGETS..]); - + self.index_node_info + .sample_successor_in_tree(rng, query_bounds, true); Self { row_node_info, index_node_info, } } } +} - impl BoundaryRowData { - /// Generate a random instance of `Self`, given the `query_bounds` provided as inputs. - /// It is employed to generate test data without the need to build an actual test tree - pub(crate) fn sample(rng: &mut R, query_bounds: &QueryBounds) -> Self { - Self { - row_node_info: BoundaryRowNodeInfo::sample(rng, query_bounds, false), - index_node_info: BoundaryRowNodeInfo::sample(rng, query_bounds, true), - } - } +#[cfg(test)] +pub(crate) mod tests { + use mp2_common::{ + utils::ToFields, + F, + }; + use plonky2::{ + field::types::Field, + hash::hash_types::HashOut, + }; + + use crate::query::universal_circuit::universal_query_gadget::OutputValues; + + use super::BoundaryRowData; - /// Given the boundary row `self`, generates at random the data of the consecutive row of - /// `self`, given the `query_bounds` provided as input. It is employed to generate test data - /// without the need to build an actual test tree - pub(crate) fn sample_consecutive_row( - &self, - rng: &mut R, - query_bounds: &QueryBounds, - ) -> Self { - if self.row_node_info.successor_info.is_found - && self.row_node_info.successor_info.value - <= *query_bounds.max_query_secondary().value() - { - // the successor must be in the same rows tree - let row_node_info = - self.row_node_info - .sample_successor_in_tree(rng, query_bounds, false); - Self { - row_node_info, - index_node_info: self.index_node_info.clone(), - } - } else { - // the successor must be in a different rows tree - let end_node_hash = gen_random_field_hash(); - // predecessor value must be out of range in this case - let [predecessor_value] = gen_values_in_range( - rng, - U256::ZERO, - query_bounds - .min_query_secondary() - .value() - .checked_sub(U256::from(1)) - .unwrap_or(U256::ZERO), - ); - let predecessor_info = NeighborInfo::sample(rng, predecessor_value, None); - let [successor_value] = gen_values_in_range( - rng, - predecessor_value.max(*query_bounds.min_query_secondary().value()), // successor value must - // always be greater than min_secondary in circuit - U256::MAX, - ); - let successor_info = NeighborInfo::sample(rng, successor_value, None); - let row_node_info = BoundaryRowNodeInfo { - end_node_hash, - predecessor_info, - successor_info, - }; - // index tree node must be a successor of `self.index_node` - let index_node_info = - self.index_node_info - .sample_successor_in_tree(rng, query_bounds, true); - Self { - row_node_info, - index_node_info, - } - } - } - } #[derive(Clone, Debug)] pub(crate) struct RowChunkData where diff --git a/verifiable-db/src/query/batching/row_chunk/row_process_gadget.rs b/verifiable-db/src/query/row_chunk_gadgets/row_process_gadget.rs similarity index 100% rename from verifiable-db/src/query/batching/row_chunk/row_process_gadget.rs rename to verifiable-db/src/query/row_chunk_gadgets/row_process_gadget.rs diff --git a/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs b/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs index 46fbd3f89..b56aaf68d 100644 --- a/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs +++ b/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs @@ -1,7 +1,7 @@ use std::iter::once; use crate::query::{ - aggregation::QueryBounds, public_inputs::PublicInputsUniversalCircuit, batching::row_chunk::BoundaryRowDataTarget, computational_hash_ids::{Output, PlaceholderIdentifier}, pi_len + computational_hash_ids::{Output, PlaceholderIdentifier}, pi_len, public_inputs::PublicInputsUniversalCircuit, row_chunk_gadgets::BoundaryRowDataTarget, utils::QueryBounds }; use anyhow::Result; use itertools::Itertools; @@ -311,7 +311,7 @@ where } #[derive(Debug, Serialize, Deserialize)] -pub(crate) struct UniversalQueryCircuitParams< +pub struct UniversalQueryCircuitParams< const MAX_NUM_COLUMNS: usize, const MAX_NUM_PREDICATE_OPS: usize, const MAX_NUM_RESULT_OPS: usize, @@ -499,7 +499,7 @@ mod tests { use rand::{thread_rng, Rng}; use crate::query::{ - aggregation::{QueryBoundSource, QueryBounds}, + utils::{QueryBoundSource, QueryBounds}, computational_hash_ids::{ AggregationOperation, ColumnIDs, HashPermutation, Identifiers, Operation, PlaceholderIdentifier, diff --git a/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs b/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs index 0fd2a9988..ef3df751e 100644 --- a/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs +++ b/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs @@ -28,7 +28,7 @@ use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget use serde::{Deserialize, Serialize}; use crate::query::{ - aggregation::{QueryBoundSecondary, QueryBoundSource, QueryBounds}, + utils::{QueryBoundSecondary, QueryBoundSource, QueryBounds}, computational_hash_ids::{ ColumnIDs, ComputationalHashCache, HashPermutation, Operation, Output, PlaceholderIdentifier, @@ -1094,7 +1094,6 @@ impl OutputValues where [(); MAX_NUM_RESULTS - 1]:, { - #[cfg(test)] // used only in test for now pub(crate) fn new_aggregation_outputs(values: &[U256]) -> Self { let first_output = CurveOrU256::::from_slice(&values[0].to_fields()); let other_outputs = values[1..] @@ -1110,7 +1109,6 @@ where } } - #[cfg(test)] // used only in test for now pub(crate) fn new_outputs_no_aggregation(point: &plonky2_ecgfp5::curve::curve::Point) -> Self { let first_output = CurveOrU256::::from_slice(&point.to_fields()); Self { diff --git a/verifiable-db/src/query/aggregation/mod.rs b/verifiable-db/src/query/utils.rs similarity index 81% rename from verifiable-db/src/query/aggregation/mod.rs rename to verifiable-db/src/query/utils.rs index 3afa18042..bf85c0c93 100644 --- a/verifiable-db/src/query/aggregation/mod.rs +++ b/verifiable-db/src/query/utils.rs @@ -5,7 +5,6 @@ use anyhow::Result; use itertools::Itertools; use mp2_common::{ poseidon::{empty_poseidon_hash, HashPermutation}, - proof::ProofWithVK, serialization::{ deserialize, deserialize_array, deserialize_long_array, serialize, serialize_array, serialize_long_array, @@ -28,7 +27,6 @@ use plonky2::{ }; use serde::{Deserialize, Serialize}; -pub(crate) mod output_computation; use super::{ computational_hash_ids::{ColumnIDs, Identifiers, PlaceholderIdentifier}, @@ -307,119 +305,6 @@ pub enum ChildPosition { Right, } -impl ChildPosition { - // convert `self` to a flag specifying whether a node is the left child of another node or not - pub(crate) fn to_flag(self) -> bool { - match self { - ChildPosition::Left => true, - ChildPosition::Right => false, - } - } -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub(crate) struct CommonInputs { - pub(crate) is_rows_tree_node: bool, - pub(crate) min_query: U256, - pub(crate) max_query: U256, -} - -impl CommonInputs { - pub(crate) fn new(is_rows_tree_node: bool, query_bounds: &QueryBounds) -> Self { - Self { - is_rows_tree_node, - min_query: if is_rows_tree_node { - query_bounds.min_query_secondary.value - } else { - query_bounds.min_query_primary - }, - max_query: if is_rows_tree_node { - query_bounds.max_query_secondary.value - } else { - query_bounds.max_query_primary - }, - } - } -} -/// Input data structure for circuits employed for nodes where both the children and the embedded tree are proven -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct TwoProvenChildNodeInput { - /// Proof for the left child of the node being proven - pub(crate) left_child_proof: ProofWithVK, - /// Proof for the right child of the node being proven - pub(crate) right_child_proof: ProofWithVK, - /// Proof for the embedded tree stored in the current node - pub(crate) embedded_tree_proof: ProofWithVK, - /// Common inputs shared across all the circuits - pub(crate) common: CommonInputs, -} -/// Input data structure for circuits employed for nodes where one child and the embedded tree are proven -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct OneProvenChildNodeInput { - /// Data related to the child not associated with a proof, if any - pub(crate) unproven_child: Option, - /// Proof for the proven child - pub(crate) proven_child_proof: ChildProof, - /// Proof for the embedded tree stored in the current node - pub(crate) embedded_tree_proof: ProofWithVK, - /// Common inputs shared across all the circuits - pub(crate) common: CommonInputs, -} -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -/// Data structure representing a proof for a child node -pub struct ChildProof { - /// Actual proof - pub(crate) proof: ProofWithVK, - /// Flag specifying whether the child associated with `proof` is the left or right child of its parent - pub(crate) child_position: ChildPosition, -} - -impl ChildProof { - pub fn new(proof: Vec, child_position: ChildPosition) -> Result { - Ok(Self { - proof: ProofWithVK::deserialize(&proof)?, - child_position, - }) - } -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -/// Enum employed to specify whether a proof refers to a child node or the embedded tree stored in a node -pub enum SubProof { - /// Proof refer to a child - Child(ChildProof), - /// Proof refer to the embedded tree stored in the node: can be either the proof for a single row - /// (if proving a rows tree node) of the proof for the root node of a rows tree (if proving an index tree node) - Embedded(ProofWithVK), -} - -impl SubProof { - /// Initialize a new `SubProof::Child` - pub fn new_child_proof(proof: Vec, child_position: ChildPosition) -> Result { - Ok(SubProof::Child(ChildProof::new(proof, child_position)?)) - } - - /// Initialize a new `SubProof::Embedded` - pub fn new_embedded_tree_proof(proof: Vec) -> Result { - Ok(SubProof::Embedded(ProofWithVK::deserialize(&proof)?)) - } -} - -/// Input data structure for circuits employed for nodes where only one among children node and embedded tree is proven -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct SinglePathInput { - /// Data about the left child of the node being proven, if any - pub(crate) left_child: Option, - /// Data about the right child of the node being proven, if any - pub(crate) right_child: Option, - /// Data about the node being proven - pub(crate) node_info: NodeInfo, - /// Proof of either a child node or of the embedded tree stored in the current node - pub(crate) subtree_proof: SubProof, - /// Common inputs shared across all the circuits - pub(crate) common: CommonInputs, -} - /// Data structure containing the computational hash and placeholder hash to be provided as input to /// non-existence circuits. These hashes are computed from the query specific data provided as input /// to the initialization method of this data structure diff --git a/verifiable-db/src/revelation/api.rs b/verifiable-db/src/revelation/api.rs index 04cbd27fb..ec2264d47 100644 --- a/verifiable-db/src/revelation/api.rs +++ b/verifiable-db/src/revelation/api.rs @@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize}; use crate::{ query::{ - aggregation::QueryBounds, computational_hash_ids::ColumnIDs, pi_len as query_pi_len, universal_circuit::{output_no_aggregation::Circuit as OutputNoAggCircuit, universal_circuit_inputs::{ + utils::QueryBounds, computational_hash_ids::ColumnIDs, pi_len as query_pi_len, universal_circuit::{output_no_aggregation::Circuit as OutputNoAggCircuit, universal_circuit_inputs::{ BasicOperation, Placeholders, ResultStructure, }, universal_query_circuit::{UniversalCircuitInput, UniversalQueryCircuitParams}} }, @@ -301,7 +301,6 @@ where placeholder_hash_ids, )?, }; - println!("{:?}", revelation_circuit); Ok(CircuitInput::NoResultsTree { query_proof, @@ -477,7 +476,7 @@ where } } - pub(crate) fn generate_proof( + pub fn generate_proof( &self, input: CircuitInput< ROW_TREE_MAX_DEPTH, diff --git a/verifiable-db/src/revelation/placeholders_check.rs b/verifiable-db/src/revelation/placeholders_check.rs index 83c393d4e..e5df6a249 100644 --- a/verifiable-db/src/revelation/placeholders_check.rs +++ b/verifiable-db/src/revelation/placeholders_check.rs @@ -2,7 +2,7 @@ //! compute and return the `num_placeholders` and the `placeholder_ids_hash`. use crate::query::{ - aggregation::QueryBounds, + utils::QueryBounds, computational_hash_ids::PlaceholderIdentifier, universal_circuit::{ universal_circuit_inputs::{PlaceholderId, Placeholders}, diff --git a/verifiable-db/src/revelation/revelation_unproven_offset.rs b/verifiable-db/src/revelation/revelation_unproven_offset.rs index e1c68c5ab..55ab42c0a 100644 --- a/verifiable-db/src/revelation/revelation_unproven_offset.rs +++ b/verifiable-db/src/revelation/revelation_unproven_offset.rs @@ -49,7 +49,7 @@ use serde::{Deserialize, Serialize}; use crate::{ ivc::PublicInputs as OriginalTreePublicInputs, query::{ - aggregation::{ChildPosition, NodeInfo, QueryBounds}, public_inputs::PublicInputsUniversalCircuit as QueryProofPublicInputs, computational_hash_ids::{ColumnIDs, ResultIdentifier}, merkle_path::{MerklePathGadget, MerklePathTargetInputs}, universal_circuit::{ + utils::{ChildPosition, NodeInfo, QueryBounds}, public_inputs::PublicInputsUniversalCircuit as QueryProofPublicInputs, computational_hash_ids::{ColumnIDs, ResultIdentifier}, merkle_path::{MerklePathGadget, MerklePathTargetInputs}, universal_circuit::{ build_cells_tree, universal_circuit_inputs::{BasicOperation, ColumnCell, Placeholders, ResultStructure, RowCells}, universal_query_circuit::{UniversalCircuitInput, UniversalQueryCircuitInputs}, } @@ -844,7 +844,7 @@ mod tests { PublicInputs as OriginalTreePublicInputs, }, query::{ - aggregation::{ChildPosition, NodeInfo}, + utils::{ChildPosition, NodeInfo}, public_inputs::{PublicInputsUniversalCircuit as QueryProofPublicInputs, QueryPublicInputsUniversalCircuit}, pi_len as query_pi_len, }, revelation::{ diff --git a/verifiable-db/src/revelation/revelation_without_results_tree.rs b/verifiable-db/src/revelation/revelation_without_results_tree.rs index c37e8c0c1..282a9d75c 100644 --- a/verifiable-db/src/revelation/revelation_without_results_tree.rs +++ b/verifiable-db/src/revelation/revelation_without_results_tree.rs @@ -393,7 +393,7 @@ mod tests { use crate::{ ivc::PublicInputs as OriginalTreePublicInputs, query::{ - aggregation::{QueryBoundSource, QueryBounds}, + utils::{QueryBoundSource, QueryBounds}, public_inputs::{ PublicInputs as QueryProofPublicInputs, QueryPublicInputs, diff --git a/verifiable-db/src/test_utils.rs b/verifiable-db/src/test_utils.rs index 3fa940683..5b508b3e1 100644 --- a/verifiable-db/src/test_utils.rs +++ b/verifiable-db/src/test_utils.rs @@ -3,12 +3,11 @@ use crate::{ ivc::public_inputs::H_RANGE as ORIGINAL_TREE_H_RANGE, query::{ - aggregation::{QueryBounds, QueryHashNonExistenceCircuits}, batching::row_chunk::tests::BoundaryRowData, computational_hash_ids::{ + computational_hash_ids::{ AggregationOperation, ColumnIDs, Identifiers, Operation, PlaceholderIdentifier, - }, universal_circuit::universal_circuit_inputs::{ + }, public_inputs::{PublicInputs as QueryPI, PublicInputsFactory, QueryPublicInputs}, row_chunk_gadgets::BoundaryRowData, universal_circuit::{universal_circuit_inputs::{ BasicOperation, ColumnCell, InputOperand, OutputItem, Placeholders, ResultStructure, - }, - public_inputs::{tests::gen_values_in_range, PublicInputs as QueryPI, QueryPublicInputs} + }, universal_query_gadget::OutputValues}, utils::{QueryBoundSource, QueryBounds, QueryHashNonExistenceCircuits} }, revelation::NUM_PREPROCESSING_IO, }; @@ -19,11 +18,13 @@ use mp2_common::{ utils::{Fieldable, ToFields}, F, }; +use mp2_test::utils::{gen_random_field_hash, gen_random_u256}; use plonky2::{ - field::types::PrimeField64, + field::types::{PrimeField64, Sample, Field}, hash::hash_types::HashOut, plonk::config::GenericHashOut, }; +use plonky2_ecgfp5::curve::curve::Point; use rand::{prelude::SliceRandom, thread_rng, Rng}; use std::array; @@ -43,6 +44,27 @@ pub const ROW_TREE_MAX_DEPTH: usize = 10; pub const INDEX_TREE_MAX_DEPTH: usize = 15; pub const NUM_COLUMNS: usize = 4; + +/// Generate a set of values in a given range ensuring that the i+1-th generated value is +/// bigger than the i-th generated value +pub(crate) fn gen_values_in_range( + rng: &mut R, + lower: U256, + upper: U256, +) -> [U256; N] { + assert!(upper >= lower, "{upper} is smaller than {lower}"); + let mut prev_value = lower; + array::from_fn(|_| { + let range = (upper - prev_value).checked_add(U256::from(1)); + let gen_value = match range { + Some(range) => prev_value + gen_random_u256(rng) % range, + None => gen_random_u256(rng), + }; + prev_value = gen_value; + gen_value + }) +} + /// Generate a random original tree proof for testing. pub fn random_original_tree_proof(tree_hash: HashOut) -> Vec { let mut rng = thread_rng(); @@ -74,6 +96,90 @@ pub fn random_aggregation_operations() -> [F; S] { }) } +impl PublicInputsFactory<'_, F, S, UNIVERSAL_CIRCUIT> { + pub(crate) fn sample_from_ops(ops: &[F; S]) -> [Vec; NUM_INPUTS] + where + [(); S - 1]:, + { + let rng = &mut thread_rng(); + + let tree_hash = gen_random_field_hash(); + let computational_hash = gen_random_field_hash(); + let placeholder_hash = gen_random_field_hash(); + let [min_primary, max_primary] = gen_values_in_range(rng, U256::ZERO, U256::MAX); + let [min_secondary, max_secondary] = gen_values_in_range(rng, U256::ZERO, U256::MAX); + + let query_bounds = { + let placeholders = Placeholders::new_empty(min_primary, max_primary); + QueryBounds::new( + &placeholders, + Some(QueryBoundSource::Constant(min_secondary)), + Some(QueryBoundSource::Constant(max_secondary)), + ) + .unwrap() + }; + + let is_first_op_id = + ops[0] == Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); + + let mut previous_row: Option = None; + array::from_fn(|_| { + // generate output values + let output_values = if is_first_op_id { + // generate random curve point + OutputValues::::new_outputs_no_aggregation(&Point::sample(rng)) + } else { + let values = (0..S).map(|_| gen_random_u256(rng)).collect_vec(); + OutputValues::::new_aggregation_outputs(&values) + }; + // generate random count and overflow flag + let count = F::from_canonical_u32(rng.gen()); + let overflow = F::from_bool(rng.gen()); + // generate boundary rows + let left_boundary_row = if let Some(row) = &previous_row { + row.sample_consecutive_row(rng, &query_bounds) + } else { + BoundaryRowData::sample(rng, &query_bounds) + }; + let right_boundary_row = BoundaryRowData::sample(rng, &query_bounds); + assert!( + left_boundary_row.index_node_info.predecessor_info.value >= min_primary + && left_boundary_row.index_node_info.predecessor_info.value <= max_primary + ); + assert!( + left_boundary_row.index_node_info.successor_info.value >= min_primary + && left_boundary_row.index_node_info.successor_info.value <= max_primary + ); + assert!( + right_boundary_row.index_node_info.predecessor_info.value >= min_primary + && right_boundary_row.index_node_info.predecessor_info.value <= max_primary + ); + assert!( + right_boundary_row.index_node_info.successor_info.value >= min_primary + && right_boundary_row.index_node_info.successor_info.value <= max_primary + ); + previous_row = Some(right_boundary_row.clone()); + + PublicInputsFactory::::new( + &tree_hash.to_fields(), + &output_values.to_fields(), + &[count], + ops, + &left_boundary_row.to_fields(), + &right_boundary_row.to_fields(), + &min_primary.to_fields(), + &max_primary.to_fields(), + &min_secondary.to_fields(), + &max_secondary.to_fields(), + &[overflow], + &computational_hash.to_fields(), + &placeholder_hash.to_fields(), + ) + .to_vec() + }) + } +} + /// Revelation related data used for testing #[derive(Debug)] pub struct TestRevelationData { From 349f9d21171dd332a3472e8fb36cdd322004a368 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 6 Dec 2024 12:39:19 +0100 Subject: [PATCH 03/12] clippy + fmt --- groth16-framework/tests/common/context.rs | 12 +- groth16-framework/tests/common/query.rs | 6 +- mp2-v1/Cargo.toml | 1 - mp2-v1/src/query/planner.rs | 3 +- .../common/cases/query/aggregated_queries.rs | 63 ++-- mp2-v1/tests/common/cases/query/mod.rs | 18 +- .../cases/query/simple_select_queries.rs | 37 +-- parsil/src/assembler.rs | 2 +- parsil/src/queries.rs | 2 +- verifiable-db/Cargo.toml | 1 - verifiable-db/src/api.rs | 10 +- verifiable-db/src/query/api.rs | 161 +++++----- .../src/query/circuits/chunk_aggregation.rs | 6 +- verifiable-db/src/query/circuits/mod.rs | 20 +- .../src/query/circuits/non_existence.rs | 21 +- .../query/circuits/row_chunk_processing.rs | 36 ++- .../src/query/computational_hash_ids.rs | 2 +- verifiable-db/src/query/merkle_path.rs | 5 +- verifiable-db/src/query/mod.rs | 6 +- verifiable-db/src/query/output_computation.rs | 4 +- verifiable-db/src/query/public_inputs.rs | 58 ++-- .../row_chunk_gadgets/aggregate_chunks.rs | 13 +- .../row_chunk_gadgets/consecutive_rows.rs | 2 +- .../src/query/row_chunk_gadgets/mod.rs | 274 +++++++++--------- .../row_chunk_gadgets/row_process_gadget.rs | 2 +- .../universal_query_circuit.rs | 160 ++++++---- .../universal_query_gadget.rs | 25 +- verifiable-db/src/query/utils.rs | 5 +- .../results_tree/binding/binding_results.rs | 11 +- verifiable-db/src/results_tree/mod.rs | 12 +- .../src/results_tree/old_public_inputs.rs | 2 + verifiable-db/src/revelation/api.rs | 51 ++-- verifiable-db/src/revelation/mod.rs | 1 - .../src/revelation/placeholders_check.rs | 2 +- .../revelation/revelation_unproven_offset.rs | 147 +++++----- .../revelation_without_results_tree.rs | 65 +++-- verifiable-db/src/test_utils.rs | 160 +++++----- 37 files changed, 777 insertions(+), 629 deletions(-) diff --git a/groth16-framework/tests/common/context.rs b/groth16-framework/tests/common/context.rs index dc38470bf..ffb617c81 100644 --- a/groth16-framework/tests/common/context.rs +++ b/groth16-framework/tests/common/context.rs @@ -6,10 +6,13 @@ use mp2_common::{C, D, F}; use mp2_test::circuit::TestDummyCircuit; use recursion_framework::framework_testing::TestingRecursiveCircuits; use verifiable_db::{ - api::WrapCircuitParams, query::pi_len, revelation::api::Parameters as RevelationParameters, test_utils::{ + api::WrapCircuitParams, + query::pi_len, + revelation::api::Parameters as RevelationParameters, + test_utils::{ INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, MAX_NUM_PLACEHOLDERS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, ROW_TREE_MAX_DEPTH, - } + }, }; /// Test context @@ -39,8 +42,9 @@ impl TestContext { // Generate a fake query circuit set. let query_circuits = TestingRecursiveCircuits::::default(); - let dummy_universal_circuit = TestDummyCircuit::<{pi_len::()}>::build(); - + let dummy_universal_circuit = + TestDummyCircuit::<{ pi_len::() }>::build(); + // Create the revelation parameters. let revelation_params = RevelationParameters::< ROW_TREE_MAX_DEPTH, diff --git a/groth16-framework/tests/common/query.rs b/groth16-framework/tests/common/query.rs index 75a7db8bd..4fef29965 100644 --- a/groth16-framework/tests/common/query.rs +++ b/groth16-framework/tests/common/query.rs @@ -56,11 +56,7 @@ impl TestContext { .unwrap(); let revelation_proof = self .revelation_params - .generate_proof( - input, - self.query_circuits.get_recursive_circuit_set(), - None, - ) + .generate_proof(input, self.query_circuits.get_recursive_circuit_set(), None) .unwrap(); let revelation_proof = ProofWithVK::deserialize(&revelation_proof).unwrap(); let (revelation_proof_with_pi, _) = revelation_proof.clone().into(); diff --git a/mp2-v1/Cargo.toml b/mp2-v1/Cargo.toml index f15c4d2cf..d7b9b3856 100644 --- a/mp2-v1/Cargo.toml +++ b/mp2-v1/Cargo.toml @@ -59,4 +59,3 @@ parsil = { path = "../parsil" } [features] original_poseidon = ["mp2_common/original_poseidon"] -batching_circuits = ["verifiable-db/batching_circuits"] diff --git a/mp2-v1/src/query/planner.rs b/mp2-v1/src/query/planner.rs index 305d1a848..54734abd4 100644 --- a/mp2-v1/src/query/planner.rs +++ b/mp2-v1/src/query/planner.rs @@ -19,8 +19,8 @@ use ryhope::{ use std::{fmt::Debug, future::Future}; use tokio_postgres::{row::Row as PsqlRow, types::ToSql, NoTls}; use verifiable_db::query::{ - utils::{ChildPosition, NodeInfo, QueryBounds}, api::TreePathInputs, + utils::{ChildPosition, NodeInfo, QueryBounds}, }; use crate::indexing::{ @@ -375,7 +375,6 @@ impl< } } - /// Fetch a key `k` from a tree, assuming that the key is in the /// tree. Therefore, it handles differently the case when `k` is not found: /// - If `T::WIDE_LINEAGE` is true, then `k` might not be found because the diff --git a/mp2-v1/tests/common/cases/query/aggregated_queries.rs b/mp2-v1/tests/common/cases/query/aggregated_queries.rs index fe8daf45d..8aad454f1 100644 --- a/mp2-v1/tests/common/cases/query/aggregated_queries.rs +++ b/mp2-v1/tests/common/cases/query/aggregated_queries.rs @@ -36,7 +36,10 @@ use mp2_v1::{ cell::MerkleCell, row::{Row, RowPayload, RowTreeKey}, }, - query::{batching_planner::{generate_chunks_and_update_tree, UTForChunkProofs, UTKey}, planner::{execute_row_query, NonExistenceInput, TreeFetcher}}, + query::{ + batching_planner::{generate_chunks_and_update_tree, UTForChunkProofs, UTKey}, + planner::{execute_row_query, NonExistenceInput, TreeFetcher}, + }, }; use parsil::{ assembler::{DynamicCircuitPis, StaticCircuitPis}, @@ -56,15 +59,14 @@ use verifiable_db::{ ivc::PublicInputs as IndexingPIS, query::{ computational_hash_ids::{ColumnIDs, Identifiers}, - universal_circuit::universal_circuit_inputs::{ - ColumnCell, PlaceholderId, Placeholders, - }, + universal_circuit::universal_circuit_inputs::{ColumnCell, PlaceholderId, Placeholders}, }, revelation::PublicInputs, }; use super::{ - GlobalCircuitInput, QueryCircuitInput, QueryPlanner, RevelationCircuitInput, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, MAX_NUM_PLACEHOLDERS + GlobalCircuitInput, QueryCircuitInput, QueryPlanner, RevelationCircuitInput, + MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, MAX_NUM_PLACEHOLDERS, }; pub type RevelationPublicInputs<'a> = @@ -77,12 +79,21 @@ pub(crate) async fn prove_query( metadata: MetadataHash, planner: &mut QueryPlanner<'_>, ) -> Result<()> { - let row_cache = planner.table + let row_cache = planner + .table .row .wide_lineage_between( planner.table.row.current_epoch(), - &core_keys_for_row_tree(&planner.query.query, planner.settings, &planner.pis.bounds, &planner.query.placeholders)?, - (planner.query.min_block as Epoch, planner.query.max_block as Epoch), + &core_keys_for_row_tree( + &planner.query.query, + planner.settings, + &planner.pis.bounds, + &planner.query.placeholders, + )?, + ( + planner.query.min_block as Epoch, + planner.query.max_block as Epoch, + ), ) .await?; // prove the index tree, on a single version. Both path can be taken depending if we do have @@ -113,14 +124,14 @@ pub(crate) async fn prove_query( planner.query.max_block ); } as BlockPrimaryIndex; - let index_path = planner.table + let index_path = planner + .table .index .compute_path(&to_be_proven_node, current_epoch as Epoch) .await - .expect( - format!("Compute path for index node with key {to_be_proven_node} failed") - .as_str(), - ); + .unwrap_or_else(|| { + panic!("Compute path for index node with key {to_be_proven_node} failed") + }); let input = QueryCircuitInput::new_non_existence_input( index_path, &column_ids, @@ -129,16 +140,18 @@ pub(crate) async fn prove_query( &planner.query.placeholders, &planner.pis.bounds, )?; - let query_proof = planner.ctx.run_query_proof( - "batching::non_existence", - GlobalCircuitInput::Query(input), - )?; + let query_proof = planner + .ctx + .run_query_proof("batching::non_existence", GlobalCircuitInput::Query(input))?; let proof_key = ProofKey::QueryAggregate(( planner.query.query.clone(), planner.query.placeholders.placeholder_values(), UTKey::default(), )); - planner.ctx.storage.store_proof(proof_key.clone(), query_proof)?; + planner + .ctx + .storage + .store_proof(proof_key.clone(), query_proof)?; proof_key } else { info!("Running INDEX tree proving from cache"); @@ -147,7 +160,8 @@ pub(crate) async fn prove_query( current_epoch as Epoch, (planner.query.min_block, planner.query.max_block), )?; - let big_index_cache = planner.table + let big_index_cache = planner + .table .index // The bounds here means between which versions of the tree should we look. For index tree, // we only look at _one_ version of the tree. @@ -185,9 +199,9 @@ pub(crate) async fn prove_query( // this is a row chunk to be proven let to_be_proven_chunk = proven_chunks .get(k) - .expect(format!("chunk for key {:?} not found", k).as_str()); + .unwrap_or_else(|| panic!("chunk for key {:?} not found", k)); let input = QueryCircuitInput::new_row_chunks_input( - &to_be_proven_chunk, + to_be_proven_chunk, &planner.pis.predication_operations, &planner.query.placeholders, &planner.pis.bounds, @@ -199,7 +213,7 @@ pub(crate) async fn prove_query( GlobalCircuitInput::Query(input), ) } else { - let children_keys = workplan.t.get_children_keys(&k); + let children_keys = workplan.t.get_children_keys(k); info!("children keys: {:?}", children_keys); // fetch the proof for each child from the storage let child_proofs = children_keys @@ -213,8 +227,7 @@ pub(crate) async fn prove_query( planner.ctx.storage.get_proof_exact(&proof_key) }) .collect::>>()?; - let input = - QueryCircuitInput::new_chunk_aggregation_input(&child_proofs)?; + let input = QueryCircuitInput::new_chunk_aggregation_input(&child_proofs)?; info!("Aggregating chunk {:?}", k); planner.ctx.run_query_proof( "batching::chunk_aggregation", @@ -238,7 +251,7 @@ pub(crate) async fn prove_query( let proof = prove_revelation( planner.ctx, &planner.query, - &planner.pis, + planner.pis, planner.table.index.current_epoch(), &query_proof_id, ) diff --git a/mp2-v1/tests/common/cases/query/mod.rs b/mp2-v1/tests/common/cases/query/mod.rs index d5d0aad4e..7887e635d 100644 --- a/mp2-v1/tests/common/cases/query/mod.rs +++ b/mp2-v1/tests/common/cases/query/mod.rs @@ -11,7 +11,10 @@ use log::info; use mp2_v1::{ api::MetadataHash, indexing::block::BlockPrimaryIndex, query::planner::execute_row_query, }; -use parsil::{assembler::DynamicCircuitPis, parse_and_validate, utils::ParsilSettingsBuilder, ParsilSettings, PlaceholderSettings}; +use parsil::{ + assembler::DynamicCircuitPis, parse_and_validate, utils::ParsilSettingsBuilder, ParsilSettings, + PlaceholderSettings, +}; use simple_select_queries::{ cook_query_no_matching_rows, cook_query_too_big_offset, cook_query_with_distinct, cook_query_with_matching_rows, cook_query_with_max_num_matching_rows, @@ -23,7 +26,10 @@ use verifiable_db::query::{ computational_hash_ids::Output, universal_circuit::universal_circuit_inputs::Placeholders, }; -use crate::common::{table::{Table, TableColumns}, TableInfo, TestContext}; +use crate::common::{ + table::{Table, TableColumns}, + TableInfo, TestContext, +}; use super::table_source::TableSource; @@ -217,13 +223,7 @@ async fn test_query_mapping( match pis.result.query_variant() { Output::Aggregation => { - prove_aggregation_query( - parsed, - res, - *table_hash, - &mut planner, - ) - .await + prove_aggregation_query(parsed, res, *table_hash, &mut planner).await } Output::NoAggregation => { prove_no_aggregation_query(parsed, table_hash, &mut planner, res).await diff --git a/mp2-v1/tests/common/cases/query/simple_select_queries.rs b/mp2-v1/tests/common/cases/query/simple_select_queries.rs index 370839426..df01b78f7 100644 --- a/mp2-v1/tests/common/cases/query/simple_select_queries.rs +++ b/mp2-v1/tests/common/cases/query/simple_select_queries.rs @@ -5,11 +5,17 @@ use log::info; use mp2_common::types::HashOutput; use mp2_v1::{ api::MetadataHash, - indexing::{block::BlockPrimaryIndex, row::{RowPayload, RowTreeKey}, LagrangeNode}, - query::planner::{execute_row_query, get_node_info, TreeFetcher}, values_extraction::identifier_block_column, + indexing::{ + block::BlockPrimaryIndex, + row::{RowPayload, RowTreeKey}, + LagrangeNode, + }, + query::planner::{execute_row_query, get_node_info, TreeFetcher}, + values_extraction::identifier_block_column, }; use parsil::{ - assembler::DynamicCircuitPis, executor::generate_query_execution_with_keys, DEFAULT_MAX_BLOCK_PLACEHOLDER, DEFAULT_MIN_BLOCK_PLACEHOLDER + assembler::DynamicCircuitPis, executor::generate_query_execution_with_keys, + DEFAULT_MAX_BLOCK_PLACEHOLDER, DEFAULT_MIN_BLOCK_PLACEHOLDER, }; use ryhope::{ storage::{pgsql::ToFromBytea, RoEpochKvStorage}, @@ -20,7 +26,11 @@ use std::{fmt::Debug, hash::Hash}; use tokio_postgres::Row as PgSqlRow; use verifiable_db::{ query::{ - computational_hash_ids::ColumnIDs, universal_circuit::universal_circuit_inputs::{ColumnCell, PlaceholderId, Placeholders, RowCells}, utils::{ChildPosition, NodeInfo} + computational_hash_ids::ColumnIDs, + universal_circuit::universal_circuit_inputs::{ + ColumnCell, PlaceholderId, Placeholders, RowCells, + }, + utils::{ChildPosition, NodeInfo}, }, revelation::{api::MatchingRow, RowPath}, test_utils::MAX_NUM_OUTPUTS, @@ -30,10 +40,8 @@ use crate::common::{ cases::{ indexing::BLOCK_COLUMN_NAME, query::{ - aggregated_queries::{ - check_final_outputs, find_longest_lived_key, - }, - GlobalCircuitInput, RevelationCircuitInput, SqlReturn, SqlType, QueryPlanner, + aggregated_queries::{check_final_outputs, find_longest_lived_key}, + GlobalCircuitInput, QueryPlanner, RevelationCircuitInput, SqlReturn, SqlType, }, }, proof_storage::{ProofKey, ProofStorage}, @@ -90,20 +98,13 @@ pub(crate) async fn prove_query( ) .await?; let (row_node_info, _, _) = get_node_info(&planner.table.row, &key, epoch).await; - let (row_tree_path, row_tree_siblings) = get_path_info( - &key, - &planner.table.row, - epoch) - .await?; + let (row_tree_path, row_tree_siblings) = + get_path_info(&key, &planner.table.row, epoch).await?; let index_node_key = epoch as BlockPrimaryIndex; let (index_node_info, _, _) = get_node_info(&planner.table.index, &index_node_key, current_epoch).await; let (index_tree_path, index_tree_siblings) = - get_path_info( - &index_node_key, - &planner.table.index, - current_epoch - ).await?; + get_path_info(&index_node_key, &planner.table.index, current_epoch).await?; let path = RowPath::new( row_node_info, row_tree_path, diff --git a/parsil/src/assembler.rs b/parsil/src/assembler.rs index 128385e9a..e847f9c0a 100644 --- a/parsil/src/assembler.rs +++ b/parsil/src/assembler.rs @@ -15,11 +15,11 @@ use sqlparser::ast::{ SelectItem, SetExpr, TableAlias, TableFactor, UnaryOperator, Value, }; use verifiable_db::query::{ - utils::{QueryBoundSource, QueryBounds}, computational_hash_ids::{AggregationOperation, Operation, PlaceholderIdentifier}, universal_circuit::universal_circuit_inputs::{ BasicOperation, InputOperand, OutputItem, Placeholders, ResultStructure, }, + utils::{QueryBoundSource, QueryBounds}, }; use crate::{ diff --git a/parsil/src/queries.rs b/parsil/src/queries.rs index 92b6d7b29..506fdb731 100644 --- a/parsil/src/queries.rs +++ b/parsil/src/queries.rs @@ -5,7 +5,7 @@ use crate::{keys_in_index_boundaries, symbols::ContextProvider, ParsilSettings}; use anyhow::*; use ryhope::{tree::sbbst::NodeIdx, Epoch, EPOCH, KEY, VALID_FROM, VALID_UNTIL}; use verifiable_db::query::{ - utils::QueryBounds, universal_circuit::universal_circuit_inputs::Placeholders, + universal_circuit::universal_circuit_inputs::Placeholders, utils::QueryBounds, }; /// Return a query read to be injected in the wide lineage computation for the diff --git a/verifiable-db/Cargo.toml b/verifiable-db/Cargo.toml index 3339e8d36..3ab92c430 100644 --- a/verifiable-db/Cargo.toml +++ b/verifiable-db/Cargo.toml @@ -30,4 +30,3 @@ tokio.workspace = true [features] original_poseidon = ["mp2_common/original_poseidon"] -batching_circuits = [] \ No newline at end of file diff --git a/verifiable-db/src/api.rs b/verifiable-db/src/api.rs index bd7556d53..9c1ca5324 100644 --- a/verifiable-db/src/api.rs +++ b/verifiable-db/src/api.rs @@ -4,14 +4,8 @@ use crate::{ block_tree, cells_tree, extraction::{ExtractionPI, ExtractionPIWrap}, ivc, - query::{ - self, api::Parameters as QueryParams, - pi_len as query_pi_len, - }, - revelation::{ - self, api::Parameters as RevelationParams, - pi_len as revelation_pi_len, - }, + query::{self, api::Parameters as QueryParams, pi_len as query_pi_len}, + revelation::{self, api::Parameters as RevelationParams, pi_len as revelation_pi_len}, row_tree::{self}, }; use anyhow::Result; diff --git a/verifiable-db/src/query/api.rs b/verifiable-db/src/query/api.rs index 6eed9daf3..18d5d7e0e 100644 --- a/verifiable-db/src/query/api.rs +++ b/verifiable-db/src/query/api.rs @@ -3,30 +3,53 @@ use std::iter::{repeat, repeat_with}; use anyhow::{bail, ensure, Result}; use itertools::Itertools; -use mp2_common::{array::ToField, default_config, poseidon::{HashPermutation, H}, proof::{serialize_proof, ProofWithVK}, types::HashOutput, utils::ToFields, C, D, F}; -use plonky2::{hash::hashing::hash_n_to_hash_no_pad, plonk::config::{GenericHashOut, Hasher}}; +use mp2_common::{ + array::ToField, + default_config, + poseidon::{HashPermutation, H}, + proof::{serialize_proof, ProofWithVK}, + types::HashOutput, + utils::ToFields, + C, D, F, +}; +use plonky2::{ + hash::hashing::hash_n_to_hash_no_pad, + plonk::config::{GenericHashOut, Hasher}, +}; use recursion_framework::{ - circuit_builder::{CircuitWithUniversalVerifier, CircuitWithUniversalVerifierBuilder}, framework::{prepare_recursive_circuit_for_circuit_set, RecursiveCircuits}, + circuit_builder::{CircuitWithUniversalVerifier, CircuitWithUniversalVerifierBuilder}, + framework::{prepare_recursive_circuit_for_circuit_set, RecursiveCircuits}, }; use serde::{Deserialize, Serialize}; use crate::query::{ - utils::{ChildPosition, NodeInfo, QueryBounds, QueryHashNonExistenceCircuits}, circuits::{ - chunk_aggregation::{ChunkAggregationCircuit, ChunkAggregationInputs, ChunkAggregationWires}, - non_existence::{NonExistenceCircuit, NonExistenceWires}, - row_chunk_processing::{RowChunkProcessingCircuit, RowChunkProcessingWires}, + chunk_aggregation::{ + ChunkAggregationCircuit, ChunkAggregationInputs, ChunkAggregationWires, + }, + non_existence::{NonExistenceCircuit, NonExistenceWires}, + row_chunk_processing::{RowChunkProcessingCircuit, RowChunkProcessingWires}, }, - row_chunk_gadgets::row_process_gadget::RowProcessingGadgetInputs, computational_hash_ids::{AggregationOperation, ColumnIDs, Identifiers}, + row_chunk_gadgets::row_process_gadget::RowProcessingGadgetInputs, universal_circuit::{ - output_with_aggregation::Circuit as OutputAggCircuit, output_no_aggregation::Circuit as OutputNoAggCircuit, + output_with_aggregation::Circuit as OutputAggCircuit, universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure, RowCells}, }, + utils::{ChildPosition, NodeInfo, QueryBounds, QueryHashNonExistenceCircuits}, }; -use super::{computational_hash_ids::Output, pi_len, universal_circuit::{universal_circuit_inputs::PlaceholderId, universal_query_circuit::{placeholder_hash, UniversalCircuitInput, UniversalQueryCircuitParams}}}; +use super::{ + computational_hash_ids::Output, + pi_len, + universal_circuit::{ + universal_circuit_inputs::PlaceholderId, + universal_query_circuit::{ + placeholder_hash, UniversalCircuitInput, UniversalQueryCircuitParams, + }, + }, +}; /// Data structure containing all the information needed to verify the membership of /// a node in a tree and to compute info about its predecessor/successor @@ -362,7 +385,9 @@ where MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, - >::ids_for_placeholder_hash(predicate_operations, results, placeholders, query_bounds) + >::ids_for_placeholder_hash( + predicate_operations, results, placeholders, query_bounds + ) } /// Compute the `placeholder_hash` associated to a query @@ -442,9 +467,9 @@ pub(crate) struct Parameters< NonExistenceWires, >, universal_circuit: UniversalQueryCircuitParams< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, OutputNoAggCircuit, >, @@ -452,15 +477,15 @@ pub(crate) struct Parameters< } impl< - const NUM_CHUNKS: usize, - const NUM_ROWS: usize, - const ROW_TREE_MAX_DEPTH: usize, - const INDEX_TREE_MAX_DEPTH: usize, - const MAX_NUM_COLUMNS: usize, - const MAX_NUM_PREDICATE_OPS: usize, - const MAX_NUM_RESULT_OPS: usize, - const MAX_NUM_RESULTS: usize, -> + const NUM_CHUNKS: usize, + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + > Parameters< NUM_CHUNKS, NUM_ROWS, @@ -523,20 +548,23 @@ where >, ) -> Result> { match input { - CircuitInput::RowChunkWithAggregation(row_chunk_processing_circuit) => - ProofWithVK::serialize( - &( - self.circuit_set.generate_proof( - &self.row_chunk_agg_circuit, - [], - [], - row_chunk_processing_circuit, - )?, - self.row_chunk_agg_circuit - .circuit_data() - .verifier_only - .clone(), - ).into()), + CircuitInput::RowChunkWithAggregation(row_chunk_processing_circuit) => { + ProofWithVK::serialize( + &( + self.circuit_set.generate_proof( + &self.row_chunk_agg_circuit, + [], + [], + row_chunk_processing_circuit, + )?, + self.row_chunk_agg_circuit + .circuit_data() + .verifier_only + .clone(), + ) + .into(), + ) + } CircuitInput::ChunkAggregation(chunk_aggregation_inputs) => { let ChunkAggregationInputs { chunk_proofs, @@ -550,41 +578,42 @@ where let input_proofs = chunk_proofs.map(|p| p.proof); ProofWithVK::serialize( &( + self.circuit_set.generate_proof( + &self.aggregation_circuit, + input_proofs, + input_vd.iter().collect_vec().try_into().unwrap(), + circuit, + )?, + self.aggregation_circuit + .circuit_data() + .verifier_only + .clone(), + ) + .into(), + ) + } + CircuitInput::NonExistence(non_existence_circuit) => ProofWithVK::serialize( + &( self.circuit_set.generate_proof( - &self.aggregation_circuit, - input_proofs, - input_vd.iter().collect_vec().try_into().unwrap(), - circuit, + &self.non_existence_circuit, + [], + [], + non_existence_circuit, )?, - self.aggregation_circuit + self.non_existence_circuit .circuit_data() .verifier_only .clone(), ) - .into()) - } - CircuitInput::NonExistence(non_existence_circuit) => - ProofWithVK::serialize( - &( - self.circuit_set.generate_proof( - &self.non_existence_circuit, - [], - [], - non_existence_circuit, - )?, - self.non_existence_circuit - .circuit_data() - .verifier_only - .clone(), - ) - .into()), - CircuitInput::UniversalCircuit(universal_circuit_input) => + .into(), + ), + CircuitInput::UniversalCircuit(universal_circuit_input) => { if let UniversalCircuitInput::QueryNoAgg(input) = universal_circuit_input { serialize_proof(&self.universal_circuit.generate_proof(&input)?) } else { unreachable!("Universal circuit should only be used for queries with no aggregation operations") } - , + } } } @@ -592,10 +621,12 @@ where &self.circuit_set } - pub(crate) fn get_universal_circuit(&self) -> &UniversalQueryCircuitParams< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, + pub(crate) fn get_universal_circuit( + &self, + ) -> &UniversalQueryCircuitParams< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, OutputNoAggCircuit, > { diff --git a/verifiable-db/src/query/circuits/chunk_aggregation.rs b/verifiable-db/src/query/circuits/chunk_aggregation.rs index 98dd02a86..93da64d6d 100644 --- a/verifiable-db/src/query/circuits/chunk_aggregation.rs +++ b/verifiable-db/src/query/circuits/chunk_aggregation.rs @@ -23,7 +23,7 @@ use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; use crate::query::{ - row_chunk_gadgets::aggregate_chunks::aggregate_chunks, pi_len, public_inputs::PublicInputs + pi_len, public_inputs::PublicInputs, row_chunk_gadgets::aggregate_chunks::aggregate_chunks, }; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -198,10 +198,10 @@ mod tests { use crate::{ query::{ - utils::tests::aggregate_output_values, - public_inputs::PublicInputs, computational_hash_ids::{AggregationOperation, Identifiers}, + public_inputs::PublicInputs, universal_circuit::universal_query_gadget::OutputValues, + utils::tests::aggregate_output_values, }, test_utils::random_aggregation_operations, }; diff --git a/verifiable-db/src/query/circuits/mod.rs b/verifiable-db/src/query/circuits/mod.rs index aa81d4d5c..5dddff4df 100644 --- a/verifiable-db/src/query/circuits/mod.rs +++ b/verifiable-db/src/query/circuits/mod.rs @@ -17,12 +17,20 @@ mod tests { }; use rand::thread_rng; - use crate::{query::{ - computational_hash_ids::AggregationOperation, merkle_path::tests::build_node, universal_circuit::{ - universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure, RowCells}, - universal_query_gadget::OutputValues, - }, utils::{NodeInfo, QueryBounds} - }, test_utils::gen_values_in_range}; + use crate::{ + query::{ + computational_hash_ids::AggregationOperation, + merkle_path::tests::build_node, + universal_circuit::{ + universal_circuit_inputs::{ + BasicOperation, Placeholders, ResultStructure, RowCells, + }, + universal_query_gadget::OutputValues, + }, + utils::{NodeInfo, QueryBounds}, + }, + test_utils::gen_values_in_range, + }; /// Data structure employed to represent a node of a rows tree in the tests #[derive(Clone, Debug)] diff --git a/verifiable-db/src/query/circuits/non_existence.rs b/verifiable-db/src/query/circuits/non_existence.rs index b35a5a1a9..68f51b3e1 100644 --- a/verifiable-db/src/query/circuits/non_existence.rs +++ b/verifiable-db/src/query/circuits/non_existence.rs @@ -22,11 +22,18 @@ use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; use crate::query::{ - utils::QueryBounds, output_computation::compute_dummy_output_targets, api::TreePathInputs, row_chunk_gadgets::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget}, merkle_path::{ + api::TreePathInputs, + merkle_path::{ MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, NeighborInfoTarget, - }, pi_len, public_inputs::PublicInputs, universal_circuit::{ + }, + output_computation::compute_dummy_output_targets, + pi_len, + public_inputs::PublicInputs, + row_chunk_gadgets::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget}, + universal_circuit::{ ComputationalHash, ComputationalHashTarget, PlaceholderHash, PlaceholderHashTarget, - } + }, + utils::QueryBounds, }; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -281,7 +288,13 @@ mod tests { use crate::{ query::{ - api::TreePathInputs, merkle_path::{tests::generate_test_tree, NeighborInfo}, output_computation::tests::compute_dummy_output_values, public_inputs::PublicInputs, row_chunk_gadgets::{BoundaryRowData, BoundaryRowNodeInfo}, universal_circuit::universal_circuit_inputs::Placeholders, utils::{ChildPosition, QueryBounds} + api::TreePathInputs, + merkle_path::{tests::generate_test_tree, NeighborInfo}, + output_computation::tests::compute_dummy_output_values, + public_inputs::PublicInputs, + row_chunk_gadgets::{BoundaryRowData, BoundaryRowNodeInfo}, + universal_circuit::universal_circuit_inputs::Placeholders, + utils::{ChildPosition, QueryBounds}, }, test_utils::{gen_values_in_range, random_aggregation_operations}, }; diff --git a/verifiable-db/src/query/circuits/row_chunk_processing.rs b/verifiable-db/src/query/circuits/row_chunk_processing.rs index 22b9a9162..56a9a75ec 100644 --- a/verifiable-db/src/query/circuits/row_chunk_processing.rs +++ b/verifiable-db/src/query/circuits/row_chunk_processing.rs @@ -10,17 +10,21 @@ use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::query::{ - utils::QueryBounds, row_chunk_gadgets:: - { + computational_hash_ids::ColumnIDs, + pi_len, + public_inputs::PublicInputs, + row_chunk_gadgets::{ + aggregate_chunks::aggregate_chunks, row_process_gadget::{RowProcessingGadgetInputWires, RowProcessingGadgetInputs}, - aggregate_chunks::aggregate_chunks, RowChunkDataTarget, - }, - computational_hash_ids::ColumnIDs, pi_len, public_inputs::PublicInputs, universal_circuit::{ + RowChunkDataTarget, + }, + universal_circuit::{ universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure}, universal_query_gadget::{ OutputComponent, UniversalQueryHashInputWires, UniversalQueryHashInputs, }, - } + }, + utils::QueryBounds, }; use mp2_common::{ @@ -217,7 +221,10 @@ where b, &chunk, ¤t_chunk, - (&query_input_wires.input_wires.min_query_primary, &query_input_wires.input_wires.max_query_primary), + ( + &query_input_wires.input_wires.min_query_primary, + &query_input_wires.input_wires.max_query_primary, + ), ( &query_input_wires.min_secondary, &query_input_wires.max_secondary, @@ -374,22 +381,18 @@ mod tests { use rand::thread_rng; use crate::query::{ - utils::{ - tests::aggregate_output_values, ChildPosition, QueryBoundSource, QueryBounds, - }, circuits::{ row_chunk_processing::RowChunkProcessingCircuit, tests::{build_test_tree, compute_output_values_for_row}, }, - row_chunk_gadgets::{ - BoundaryRowData, BoundaryRowNodeInfo, - row_process_gadget::RowProcessingGadgetInputs - }, - public_inputs::PublicInputs, computational_hash_ids::{ AggregationOperation, ColumnIDs, Identifiers, Operation, PlaceholderIdentifier, }, - merkle_path::{NeighborInfo, MerklePathWithNeighborsGadget}, + merkle_path::{MerklePathWithNeighborsGadget, NeighborInfo}, + public_inputs::PublicInputs, + row_chunk_gadgets::{ + row_process_gadget::RowProcessingGadgetInputs, BoundaryRowData, BoundaryRowNodeInfo, + }, universal_circuit::{ output_no_aggregation::Circuit as NoAggOutputCircuit, output_with_aggregation::Circuit as AggOutputCircuit, @@ -401,6 +404,7 @@ mod tests { universal_query_gadget::CurveOrU256, ComputationalHash, }, + utils::{tests::aggregate_output_values, ChildPosition, QueryBoundSource, QueryBounds}, }; use super::{OutputComponent, RowChunkProcessingWires}; diff --git a/verifiable-db/src/query/computational_hash_ids.rs b/verifiable-db/src/query/computational_hash_ids.rs index ef55f2870..a672f8da7 100644 --- a/verifiable-db/src/query/computational_hash_ids.rs +++ b/verifiable-db/src/query/computational_hash_ids.rs @@ -31,7 +31,6 @@ use serde::{Deserialize, Serialize}; use crate::revelation::placeholders_check::placeholder_ids_hash; use super::{ - utils::QueryBoundSource, universal_circuit::{ universal_circuit_inputs::{ BasicOperation, InputOperand, OutputItem, PlaceholderIdsSet, ResultStructure, @@ -39,6 +38,7 @@ use super::{ universal_query_gadget::QueryBound, ComputationalHash, ComputationalHashTarget, }, + utils::QueryBoundSource, }; pub enum Identifiers { diff --git a/verifiable-db/src/query/merkle_path.rs b/verifiable-db/src/query/merkle_path.rs index a1c0a3545..571cc8bb2 100644 --- a/verifiable-db/src/query/merkle_path.rs +++ b/verifiable-db/src/query/merkle_path.rs @@ -19,12 +19,12 @@ use mp2_common::{ }; use mp2_test::utils::gen_random_field_hash; use plonky2::{ + field::types::Field, hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, - field::types::Field, plonk::{circuit_builder::CircuitBuilder, config::GenericHashOut}, }; use rand::Rng; @@ -783,7 +783,8 @@ pub(crate) mod tests { use crate::query::utils::{ChildPosition, NodeInfo}; use super::{ - MerklePathGadget, MerklePathTargetInputs, MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, NeighborInfo, NeighborInfoTarget + MerklePathGadget, MerklePathTargetInputs, MerklePathWithNeighborsGadget, + MerklePathWithNeighborsTargetInputs, NeighborInfo, NeighborInfoTarget, }; #[derive(Clone, Debug)] diff --git a/verifiable-db/src/query/mod.rs b/verifiable-db/src/query/mod.rs index 4b69d5497..a99fbfba8 100644 --- a/verifiable-db/src/query/mod.rs +++ b/verifiable-db/src/query/mod.rs @@ -2,13 +2,13 @@ use plonky2::iop::target::Target; use public_inputs::PublicInputs; pub mod api; +pub(crate) mod circuits; pub mod computational_hash_ids; pub mod merkle_path; +pub(crate) mod output_computation; pub mod public_inputs; -pub mod universal_circuit; -pub(crate) mod circuits; pub(crate) mod row_chunk_gadgets; -pub(crate) mod output_computation; +pub mod universal_circuit; pub mod utils; pub const fn pi_len() -> usize { diff --git a/verifiable-db/src/query/output_computation.rs b/verifiable-db/src/query/output_computation.rs index 76e126451..5c7b4d91e 100644 --- a/verifiable-db/src/query/output_computation.rs +++ b/verifiable-db/src/query/output_computation.rs @@ -156,9 +156,9 @@ pub(crate) mod tests { use super::*; use crate::{ query::{ - utils::tests::compute_output_item_value, pi_len, - public_inputs::PublicInputs, + pi_len, public_inputs::PublicInputs, universal_circuit::universal_query_gadget::CurveOrU256, + utils::tests::compute_output_item_value, }, test_utils::random_aggregation_operations, }; diff --git a/verifiable-db/src/query/public_inputs.rs b/verifiable-db/src/query/public_inputs.rs index d3d1736be..a3b22f50e 100644 --- a/verifiable-db/src/query/public_inputs.rs +++ b/verifiable-db/src/query/public_inputs.rs @@ -74,10 +74,10 @@ pub enum QueryPublicInputsUniversalCircuit { /// `ops` : `[F; S]` Set of identifiers of the aggregation operations for each of the `S` items found in `V` /// (like "SUM", "MIN", "MAX", "COUNT" operations) OpIds, - /// Data associated to the left boundary row of the row chunk being proven; it is dummy in case of universal query + /// Data associated to the left boundary row of the row chunk being proven; it is dummy in case of universal query /// circuit, it is just empoyed to re-use the same public inputs LeftBoundaryRow, - /// Data associated to the right boundary row of the row chunk being proven; it is dummy in case of universal query + /// Data associated to the right boundary row of the row chunk being proven; it is dummy in case of universal query /// circuit, it is just empoyed to re-use the same public inputs RightBoundaryRow, /// `MIN_primary`: `u256` Lower bound of the range of primary indexed column values specified in the query @@ -103,18 +103,28 @@ impl From for QueryPublicInputs { QueryPublicInputsUniversalCircuit::OutputValues => QueryPublicInputs::OutputValues, QueryPublicInputsUniversalCircuit::NumMatching => QueryPublicInputs::NumMatching, QueryPublicInputsUniversalCircuit::OpIds => QueryPublicInputs::NumMatching, - QueryPublicInputsUniversalCircuit::LeftBoundaryRow => QueryPublicInputs::LeftBoundaryRow, - QueryPublicInputsUniversalCircuit::RightBoundaryRow => QueryPublicInputs::RightBoundaryRow, + QueryPublicInputsUniversalCircuit::LeftBoundaryRow => { + QueryPublicInputs::LeftBoundaryRow + } + QueryPublicInputsUniversalCircuit::RightBoundaryRow => { + QueryPublicInputs::RightBoundaryRow + } QueryPublicInputsUniversalCircuit::MinPrimary => QueryPublicInputs::MinPrimary, QueryPublicInputsUniversalCircuit::MaxPrimary => QueryPublicInputs::MaxPrimary, - QueryPublicInputsUniversalCircuit::SecondaryIndexValue => QueryPublicInputs::MinSecondary, + QueryPublicInputsUniversalCircuit::SecondaryIndexValue => { + QueryPublicInputs::MinSecondary + } QueryPublicInputsUniversalCircuit::PrimaryIndexValue => QueryPublicInputs::MaxSecondary, QueryPublicInputsUniversalCircuit::Overflow => QueryPublicInputs::Overflow, - QueryPublicInputsUniversalCircuit::ComputationalHash => QueryPublicInputs::ComputationalHash, - QueryPublicInputsUniversalCircuit::PlaceholderHash => QueryPublicInputs::PlaceholderHash, + QueryPublicInputsUniversalCircuit::ComputationalHash => { + QueryPublicInputs::ComputationalHash + } + QueryPublicInputsUniversalCircuit::PlaceholderHash => { + QueryPublicInputs::PlaceholderHash + } } } -} +} /// Public inputs for generic query circuits pub type PublicInputs<'a, T, const S: usize> = PublicInputsFactory<'a, T, S, false>; /// Public inputs for universal query circuit @@ -128,8 +138,8 @@ pub type PublicInputsUniversalCircuit<'a, T, const S: usize> = PublicInputsFacto /// public inputs of universal query circuit. The methods being common between the /// 2 public inputs are implemented for this data structure, while the methods that /// are specific to each public input type are implemented for the corresponding alias. -/// In this way, the methods implemented for the type alias define the correct semantics -/// of each of the items in both types of public inputs. +/// In this way, the methods implemented for the type alias define the correct semantics +/// of each of the items in both types of public inputs. #[derive(Clone, Debug)] pub struct PublicInputsFactory<'a, T, const S: usize, const UNIVERSAL_CIRCUIT: bool> { h: &'a [T], @@ -149,12 +159,9 @@ pub struct PublicInputsFactory<'a, T, const S: usize, const UNIVERSAL_CIRCUIT: b const NUM_PUBLIC_INPUTS: usize = QueryPublicInputs::PlaceholderHash as usize + 1; -impl< - 'a, - T: Clone, - const S: usize, - const UNIVERSAL_CIRCUIT: bool, -> PublicInputsFactory<'a, T, S, UNIVERSAL_CIRCUIT> { +impl<'a, T: Clone, const S: usize, const UNIVERSAL_CIRCUIT: bool> + PublicInputsFactory<'a, T, S, UNIVERSAL_CIRCUIT> +{ const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ Self::to_range_internal(QueryPublicInputs::TreeHash), Self::to_range_internal(QueryPublicInputs::OutputValues), @@ -211,8 +218,7 @@ impl< offset..offset + Self::SIZES[pi_pos] } - pub fn to_range>(query_pi: Q) -> PublicInputRange - { + pub fn to_range>(query_pi: Q) -> PublicInputRange { Self::to_range_internal(query_pi.into()) } @@ -348,7 +354,9 @@ impl< } } -impl PublicInputCommon for PublicInputsFactory<'_, Target, S, UNIVERSAL_CIRCUIT> { +impl PublicInputCommon + for PublicInputsFactory<'_, Target, S, UNIVERSAL_CIRCUIT> +{ const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES; fn register_args(&self, cb: &mut CBuilder) { @@ -368,7 +376,9 @@ impl PublicInputCommon for Public } } -impl PublicInputsFactory<'_, Target, S, UNIVERSAL_CIRCUIT> { +impl + PublicInputsFactory<'_, Target, S, UNIVERSAL_CIRCUIT> +{ pub fn tree_hash_target(&self) -> HashOutTarget { HashOutTarget::try_from(self.to_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length } @@ -493,13 +503,13 @@ impl PublicInputs<'_, Target, S> { impl PublicInputsUniversalCircuit<'_, Target, S> { pub fn secondary_index_value_target(&self) -> UInt256Target { - // secondary index value is found in `self.min_s` for + // secondary index value is found in `self.min_s` for // `PublicInputsUniversalCircuit` UInt256Target::from_targets(self.min_s) } pub fn primary_index_value_target(&self) -> UInt256Target { - // primary index value is found in `self.max_s` for + // primary index value is found in `self.max_s` for // `PublicInputsUniversalCircuit` UInt256Target::from_targets(self.max_s) } @@ -576,13 +586,13 @@ impl PublicInputs<'_, F, S> { impl PublicInputsUniversalCircuit<'_, F, S> { pub fn secondary_index_value(&self) -> U256 { - // secondary index value is found in `self.min_s` for + // secondary index value is found in `self.min_s` for // `PublicInputsUniversalCircuit` U256::from_fields(self.min_s) } pub fn primary_index_value(&self) -> U256 { - // primary index value is found in `self.max_s` for + // primary index value is found in `self.max_s` for // `PublicInputsUniversalCircuit` U256::from_fields(self.max_s) } diff --git a/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs b/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs index c3a64296b..b942968e2 100644 --- a/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs +++ b/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs @@ -124,19 +124,20 @@ mod tests { use crate::{ query::{ - utils::{tests::aggregate_output_values, ChildPosition, NodeInfo}, - row_chunk_gadgets::{BoundaryRowData, BoundaryRowNodeInfo, tests::RowChunkData, - BoundaryRowDataTarget, BoundaryRowNodeInfoTarget, RowChunkDataTarget, - }, - public_inputs::PublicInputs, computational_hash_ids::{AggregationOperation, Identifiers}, merkle_path::{ tests::{build_node, generate_test_tree}, - NeighborInfo, MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, + MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, NeighborInfo, + }, + public_inputs::PublicInputs, + row_chunk_gadgets::{ + tests::RowChunkData, BoundaryRowData, BoundaryRowDataTarget, BoundaryRowNodeInfo, + BoundaryRowNodeInfoTarget, RowChunkDataTarget, }, universal_circuit::universal_query_gadget::{ OutputValues, OutputValuesTarget, UniversalQueryOutputWires, }, + utils::{tests::aggregate_output_values, ChildPosition, NodeInfo}, }, test_utils::random_aggregation_operations, }; diff --git a/verifiable-db/src/query/row_chunk_gadgets/consecutive_rows.rs b/verifiable-db/src/query/row_chunk_gadgets/consecutive_rows.rs index 800f6c846..984f6828d 100644 --- a/verifiable-db/src/query/row_chunk_gadgets/consecutive_rows.rs +++ b/verifiable-db/src/query/row_chunk_gadgets/consecutive_rows.rs @@ -235,11 +235,11 @@ mod tests { use rand::thread_rng; use crate::query::{ - utils::{ChildPosition, NodeInfo}, merkle_path::{ tests::{build_node, generate_test_tree}, MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, }, + utils::{ChildPosition, NodeInfo}, }; use super::{ diff --git a/verifiable-db/src/query/row_chunk_gadgets/mod.rs b/verifiable-db/src/query/row_chunk_gadgets/mod.rs index ee1f745e1..08c23d9a0 100644 --- a/verifiable-db/src/query/row_chunk_gadgets/mod.rs +++ b/verifiable-db/src/query/row_chunk_gadgets/mod.rs @@ -7,7 +7,8 @@ use alloy::primitives::U256; use mp2_common::{ serialization::circuit_data_serialization::SerializableRichField, - utils::{FromFields, FromTargets, HashBuilder, SelectTarget, ToFields, ToTargets}, F, + utils::{FromFields, FromTargets, HashBuilder, SelectTarget, ToFields, ToTargets}, + F, }; use mp2_test::utils::gen_random_field_hash; use plonky2::{ @@ -17,10 +18,13 @@ use plonky2::{ }; use rand::Rng; -use crate::{query::{ - merkle_path::{MerklePathWithNeighborsTarget, NeighborInfoTarget}, - universal_circuit::universal_query_gadget::UniversalQueryOutputWires, -}, test_utils::gen_values_in_range}; +use crate::{ + query::{ + merkle_path::{MerklePathWithNeighborsTarget, NeighborInfoTarget}, + universal_circuit::universal_query_gadget::UniversalQueryOutputWires, + }, + test_utils::gen_values_in_range, +}; use super::{merkle_path::NeighborInfo, utils::QueryBounds}; @@ -213,7 +217,6 @@ where } } - #[derive(Clone, Debug)] pub(crate) struct BoundaryRowNodeInfo { pub(crate) end_node_hash: HashOut, @@ -237,9 +240,8 @@ impl FromFields for BoundaryRowNodeInfo { assert!(t.len() >= BoundaryRowNodeInfoTarget::NUM_TARGETS); let end_node_hash = HashOut::from_partial(&t[..NUM_HASH_OUT_ELTS]); let predecessor_info = NeighborInfo::from_fields(&t[NUM_HASH_OUT_ELTS..]); - let successor_info = NeighborInfo::from_fields( - &t[NUM_HASH_OUT_ELTS + NeighborInfoTarget::NUM_TARGETS..], - ); + let successor_info = + NeighborInfo::from_fields(&t[NUM_HASH_OUT_ELTS + NeighborInfoTarget::NUM_TARGETS..]); Self { end_node_hash, @@ -250,135 +252,135 @@ impl FromFields for BoundaryRowNodeInfo { } impl BoundaryRowNodeInfo { - /// Generate an instance of `Self` representing a random node, given the `query_bounds` - /// provided as input and a flag `is_index_tree` specifying whether the random node - /// should be part of an index tree or of a rows tree. It is used to generate test data - /// without the need to generate an actual tree - pub(crate) fn sample( - rng: &mut R, - query_bounds: &QueryBounds, - is_index_tree: bool, - ) -> Self { - let (min_query_bound, max_query_bound) = if is_index_tree { - ( - query_bounds.min_query_primary(), - query_bounds.max_query_primary(), - ) + /// Generate an instance of `Self` representing a random node, given the `query_bounds` + /// provided as input and a flag `is_index_tree` specifying whether the random node + /// should be part of an index tree or of a rows tree. It is used to generate test data + /// without the need to generate an actual tree + pub(crate) fn sample( + rng: &mut R, + query_bounds: &QueryBounds, + is_index_tree: bool, + ) -> Self { + let (min_query_bound, max_query_bound) = if is_index_tree { + ( + query_bounds.min_query_primary(), + query_bounds.max_query_primary(), + ) + } else { + ( + *query_bounds.min_query_secondary().value(), + *query_bounds.max_query_secondary().value(), + ) + }; + let end_node_hash = gen_random_field_hash(); + let [predecessor_value] = gen_values_in_range( + rng, + if is_index_tree { + min_query_bound // predecessor in index tree must always be in range } else { - ( - *query_bounds.min_query_secondary().value(), - *query_bounds.max_query_secondary().value(), - ) - }; - let end_node_hash = gen_random_field_hash(); - let [predecessor_value] = gen_values_in_range( - rng, - if is_index_tree { - min_query_bound // predecessor in index tree must always be in range - } else { - U256::ZERO - }, - max_query_bound, // predecessor value must always be smaller than max_secondary in circuit - ); - let predecessor_info = NeighborInfo::sample( - rng, - predecessor_value, - if is_index_tree { - // in index tree, there must always be a predecessor for boundary rows - Some(true) - } else { - None - }, - ); - let [successor_value] = gen_values_in_range( - rng, - predecessor_value.max(min_query_bound), // successor value must - // always be greater than min_secondary in circuit, and it must be also - // greater than predecessor value since we are in a BST - if is_index_tree { - max_query_bound // successor in index tree must always be in range - } else { - U256::MAX - }, - ); - let successor_info = NeighborInfo::sample( - rng, - successor_value, - if is_index_tree { - // in index tree, there must always be a successor for boundary rows - Some(true) - } else { - None - }, - ); + U256::ZERO + }, + max_query_bound, // predecessor value must always be smaller than max_secondary in circuit + ); + let predecessor_info = NeighborInfo::sample( + rng, + predecessor_value, + if is_index_tree { + // in index tree, there must always be a predecessor for boundary rows + Some(true) + } else { + None + }, + ); + let [successor_value] = gen_values_in_range( + rng, + predecessor_value.max(min_query_bound), // successor value must + // always be greater than min_secondary in circuit, and it must be also + // greater than predecessor value since we are in a BST + if is_index_tree { + max_query_bound // successor in index tree must always be in range + } else { + U256::MAX + }, + ); + let successor_info = NeighborInfo::sample( + rng, + successor_value, + if is_index_tree { + // in index tree, there must always be a successor for boundary rows + Some(true) + } else { + None + }, + ); - Self { - end_node_hash, - predecessor_info, - successor_info, - } + Self { + end_node_hash, + predecessor_info, + successor_info, } + } - /// Given a boundary node with info stored in `self`, this method generates at random the - /// information about a node that can be the successor of `self` in a BST. This method - /// requires as additional inputs the `query_bounds` and a flag `is_index_tree`, which - /// specifies whether `self` and the generated node should be part of an index tree or - /// of a rows tree - pub(crate) fn sample_successor_in_tree( - &self, - rng: &mut R, - query_bounds: &QueryBounds, - is_index_tree: bool, - ) -> Self { - let (min_query_bound, max_query_bound) = if is_index_tree { - ( - query_bounds.min_query_primary(), - query_bounds.max_query_primary(), - ) + /// Given a boundary node with info stored in `self`, this method generates at random the + /// information about a node that can be the successor of `self` in a BST. This method + /// requires as additional inputs the `query_bounds` and a flag `is_index_tree`, which + /// specifies whether `self` and the generated node should be part of an index tree or + /// of a rows tree + pub(crate) fn sample_successor_in_tree( + &self, + rng: &mut R, + query_bounds: &QueryBounds, + is_index_tree: bool, + ) -> Self { + let (min_query_bound, max_query_bound) = if is_index_tree { + ( + query_bounds.min_query_primary(), + query_bounds.max_query_primary(), + ) + } else { + ( + *query_bounds.min_query_secondary().value(), + *query_bounds.max_query_secondary().value(), + ) + }; + let end_node_hash = self.successor_info.hash; + // value of predecessor must be in query range and between the predecessor and successor value + // of `self` + let [predecessor_value] = gen_values_in_range( + rng, + min_query_bound.max(self.predecessor_info.value), + self.successor_info.value.min(max_query_bound), + ); + let predecessor_info = if self.successor_info.is_in_path { + NeighborInfo::new(predecessor_value, None) + } else { + NeighborInfo::new(predecessor_value, Some(self.end_node_hash)) + }; + let [successor_value] = gen_values_in_range( + rng, + predecessor_value.max(min_query_bound), + if is_index_tree { + max_query_bound // successor must always be in range in index tree } else { - ( - *query_bounds.min_query_secondary().value(), - *query_bounds.max_query_secondary().value(), - ) - }; - let end_node_hash = self.successor_info.hash; - // value of predecessor must be in query range and between the predecessor and successor value - // of `self` - let [predecessor_value] = gen_values_in_range( - rng, - min_query_bound.max(self.predecessor_info.value), - self.successor_info.value.min(max_query_bound), - ); - let predecessor_info = if self.successor_info.is_in_path { - NeighborInfo::new(predecessor_value, None) + U256::MAX + }, + ); + let successor_info = NeighborInfo::sample( + rng, + successor_value, + if is_index_tree { + // in index tree, there must always be a successor for boundary rows + Some(true) } else { - NeighborInfo::new(predecessor_value, Some(self.end_node_hash)) - }; - let [successor_value] = gen_values_in_range( - rng, - predecessor_value.max(min_query_bound), - if is_index_tree { - max_query_bound // successor must always be in range in index tree - } else { - U256::MAX - }, - ); - let successor_info = NeighborInfo::sample( - rng, - successor_value, - if is_index_tree { - // in index tree, there must always be a successor for boundary rows - Some(true) - } else { - None - }, - ); - BoundaryRowNodeInfo { - end_node_hash, - predecessor_info, - successor_info, - } + None + }, + ); + BoundaryRowNodeInfo { + end_node_hash, + predecessor_info, + successor_info, } + } } #[derive(Clone, Debug)] @@ -481,14 +483,8 @@ impl BoundaryRowData { #[cfg(test)] pub(crate) mod tests { - use mp2_common::{ - utils::ToFields, - F, - }; - use plonky2::{ - field::types::Field, - hash::hash_types::HashOut, - }; + use mp2_common::{utils::ToFields, F}; + use plonky2::{field::types::Field, hash::hash_types::HashOut}; use crate::query::universal_circuit::universal_query_gadget::OutputValues; diff --git a/verifiable-db/src/query/row_chunk_gadgets/row_process_gadget.rs b/verifiable-db/src/query/row_chunk_gadgets/row_process_gadget.rs index b821f50df..30b9c84b6 100644 --- a/verifiable-db/src/query/row_chunk_gadgets/row_process_gadget.rs +++ b/verifiable-db/src/query/row_chunk_gadgets/row_process_gadget.rs @@ -6,6 +6,7 @@ use plonky2::iop::witness::PartialWitness; use serde::{Deserialize, Serialize}; use crate::query::{ + api::RowInput, merkle_path::{MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs}, universal_circuit::{ universal_circuit_inputs::RowCells, @@ -14,7 +15,6 @@ use crate::query::{ UniversalQueryValueInputs, UniversalQueryValueWires, }, }, - api::RowInput, }; use super::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget, RowChunkDataTarget}; diff --git a/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs b/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs index b56aaf68d..edecea6b6 100644 --- a/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs +++ b/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs @@ -1,18 +1,35 @@ use std::iter::once; use crate::query::{ - computational_hash_ids::{Output, PlaceholderIdentifier}, pi_len, public_inputs::PublicInputsUniversalCircuit, row_chunk_gadgets::BoundaryRowDataTarget, utils::QueryBounds + computational_hash_ids::{Output, PlaceholderIdentifier}, + pi_len, + public_inputs::PublicInputsUniversalCircuit, + row_chunk_gadgets::BoundaryRowDataTarget, + utils::QueryBounds, }; use anyhow::Result; use itertools::Itertools; use mp2_common::{ - array::ToField, poseidon::{empty_poseidon_hash, HashPermutation}, public_inputs::PublicInputCommon, serialization::{deserialize, serialize}, types::CBuilder, utils::{FromTargets, HashBuilder, ToFields, ToTargets}, CHasher, C, D, F + array::ToField, + poseidon::{empty_poseidon_hash, HashPermutation}, + public_inputs::PublicInputCommon, + serialization::{deserialize, serialize}, + types::CBuilder, + utils::{FromTargets, HashBuilder, ToFields, ToTargets}, + CHasher, C, D, F, }; use plonky2::{ - field::types::Field, hash::hashing::hash_n_to_hash_no_pad, iop::{ + field::types::Field, + hash::hashing::hash_n_to_hash_no_pad, + iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, - }, plonk::{circuit_builder::CircuitBuilder, circuit_data::{CircuitConfig, CircuitData}, proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}} + }, + plonk::{ + circuit_builder::CircuitBuilder, + circuit_data::{CircuitConfig, CircuitData}, + proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}, + }, }; use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -157,7 +174,9 @@ where .chain(empty_hash.elements.iter()) .chain(node_min.to_targets().iter()) .chain(node_max.to_targets().iter()) - .chain(once(&hash_wires.input_wires.column_extraction_wires.column_ids[1])) + .chain(once( + &hash_wires.input_wires.column_extraction_wires.column_ids[1], + )) .chain(node_min.to_targets().iter()) .chain(value_wires.output_wires.tree_hash.elements.iter()) .cloned() @@ -173,7 +192,8 @@ where // compute dummy left boundary and right boundary rows to be exposed as public inputs; // they are ignored by the circuits processing this proof, so it's ok to use dummy // values - let dummy_boundary_row_targets = b.constants(&vec![F::ZERO; BoundaryRowDataTarget::NUM_TARGETS]); + let dummy_boundary_row_targets = + b.constants(&vec![F::ZERO; BoundaryRowDataTarget::NUM_TARGETS]); let primary_index_value = &value_wires.input_wires.column_values[0]; PublicInputsUniversalCircuit::::new( &tree_hash.to_targets(), @@ -318,19 +338,32 @@ pub struct UniversalQueryCircuitParams< const MAX_NUM_RESULTS: usize, T: OutputComponent + Serialize, > { - #[serde(serialize_with="serialize", deserialize_with="deserialize")] + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] pub(crate) data: CircuitData, - wires: UniversalQueryCircuitWires, + wires: UniversalQueryCircuitWires< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, } impl< - const MAX_NUM_COLUMNS: usize, - const MAX_NUM_PREDICATE_OPS: usize, - const MAX_NUM_RESULT_OPS: usize, - const MAX_NUM_RESULTS: usize, - T: OutputComponent + Serialize + DeserializeOwned, -> UniversalQueryCircuitParams -where + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent + Serialize + DeserializeOwned, + > + UniversalQueryCircuitParams< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + > +where [(); MAX_NUM_RESULTS - 1]:, [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, { @@ -338,19 +371,18 @@ where let mut builder = CBuilder::new(config); let wires = UniversalQueryCircuitInputs::build(&mut builder); let data = builder.build(); - Self { - data, - wires, - } + Self { data, wires } } - pub(crate) fn generate_proof(&self, input: &UniversalQueryCircuitInputs< + pub(crate) fn generate_proof( + &self, + input: &UniversalQueryCircuitInputs< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, T, - > + >, ) -> Result> { let mut pw = PartialWitness::::new(); input.assign(&mut pw, &self.wires); @@ -481,7 +513,12 @@ mod tests { use alloy::primitives::U256; use itertools::Itertools; use mp2_common::{ - array::ToField, default_config, group_hashing::map_to_curve_point, poseidon::empty_poseidon_hash, utils::{FromFields, ToFields, TryIntoBool}, C, D, F + array::ToField, + default_config, + group_hashing::map_to_curve_point, + poseidon::empty_poseidon_hash, + utils::{FromFields, ToFields, TryIntoBool}, + C, D, F, }; use mp2_test::{ cells_tree::{compute_cells_tree_hash, TestCell}, @@ -499,24 +536,25 @@ mod tests { use rand::{thread_rng, Rng}; use crate::query::{ - utils::{QueryBoundSource, QueryBounds}, computational_hash_ids::{ AggregationOperation, ColumnIDs, HashPermutation, Identifiers, Operation, PlaceholderIdentifier, }, public_inputs::PublicInputsUniversalCircuit, universal_circuit::{ - output_no_aggregation::Circuit as OutputNoAggCircuit, output_with_aggregation::Circuit as OutputAggCircuit, universal_circuit_inputs::{ + output_no_aggregation::Circuit as OutputNoAggCircuit, + output_with_aggregation::Circuit as OutputAggCircuit, + universal_circuit_inputs::{ BasicOperation, ColumnCell, InputOperand, OutputItem, PlaceholderId, Placeholders, ResultStructure, RowCells, - }, universal_query_circuit::{placeholder_hash, UniversalQueryCircuitParams}, ComputationalHash + }, + universal_query_circuit::{placeholder_hash, UniversalQueryCircuitParams}, + ComputationalHash, }, + utils::{QueryBoundSource, QueryBounds}, }; - use super::{ - OutputComponent, UniversalQueryCircuitInputs, - UniversalQueryCircuitWires, - }; + use super::{OutputComponent, UniversalQueryCircuitInputs, UniversalQueryCircuitWires}; impl< const MAX_NUM_COLUMNS: usize, @@ -574,17 +612,21 @@ mod tests { // ensure that primary index column value is in the range specified by the query: // we sample a random u256 in range [0, max_query - min_query) and then we // add min_query - gen_random_u256(rng).div_rem(max_query_primary - min_query_primary + U256::from(1)).1 + min_query_primary - }, + gen_random_u256(rng) + .div_rem(max_query_primary - min_query_primary + U256::from(1)) + .1 + + min_query_primary + } 1 => { // ensure that second column value is in the range specified by the query: // we sample a random u256 in range [0, max_query - min_query) and then we // add min_query - gen_random_u256(rng).div_rem(max_query_secondary - min_query_secondary + U256::from(1)).1 + min_query_secondary - }, - _ => { gen_random_u256(rng) - }, + .div_rem(max_query_secondary - min_query_secondary + U256::from(1)) + .1 + + min_query_secondary + } + _ => gen_random_u256(rng), } }) .collect_vec(); @@ -602,10 +644,7 @@ mod tests { // define placeholders let first_placeholder_id = PlaceholderId::Generic(0); let second_placeholder_id = PlaceholderIdentifier::Generic(1); - let mut placeholders = Placeholders::new_empty( - min_query_primary, - max_query_primary, - ); + let mut placeholders = Placeholders::new_empty(min_query_primary, max_query_primary); [first_placeholder_id, second_placeholder_id] .iter() .for_each(|id| placeholders.insert(*id, gen_random_u256(rng))); @@ -875,17 +914,14 @@ mod tests { .into(), ); let proof = if build_parameters { - let params = UniversalQueryCircuitParams::build( - default_config() - ); - params - .generate_proof(&circuit) - .unwrap() + let params = UniversalQueryCircuitParams::build(default_config()); + params.generate_proof(&circuit).unwrap() } else { run_circuit::(circuit.clone()) }; - let pi = PublicInputsUniversalCircuit::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); + let pi = + PublicInputsUniversalCircuit::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); assert_eq!(tree_hash, pi.tree_hash()); assert_eq!(output_values[0], pi.first_value_as_u256()); assert_eq!(output_values[1..], pi.values()[..output_values.len() - 1]); @@ -934,17 +970,21 @@ mod tests { // ensure that primary index column value is in the range specified by the query: // we sample a random u256 in range [0, max_query - min_query) and then we // add min_query - gen_random_u256(rng).div_rem(max_query_primary - min_query_primary + U256::from(1)).1 + min_query_primary - }, + gen_random_u256(rng) + .div_rem(max_query_primary - min_query_primary + U256::from(1)) + .1 + + min_query_primary + } 1 => { // ensure that second column value is in the range specified by the query: // we sample a random u256 in range [0, max_query - min_query) and then we // add min_query - gen_random_u256(rng).div_rem(max_query_secondary - min_query_secondary + U256::from(1)).1 + min_query_secondary - }, - _ => { gen_random_u256(rng) - }, + .div_rem(max_query_secondary - min_query_secondary + U256::from(1)) + .1 + + min_query_secondary + } + _ => gen_random_u256(rng), } }) .collect_vec(); @@ -962,10 +1002,7 @@ mod tests { // define placeholders let first_placeholder_id = PlaceholderId::Generic(0); let second_placeholder_id = PlaceholderIdentifier::Generic(1); - let mut placeholders = Placeholders::new_empty( - min_query_primary, - max_query_primary, - ); + let mut placeholders = Placeholders::new_empty(min_query_primary, max_query_primary); [first_placeholder_id, second_placeholder_id] .iter() .for_each(|id| placeholders.insert(*id, gen_random_u256(rng))); @@ -1271,17 +1308,14 @@ mod tests { ); let proof = if build_parameters { - let params = UniversalQueryCircuitParams::build( - default_config() - ); - params - .generate_proof(&circuit) - .unwrap() + let params = UniversalQueryCircuitParams::build(default_config()); + params.generate_proof(&circuit).unwrap() } else { run_circuit::(circuit.clone()) }; - let pi = PublicInputsUniversalCircuit::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); + let pi = + PublicInputsUniversalCircuit::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); assert_eq!(tree_hash, pi.tree_hash()); assert_eq!(output_acc.to_weierstrass(), pi.first_value_as_curve_point()); // The other MAX_NUM_RESULTS -1 output values are dummy ones, as in queries diff --git a/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs b/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs index ef3df751e..97eeedaf2 100644 --- a/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs +++ b/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs @@ -28,7 +28,6 @@ use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget use serde::{Deserialize, Serialize}; use crate::query::{ - utils::{QueryBoundSecondary, QueryBoundSource, QueryBounds}, computational_hash_ids::{ ColumnIDs, ComputationalHashCache, HashPermutation, Operation, Output, PlaceholderIdentifier, @@ -37,6 +36,7 @@ use crate::query::{ basic_operation::BasicOperationInputs, column_extraction::ColumnExtractionValueWires, universal_circuit_inputs::OutputItem, }, + utils::{QueryBoundSecondary, QueryBoundSource, QueryBounds}, }; use super::{ @@ -163,6 +163,7 @@ impl QueryBound { /// Initialize a query bound for the primary index, from the set of `placeholders` employed in the query, /// which include also the primary index bounds by construction. The flag `is_min_bound` /// must be true iff the bound to be initialized is a lower bound in the range specified in the query + #[allow(dead_code)] // unused for now, but it could be useful to keep it pub(crate) fn new_primary_index_bound( placeholders: &Placeholders, is_min_bound: bool, @@ -765,8 +766,11 @@ where let min_secondary = min_query_secondary.get_bound_value().clone(); let max_secondary = max_query_secondary.get_bound_value().clone(); - let num_bound_overflows = - QueryBoundTarget::num_overflows_for_query_bound_operations(b, &min_query_secondary, &max_query_secondary); + let num_bound_overflows = QueryBoundTarget::num_overflows_for_query_bound_operations( + b, + &min_query_secondary, + &max_query_secondary, + ); UniversalQueryHashWires { input_wires: UniversalQueryHashInputWires { column_extraction_wires: column_extraction_wires.input_wires, @@ -806,8 +810,12 @@ where .assign(pw, &wires.column_extraction_wires); pw.set_u256_target(&wires.min_query_primary, self.min_query_primary); pw.set_u256_target(&wires.max_query_primary, self.max_query_primary); - wires.min_query_secondary.assign(pw, &self.min_query_secondary); - wires.max_query_secondary.assign(pw, &self.max_query_secondary); + wires + .min_query_secondary + .assign(pw, &self.min_query_secondary); + wires + .max_query_secondary + .assign(pw, &self.max_query_secondary); self.filtering_predicate_inputs .iter() .chain(self.result_values_inputs.iter()) @@ -1309,11 +1317,12 @@ where // Enforce that the value of primary index for the current row is in the range given by these bounds let index_value = &column_values[0]; - let less_than_max = b.is_less_or_equal_than_u256(index_value, &hash_input_wires.max_query_primary); - let greater_than_min = b.is_less_or_equal_than_u256(&hash_input_wires.min_query_primary, index_value); + let less_than_max = + b.is_less_or_equal_than_u256(index_value, &hash_input_wires.max_query_primary); + let greater_than_min = + b.is_less_or_equal_than_u256(&hash_input_wires.min_query_primary, index_value); b.connect(less_than_max.target, _true.target); b.connect(greater_than_min.target, _true.target); - // min and max for secondary indexed column let node_min = &column_values[1]; diff --git a/verifiable-db/src/query/utils.rs b/verifiable-db/src/query/utils.rs index bf85c0c93..15ba24110 100644 --- a/verifiable-db/src/query/utils.rs +++ b/verifiable-db/src/query/utils.rs @@ -27,12 +27,13 @@ use plonky2::{ }; use serde::{Deserialize, Serialize}; - use super::{ computational_hash_ids::{ColumnIDs, Identifiers, PlaceholderIdentifier}, universal_circuit::{ universal_circuit_inputs::{BasicOperation, PlaceholderId, Placeholders, ResultStructure}, - universal_query_circuit::{placeholder_hash, placeholder_hash_without_query_bounds, UniversalCircuitInput}, + universal_query_circuit::{ + placeholder_hash, placeholder_hash_without_query_bounds, UniversalCircuitInput, + }, universal_query_gadget::QueryBound, ComputationalHash, PlaceholderHash, }, diff --git a/verifiable-db/src/results_tree/binding/binding_results.rs b/verifiable-db/src/results_tree/binding/binding_results.rs index 1bbb2e41f..b177a37f3 100644 --- a/verifiable-db/src/results_tree/binding/binding_results.rs +++ b/verifiable-db/src/results_tree/binding/binding_results.rs @@ -99,10 +99,13 @@ impl BindingResultsCircuit { mod tests { use super::*; use crate::{ - results_tree::{construction::{ - public_inputs::ResultsConstructionPublicInputs, - tests::{pi_len, random_results_construction_public_inputs}, - }, tests::random_aggregation_public_inputs}, + results_tree::{ + construction::{ + public_inputs::ResultsConstructionPublicInputs, + tests::{pi_len, random_results_construction_public_inputs}, + }, + tests::random_aggregation_public_inputs, + }, test_utils::random_aggregation_operations, }; use itertools::Itertools; diff --git a/verifiable-db/src/results_tree/mod.rs b/verifiable-db/src/results_tree/mod.rs index 62b718052..443a8fd90 100644 --- a/verifiable-db/src/results_tree/mod.rs +++ b/verifiable-db/src/results_tree/mod.rs @@ -1,6 +1,6 @@ pub(crate) mod binding; pub(crate) mod construction; -/// Old query public inputs, moved here because the circuits in this module still expects +/// Old query public inputs, moved here because the circuits in this module still expects /// these public inputs for now pub(crate) mod old_public_inputs; @@ -9,7 +9,10 @@ pub(crate) mod tests { use std::array; use mp2_common::{array::ToField, types::CURVE_TARGET_LEN, utils::ToFields, F}; - use plonky2::{field::types::{Field, Sample}, hash::hash_types::NUM_HASH_OUT_ELTS}; + use plonky2::{ + field::types::{Field, Sample}, + hash::hash_types::NUM_HASH_OUT_ELTS, + }; use plonky2_ecgfp5::curve::curve::Point; use rand::{thread_rng, Rng}; @@ -32,7 +35,8 @@ pub(crate) mod tests { ] .map(PublicInputs::::to_range); - let first_value_start = PublicInputs::::to_range(QueryPublicInputs::OutputValues).start; + let first_value_start = + PublicInputs::::to_range(QueryPublicInputs::OutputValues).start; let is_first_op_id = ops[0] == Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); @@ -48,7 +52,7 @@ pub(crate) mod tests { }); array::from_fn(|_| { - let mut pi = (0..PublicInputs::::total_len()) + let mut pi = (0..PublicInputs::::total_len()) .map(|_| rng.gen()) .collect::>() .to_fields(); diff --git a/verifiable-db/src/results_tree/old_public_inputs.rs b/verifiable-db/src/results_tree/old_public_inputs.rs index 7f6d07b00..5eb805638 100644 --- a/verifiable-db/src/results_tree/old_public_inputs.rs +++ b/verifiable-db/src/results_tree/old_public_inputs.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use std::iter::once; use alloy::primitives::U256; diff --git a/verifiable-db/src/revelation/api.rs b/verifiable-db/src/revelation/api.rs index ec2264d47..c07f032b1 100644 --- a/verifiable-db/src/revelation/api.rs +++ b/verifiable-db/src/revelation/api.rs @@ -12,7 +12,9 @@ use mp2_common::{ C, D, F, }; use plonky2::plonk::{ - circuit_data::{VerifierCircuitData, VerifierOnlyCircuitData}, config::Hasher, proof::ProofWithPublicInputs, + circuit_data::{VerifierCircuitData, VerifierOnlyCircuitData}, + config::Hasher, + proof::ProofWithPublicInputs, }; use recursion_framework::{ circuit_builder::{CircuitWithUniversalVerifier, CircuitWithUniversalVerifierBuilder}, @@ -24,15 +26,21 @@ use serde::{Deserialize, Serialize}; use crate::{ query::{ - utils::QueryBounds, computational_hash_ids::ColumnIDs, pi_len as query_pi_len, universal_circuit::{output_no_aggregation::Circuit as OutputNoAggCircuit, universal_circuit_inputs::{ - BasicOperation, Placeholders, ResultStructure, - }, universal_query_circuit::{UniversalCircuitInput, UniversalQueryCircuitParams}} + computational_hash_ids::ColumnIDs, + pi_len as query_pi_len, + universal_circuit::{ + output_no_aggregation::Circuit as OutputNoAggCircuit, + universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure}, + universal_query_circuit::{UniversalCircuitInput, UniversalQueryCircuitParams}, + }, + utils::QueryBounds, }, revelation::{ placeholders_check::CheckPlaceholderGadget, revelation_unproven_offset::{ generate_dummy_row_proof_inputs, RecursiveCircuitWires as RecursiveCircuitWiresUnprovenOffset, + TabularQueryOutputModifiers, }, }, }; @@ -40,11 +48,12 @@ use crate::{ use super::{ pi_len, revelation_unproven_offset::{ - CircuitBuilderParams, RecursiveCircuitInputs as RecursiveCircuitInputsUnporvenOffset, RevelationCircuit as RevelationCircuitUnprovenOffset, RowPath + CircuitBuilderParams, RecursiveCircuitInputs as RecursiveCircuitInputsUnporvenOffset, + RevelationCircuit as RevelationCircuitUnprovenOffset, RowPath, }, revelation_without_results_tree::{ - CircuitBuilderParams as CircuitBuilderParamsNoResultsTree, RecursiveCircuitInputs, RecursiveCircuitWires, - RevelationWithoutResultsTreeCircuit, + CircuitBuilderParams as CircuitBuilderParamsNoResultsTree, RecursiveCircuitInputs, + RecursiveCircuitWires, RevelationWithoutResultsTreeCircuit, }, }; #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] @@ -391,9 +400,11 @@ where [column_ids.primary, column_ids.secondary], &results_structure.output_ids, result_values, - limit, - offset, - results_structure.distinct.unwrap_or(false), + TabularQueryOutputModifiers::new( + limit, + offset, + results_structure.distinct.unwrap_or(false), + ), placeholder_inputs, )?; @@ -570,10 +581,13 @@ where #[cfg(test)] mod tests { - use crate::{query::pi_len as query_pi_len, test_utils::{ - TestRevelationData, MAX_NUM_COLUMNS, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, - MAX_NUM_PLACEHOLDERS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, - }}; + use crate::{ + query::pi_len as query_pi_len, + test_utils::{ + TestRevelationData, MAX_NUM_COLUMNS, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, + MAX_NUM_PLACEHOLDERS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, + }, + }; use itertools::Itertools; use mp2_common::{ array::ToField, @@ -619,7 +633,8 @@ mod tests { >::default(); let preprocessing_circuits = TestingRecursiveCircuits::::default(); - let dummy_universal_circuit = TestDummyCircuit::<{query_pi_len::()}>::build(); + let dummy_universal_circuit = + TestDummyCircuit::<{ query_pi_len::() }>::build(); println!("building params"); let params = Parameters::< ROW_TREE_MAX_DEPTH, @@ -668,11 +683,7 @@ mod tests { ) .unwrap(); let proof = params - .generate_proof( - input, - query_circuits.get_recursive_circuit_set(), - None, - ) + .generate_proof(input, query_circuits.get_recursive_circuit_set(), None) .unwrap(); let (proof, _) = ProofWithVK::deserialize(&proof).unwrap().into(); let pi = PublicInputs::::from_slice(&proof.public_inputs); diff --git a/verifiable-db/src/revelation/mod.rs b/verifiable-db/src/revelation/mod.rs index d27c49e06..22d5dcb46 100644 --- a/verifiable-db/src/revelation/mod.rs +++ b/verifiable-db/src/revelation/mod.rs @@ -3,7 +3,6 @@ use crate::ivc::NUM_IO; use mp2_common::F; - pub mod api; pub(crate) mod placeholders_check; mod public_inputs; diff --git a/verifiable-db/src/revelation/placeholders_check.rs b/verifiable-db/src/revelation/placeholders_check.rs index e5df6a249..928ab3a68 100644 --- a/verifiable-db/src/revelation/placeholders_check.rs +++ b/verifiable-db/src/revelation/placeholders_check.rs @@ -2,12 +2,12 @@ //! compute and return the `num_placeholders` and the `placeholder_ids_hash`. use crate::query::{ - utils::QueryBounds, computational_hash_ids::PlaceholderIdentifier, universal_circuit::{ universal_circuit_inputs::{PlaceholderId, Placeholders}, universal_query_gadget::QueryBound, }, + utils::QueryBounds, }; use alloy::primitives::U256; use anyhow::{ensure, Result}; diff --git a/verifiable-db/src/revelation/revelation_unproven_offset.rs b/verifiable-db/src/revelation/revelation_unproven_offset.rs index 55ab42c0a..86481c69c 100644 --- a/verifiable-db/src/revelation/revelation_unproven_offset.rs +++ b/verifiable-db/src/revelation/revelation_unproven_offset.rs @@ -34,25 +34,32 @@ use plonky2::{ witness::{PartialWitness, WitnessWrite}, }, plonk::{ - circuit_data::{VerifierCircuitData, VerifierOnlyCircuitData}, config::Hasher, proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget} + circuit_data::{VerifierCircuitData, VerifierOnlyCircuitData}, + config::Hasher, + proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}, }, }; use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; use recursion_framework::{ circuit_builder::CircuitLogicWires, - framework::{ - RecursiveCircuits, RecursiveCircuitsVerifierGagdet, - }, + framework::{RecursiveCircuits, RecursiveCircuitsVerifierGagdet}, }; use serde::{Deserialize, Serialize}; use crate::{ ivc::PublicInputs as OriginalTreePublicInputs, query::{ - utils::{ChildPosition, NodeInfo, QueryBounds}, public_inputs::PublicInputsUniversalCircuit as QueryProofPublicInputs, computational_hash_ids::{ColumnIDs, ResultIdentifier}, merkle_path::{MerklePathGadget, MerklePathTargetInputs}, universal_circuit::{ + computational_hash_ids::{ColumnIDs, ResultIdentifier}, + merkle_path::{MerklePathGadget, MerklePathTargetInputs}, + public_inputs::PublicInputsUniversalCircuit as QueryProofPublicInputs, + universal_circuit::{ build_cells_tree, - universal_circuit_inputs::{BasicOperation, ColumnCell, Placeholders, ResultStructure, RowCells}, universal_query_circuit::{UniversalCircuitInput, UniversalQueryCircuitInputs}, - } + universal_circuit_inputs::{ + BasicOperation, ColumnCell, Placeholders, ResultStructure, RowCells, + }, + universal_query_circuit::{UniversalCircuitInput, UniversalQueryCircuitInputs}, + }, + utils::{ChildPosition, NodeInfo, QueryBounds}, }, }; @@ -145,6 +152,26 @@ impl NodeInfoTarget { } } +/// Data structure containing the parameters found in tabular +/// queries that specify which outputs should be returned +#[derive(Clone, Debug)] +pub(crate) struct TabularQueryOutputModifiers { + limit: u32, + offset: u32, + /// Boolean flag specifying whether DISTINCT keyword must be applied to results + distinct: bool, +} + +impl TabularQueryOutputModifiers { + pub(crate) fn new(limit: u32, offset: u32, distinct: bool) -> Self { + Self { + limit, + offset, + distinct, + } + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub(crate) struct RevelationWires< const ROW_TREE_MAX_DEPTH: usize, @@ -287,9 +314,7 @@ where index_column_ids: [F; 2], item_ids: &[F], results: [Vec; L], - limit: u32, - offset: u32, - distinct: bool, + query_modifiers: TabularQueryOutputModifiers, placeholder_inputs: CheckPlaceholderGadget, ) -> Result { let mut row_tree_paths = [MerklePathGadget::::default(); L]; @@ -337,9 +362,9 @@ where num_actual_items_per_row, ids: padded_ids.try_into().unwrap(), results: results.try_into().unwrap(), - limit, - offset, - distinct, + limit: query_modifiers.limit, + offset: query_modifiers.offset, + distinct: query_modifiers.distinct, check_placeholder_inputs: placeholder_inputs, }) } @@ -481,14 +506,8 @@ where b.connect_hashes(row_proof.computational_hash_target(), computational_hash); b.connect_hashes(row_proof.placeholder_hash_target(), placeholder_hash); // check that query bounds on primary index are the same for all the proofs - b.enforce_equal_u256( - &row_proof.min_primary_target(), - &min_query_primary, - ); - b.enforce_equal_u256( - &row_proof.max_primary_target(), - &max_query_primary, - ); + b.enforce_equal_u256(&row_proof.min_primary_target(), &min_query_primary); + b.enforce_equal_u256(&row_proof.max_primary_target(), &max_query_primary); overflow = b.or(overflow, row_proof.overflow_flag_target()); }); @@ -663,13 +682,15 @@ where value: U256::default(), id: column_ids.secondary, }; - let non_indexed_columns = column_ids.non_indexed_columns().iter().map(|id| - ColumnCell::new(*id, U256::default()) - ).collect_vec(); + let non_indexed_columns = column_ids + .non_indexed_columns() + .iter() + .map(|id| ColumnCell::new(*id, U256::default())) + .collect_vec(); let cells = RowCells::new( primary_index_column, secondary_index_column, - non_indexed_columns + non_indexed_columns, ); let universal_query_circuit = UniversalQueryCircuitInputs::new( &cells, @@ -680,11 +701,7 @@ where results, true, // we generate proof for a dummy row )?; - Ok( - UniversalCircuitInput::QueryNoAgg( - universal_query_circuit - ) - ) + Ok(UniversalCircuitInput::QueryNoAgg(universal_query_circuit)) } pub struct CircuitBuilderParams { @@ -765,10 +782,8 @@ where _verified_proofs: [&ProofWithPublicInputsTarget; 0], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - let row_verifiers = [0; L].map(|_| verify_proof_fixed_circuit( - builder, - &builder_parameters.universal_query_vk, - )); + let row_verifiers = [0; L] + .map(|_| verify_proof_fixed_circuit(builder, &builder_parameters.universal_query_vk)); let preprocessing_verifier = RecursiveCircuitsVerifierGagdet::::new( default_config(), @@ -780,11 +795,7 @@ where ); let row_pis = row_verifiers .iter() - .map(|verifier| { - QueryProofPublicInputs::from_slice( - &verifier.public_inputs, - ) - }) + .map(|verifier| QueryProofPublicInputs::from_slice(&verifier.public_inputs)) .collect_vec(); let preprocessing_pi = OriginalTreePublicInputs::from_slice(&preprocessing_proof.public_inputs); @@ -844,11 +855,16 @@ mod tests { PublicInputs as OriginalTreePublicInputs, }, query::{ + pi_len as query_pi_len, + public_inputs::{ + PublicInputsUniversalCircuit as QueryProofPublicInputs, + QueryPublicInputsUniversalCircuit, + }, utils::{ChildPosition, NodeInfo}, - public_inputs::{PublicInputsUniversalCircuit as QueryProofPublicInputs, QueryPublicInputsUniversalCircuit}, pi_len as query_pi_len, }, revelation::{ - revelation_unproven_offset::RowPath, tests::TestPlaceholders, + revelation_unproven_offset::{RowPath, TabularQueryOutputModifiers}, + tests::TestPlaceholders, NUM_PREPROCESSING_IO, }, test_utils::random_aggregation_operations, @@ -950,7 +966,7 @@ mod tests { let placeholder_hash = test_placeholders.query_placeholder_hash; let min_query_primary = test_placeholders.min_query; let max_query_primary = test_placeholders.max_query; - // set same primary index query bounds, computational hash and placeholder hash for all proofs; + // set same primary index query bounds, computational hash and placeholder hash for all proofs; // set also num matching rows to 1 for all proofs row_pis.iter_mut().for_each(|pis| { let [min_primary_range, max_primary_range, ch_range, ph_range, count_range] = [ @@ -967,8 +983,11 @@ mod tests { pis[ph_range].copy_from_slice(&placeholder_hash.to_fields()); pis[count_range].copy_from_slice(&[F::ONE]); }); - let hash_range = QueryProofPublicInputs::::to_range(QueryPublicInputsUniversalCircuit::TreeHash); - let index_value_range = QueryProofPublicInputs::::to_range(QueryPublicInputsUniversalCircuit::PrimaryIndexValue); + let hash_range = + QueryProofPublicInputs::::to_range(QueryPublicInputsUniversalCircuit::TreeHash); + let index_value_range = QueryProofPublicInputs::::to_range( + QueryPublicInputsUniversalCircuit::PrimaryIndexValue, + ); // build a test tree containing the rows 0..5 found in row_pis // Index tree: // A @@ -1015,8 +1034,7 @@ mod tests { }; let node_2 = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[2]); - let embedded_tree_hash = - HashOutput::from(gen_random_field_hash::()); + let embedded_tree_hash = HashOutput::from(gen_random_field_hash::()); let node_value = row_pi.secondary_index_value(); NodeInfo::new( &embedded_tree_hash, @@ -1032,8 +1050,7 @@ mod tests { row_pis[2][hash_range.clone()].copy_from_slice(&node_2_hash.to_fields()); let node_4 = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[4]); - let embedded_tree_hash = - HashOutput::from(gen_random_field_hash::()); + let embedded_tree_hash = HashOutput::from(gen_random_field_hash::()); let node_value = row_pi.secondary_index_value(); NodeInfo::new( &embedded_tree_hash, @@ -1049,8 +1066,7 @@ mod tests { row_pis[4][hash_range.clone()].copy_from_slice(&node_4_hash.to_fields()); let node_5 = { // can use all dummy values for this node, since there is no proof associated to it - let embedded_tree_hash = - HashOutput::from(gen_random_field_hash::()); + let embedded_tree_hash = HashOutput::from(gen_random_field_hash::()); let [node_value, node_min, node_max] = array::from_fn(|_| gen_random_u256(rng)); NodeInfo::new( &embedded_tree_hash, @@ -1062,8 +1078,7 @@ mod tests { ) }; let node_4_hash = HashOutput::from(node_4_hash); - let node_5_hash = - HashOutput::from(node_5.compute_node_hash(index_ids[1])); + let node_5_hash = HashOutput::from(node_5.compute_node_hash(index_ids[1])); let node_3 = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[3]); let embedded_tree_hash = HashOutput::from(row_pi.tree_hash()); @@ -1079,8 +1094,7 @@ mod tests { }; let node_b = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[2]); - let embedded_tree_hash = - HashOutput::from(node_2.compute_node_hash(index_ids[1])); + let embedded_tree_hash = HashOutput::from(node_2.compute_node_hash(index_ids[1])); let node_value = row_pi.primary_index_value(); NodeInfo::new( &embedded_tree_hash, @@ -1093,12 +1107,11 @@ mod tests { }; let node_c = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[3]); - let embedded_tree_hash = - HashOutput::from(node_3.compute_node_hash(index_ids[1])); + let embedded_tree_hash = HashOutput::from(node_3.compute_node_hash(index_ids[1])); let node_value = row_pi.primary_index_value(); // we need to set index value in `row_pis[4]` to the same value of `row_pis[3]`, as // they are in the same index tree - row_pis[4][index_value_range.clone()].copy_from_slice(&node_value.to_fields()); + row_pis[4][index_value_range.clone()].copy_from_slice(&node_value.to_fields()); NodeInfo::new( &embedded_tree_hash, None, @@ -1108,18 +1121,15 @@ mod tests { node_value, ) }; - let node_b_hash = - HashOutput::from(node_b.compute_node_hash(index_ids[0])); - let node_c_hash = - HashOutput::from(node_c.compute_node_hash(index_ids[0])); + let node_b_hash = HashOutput::from(node_b.compute_node_hash(index_ids[0])); + let node_c_hash = HashOutput::from(node_c.compute_node_hash(index_ids[0])); let node_a = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[0]); - let embedded_tree_hash = - HashOutput::from(node_0.compute_node_hash(index_ids[1])); + let embedded_tree_hash = HashOutput::from(node_0.compute_node_hash(index_ids[1])); let node_value = row_pi.primary_index_value(); // we need to set index value in `row_pis[1]` to the same value of `row_pis[0]`, as // they are in the same index tree - row_pis[1][index_value_range].copy_from_slice(&node_value.to_fields()); + row_pis[1][index_value_range].copy_from_slice(&node_value.to_fields()); NodeInfo::new( &embedded_tree_hash, Some(&node_b_hash), // left child is node B @@ -1182,8 +1192,9 @@ mod tests { .await; row_pis.iter_mut().zip(digests).for_each(|(pis, digest)| { - let values_range = - QueryProofPublicInputs::::to_range(QueryPublicInputsUniversalCircuit::OutputValues); + let values_range = QueryProofPublicInputs::::to_range( + QueryPublicInputsUniversalCircuit::OutputValues, + ); pis[values_range.start..values_range.start + CURVE_TARGET_LEN] .copy_from_slice(&digest.to_fields()) }); @@ -1237,9 +1248,7 @@ mod tests { index_ids, &ids, results.map(|res| res.to_vec()), - 0, - 0, - false, + TabularQueryOutputModifiers::new(0, 0, false), test_placeholders.check_placeholder_inputs, ) .unwrap(), diff --git a/verifiable-db/src/revelation/revelation_without_results_tree.rs b/verifiable-db/src/revelation/revelation_without_results_tree.rs index 282a9d75c..5def42ae4 100644 --- a/verifiable-db/src/revelation/revelation_without_results_tree.rs +++ b/verifiable-db/src/revelation/revelation_without_results_tree.rs @@ -3,7 +3,8 @@ use crate::{ ivc::PublicInputs as OriginalTreePublicInputs, query::{ - public_inputs::PublicInputs as QueryProofPublicInputs, computational_hash_ids::AggregationOperation, pi_len as query_pi_len, + computational_hash_ids::AggregationOperation, pi_len as query_pi_len, + public_inputs::PublicInputs as QueryProofPublicInputs, }, revelation::PublicInputs, }; @@ -97,21 +98,25 @@ where let mut results = Vec::with_capacity(L * S); // flag to determine whether entry count is zero let is_entry_count_zero = b.add_virtual_bool_target_unsafe(); - query_proof.operation_ids_target().into_iter().enumerate().for_each(|(i, op)| { - let is_op_avg = b.is_equal(op, op_avg); - let is_op_count = b.is_equal(op, op_count); - let result = query_proof.value_target_at_index(i); + query_proof + .operation_ids_target() + .into_iter() + .enumerate() + .for_each(|(i, op)| { + let is_op_avg = b.is_equal(op, op_avg); + let is_op_count = b.is_equal(op, op_count); + let result = query_proof.value_target_at_index(i); - // Compute the AVG result (and it's set to zero if the divisor is zero). - let (avg_result, _, is_divisor_zero) = b.div_u256(&result, &entry_count); + // Compute the AVG result (and it's set to zero if the divisor is zero). + let (avg_result, _, is_divisor_zero) = b.div_u256(&result, &entry_count); - let result = b.select_u256(is_op_avg, &avg_result, &result); - let result = b.select_u256(is_op_count, &entry_count, &result); + let result = b.select_u256(is_op_avg, &avg_result, &result); + let result = b.select_u256(is_op_count, &entry_count, &result); - b.connect(is_divisor_zero.target, is_entry_count_zero.target); + b.connect(is_divisor_zero.target, is_entry_count_zero.target); - results.push(result); - }); + results.push(result); + }); results.resize(L * S, u256_zero); // Pre-compute the final placeholder hash then check it in the @@ -132,7 +137,10 @@ where // Check that the tree employed to build the queries is the same as the // tree constructed in pre-processing. - b.connect_hashes(query_proof.tree_hash_target(), original_tree_proof.merkle_hash()); + b.connect_hashes( + query_proof.tree_hash_target(), + original_tree_proof.merkle_hash(), + ); // Add the hash of placeholder identifiers and pre-processing metadata // hash to the computational hash: @@ -323,12 +331,11 @@ where _verified_proofs: [&ProofWithPublicInputsTarget; 0], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - let query_verifier = RecursiveCircuitsVerifierGagdet::< - F, - C, - D, - { query_pi_len::() }, - >::new(default_config(), &builder_parameters.query_circuit_set); + let query_verifier = + RecursiveCircuitsVerifierGagdet::() }>::new( + default_config(), + &builder_parameters.query_circuit_set, + ); let query_verifier = query_verifier.verify_proof_in_circuit_set(builder); let preprocessing_verifier = RecursiveCircuitsVerifierGagdet::::new( @@ -343,8 +350,7 @@ where OriginalTreePublicInputs::from_slice(&preprocessing_proof.public_inputs); let revelation_circuit = { let query_pi = QueryProofPublicInputs::from_slice( - query_verifier - .get_public_input_targets::() }>(), + query_verifier.get_public_input_targets::() }>(), ); RevelationWithoutResultsTreeCircuit::build(builder, &query_pi, &preprocessing_pi) }; @@ -393,15 +399,12 @@ mod tests { use crate::{ ivc::PublicInputs as OriginalTreePublicInputs, query::{ - utils::{QueryBoundSource, QueryBounds}, - public_inputs::{ - PublicInputs as QueryProofPublicInputs, - QueryPublicInputs, - }, computational_hash_ids::AggregationOperation, + public_inputs::{PublicInputs as QueryProofPublicInputs, QueryPublicInputs}, universal_circuit::{ universal_circuit_inputs::Placeholders, universal_query_gadget::OutputValues, }, + utils::{QueryBoundSource, QueryBounds}, }, revelation::{ revelation_without_results_tree::{ @@ -410,7 +413,10 @@ mod tests { tests::{compute_results_from_query_proof_outputs, TestPlaceholders}, PublicInputs, NUM_PREPROCESSING_IO, }, - test_utils::{random_aggregation_operations, random_original_tree_proof, sample_boundary_rows_for_revelation}, + test_utils::{ + random_aggregation_operations, random_original_tree_proof, + sample_boundary_rows_for_revelation, + }, }; // L: maximum number of results @@ -434,7 +440,6 @@ mod tests { } } } - #[derive(Clone, Debug)] struct TestRevelationCircuit<'a> { @@ -516,7 +521,8 @@ mod tests { Some(QueryBoundSource::Constant(max_secondary)), ) .unwrap(); - let (left_boundary_row, right_boundary_row) = sample_boundary_rows_for_revelation(&query_bounds, rng); + let (left_boundary_row, right_boundary_row) = + sample_boundary_rows_for_revelation(&query_bounds, rng); proof[left_row_range].copy_from_slice(&left_boundary_row.to_fields()); proof[right_row_range].copy_from_slice(&right_boundary_row.to_fields()); @@ -665,4 +671,3 @@ mod tests { test_revelation_batching_circuit(&ops, Some(0)); } } - diff --git a/verifiable-db/src/test_utils.rs b/verifiable-db/src/test_utils.rs index 5b508b3e1..2000e0b1c 100644 --- a/verifiable-db/src/test_utils.rs +++ b/verifiable-db/src/test_utils.rs @@ -5,9 +5,16 @@ use crate::{ query::{ computational_hash_ids::{ AggregationOperation, ColumnIDs, Identifiers, Operation, PlaceholderIdentifier, - }, public_inputs::{PublicInputs as QueryPI, PublicInputsFactory, QueryPublicInputs}, row_chunk_gadgets::BoundaryRowData, universal_circuit::{universal_circuit_inputs::{ - BasicOperation, ColumnCell, InputOperand, OutputItem, Placeholders, ResultStructure, - }, universal_query_gadget::OutputValues}, utils::{QueryBoundSource, QueryBounds, QueryHashNonExistenceCircuits} + }, + public_inputs::{PublicInputs as QueryPI, PublicInputsFactory, QueryPublicInputs}, + row_chunk_gadgets::BoundaryRowData, + universal_circuit::{ + universal_circuit_inputs::{ + BasicOperation, ColumnCell, InputOperand, OutputItem, Placeholders, ResultStructure, + }, + universal_query_gadget::OutputValues, + }, + utils::{QueryBoundSource, QueryBounds, QueryHashNonExistenceCircuits}, }, revelation::NUM_PREPROCESSING_IO, }; @@ -20,7 +27,7 @@ use mp2_common::{ }; use mp2_test::utils::{gen_random_field_hash, gen_random_u256}; use plonky2::{ - field::types::{PrimeField64, Sample, Field}, + field::types::{Field, PrimeField64, Sample}, hash::hash_types::HashOut, plonk::config::GenericHashOut, }; @@ -44,7 +51,6 @@ pub const ROW_TREE_MAX_DEPTH: usize = 10; pub const INDEX_TREE_MAX_DEPTH: usize = 15; pub const NUM_COLUMNS: usize = 4; - /// Generate a set of values in a given range ensuring that the i+1-th generated value is /// bigger than the i-th generated value pub(crate) fn gen_values_in_range( @@ -96,7 +102,9 @@ pub fn random_aggregation_operations() -> [F; S] { }) } -impl PublicInputsFactory<'_, F, S, UNIVERSAL_CIRCUIT> { +impl + PublicInputsFactory<'_, F, S, UNIVERSAL_CIRCUIT> +{ pub(crate) fn sample_from_ops(ops: &[F; S]) -> [Vec; NUM_INPUTS] where [(); S - 1]:, @@ -286,23 +294,24 @@ impl TestRevelationData { let computational_hash = non_existence_circuits.computational_hash(); let placeholder_hash = non_existence_circuits.placeholder_hash(); - let [mut query_pi_raw] = QueryPI::::sample_from_ops( - &ops_ids.try_into().unwrap(), - ); - let [min_query_primary, max_query_primary, min_query_secondary, max_query_secondary, p_hash_range, c_hash_range, left_row_range, right_row_range] = [ - QueryPublicInputs::MinPrimary, - QueryPublicInputs::MaxPrimary, - QueryPublicInputs::MinSecondary, - QueryPublicInputs::MaxSecondary, - QueryPublicInputs::PlaceholderHash, - QueryPublicInputs::ComputationalHash, - QueryPublicInputs::LeftBoundaryRow, - QueryPublicInputs::RightBoundaryRow, - ] - .map(QueryPI::::to_range); - + let [mut query_pi_raw] = + QueryPI::::sample_from_ops(&ops_ids.try_into().unwrap()); + let [min_query_primary, max_query_primary, min_query_secondary, max_query_secondary, p_hash_range, c_hash_range, left_row_range, right_row_range] = + [ + QueryPublicInputs::MinPrimary, + QueryPublicInputs::MaxPrimary, + QueryPublicInputs::MinSecondary, + QueryPublicInputs::MaxSecondary, + QueryPublicInputs::PlaceholderHash, + QueryPublicInputs::ComputationalHash, + QueryPublicInputs::LeftBoundaryRow, + QueryPublicInputs::RightBoundaryRow, + ] + .map(QueryPI::::to_range); + // sample left boundary row and right boundary row to satisfy revelation circuit constraints - let (left_boundary_row, right_boundary_row) = sample_boundary_rows_for_revelation(&query_bounds, rng); + let (left_boundary_row, right_boundary_row) = + sample_boundary_rows_for_revelation(&query_bounds, rng); // Set the minimum, maximum query, placeholder hash andn computational hash to expected values. [ @@ -325,20 +334,14 @@ impl TestRevelationData { (p_hash_range, placeholder_hash.to_vec()), (c_hash_range, computational_hash.to_vec()), (left_row_range, left_boundary_row.to_fields()), - (right_row_range, right_boundary_row.to_fields()) + (right_row_range, right_boundary_row.to_fields()), ] .into_iter() .for_each(|(range, fields)| query_pi_raw[range].copy_from_slice(&fields)); let query_pi = QueryPI::::from_slice(&query_pi_raw); - assert_eq!( - query_pi.min_primary(), - query_bounds.min_query_primary(), - ); - assert_eq!( - query_pi.max_primary(), - query_bounds.max_query_primary(), - ); + assert_eq!(query_pi.min_primary(), query_bounds.min_query_primary(),); + assert_eq!(query_pi.max_primary(), query_bounds.max_query_primary(),); assert_eq!( query_pi.min_secondary(), query_bounds.min_query_secondary().value, @@ -393,52 +396,51 @@ pub(crate) fn sample_boundary_rows_for_revelation( query_bounds: &QueryBounds, rng: &mut R, ) -> (BoundaryRowData, BoundaryRowData) { - let min_secondary = *query_bounds.min_query_secondary().value(); - let max_secondary = *query_bounds.max_query_secondary().value(); - let mut left_boundary_row = BoundaryRowData::sample(rng, &query_bounds); - // for predecessor of `left_boundary_row` in index tree, we need to either mark it as - // non-existent or to make its value out of range - if rng.gen() || query_bounds.min_query_primary() == U256::ZERO { - left_boundary_row.index_node_info.predecessor_info.is_found = false; - } else { - let [predecessor_value] = gen_values_in_range( - rng, - U256::ZERO, - query_bounds.min_query_primary() - U256::from(1), - ); - left_boundary_row.index_node_info.predecessor_info.value = predecessor_value; - } - // for predecessor of `left_boundary_row` in rows tree, we need to either mark it as - // non-existent or to make its value out of range - if rng.gen() || min_secondary == U256::ZERO { - left_boundary_row.row_node_info.predecessor_info.is_found = false; - } else { - let [predecessor_value] = - gen_values_in_range(rng, U256::ZERO, min_secondary - U256::from(1)); - left_boundary_row.row_node_info.predecessor_info.value = predecessor_value; - } - let mut right_boundary_row = BoundaryRowData::sample(rng, &query_bounds); - // for successor of `right_boundary_row` in index tree, we need to either mark it as - // non-existent or to make its value out of range - if rng.gen() || query_bounds.max_query_primary() == U256::MAX { - right_boundary_row.index_node_info.successor_info.is_found = false; - } else { - let [successor_value] = gen_values_in_range( - rng, - query_bounds.max_query_primary() + U256::from(1), - U256::MAX, - ); - right_boundary_row.index_node_info.successor_info.value = successor_value; - } - // for successor of `right_boundary_row` in rows tree, we need to either mark it as - // non-existent or to make its value out of range - if rng.gen() || max_secondary == U256::MAX { - right_boundary_row.row_node_info.successor_info.is_found = false; - } else { - let [successor_value] = - gen_values_in_range(rng, max_secondary + U256::from(1), U256::MAX); - right_boundary_row.row_node_info.successor_info.value = successor_value; - } + let min_secondary = *query_bounds.min_query_secondary().value(); + let max_secondary = *query_bounds.max_query_secondary().value(); + let mut left_boundary_row = BoundaryRowData::sample(rng, query_bounds); + // for predecessor of `left_boundary_row` in index tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || query_bounds.min_query_primary() == U256::ZERO { + left_boundary_row.index_node_info.predecessor_info.is_found = false; + } else { + let [predecessor_value] = gen_values_in_range( + rng, + U256::ZERO, + query_bounds.min_query_primary() - U256::from(1), + ); + left_boundary_row.index_node_info.predecessor_info.value = predecessor_value; + } + // for predecessor of `left_boundary_row` in rows tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || min_secondary == U256::ZERO { + left_boundary_row.row_node_info.predecessor_info.is_found = false; + } else { + let [predecessor_value] = + gen_values_in_range(rng, U256::ZERO, min_secondary - U256::from(1)); + left_boundary_row.row_node_info.predecessor_info.value = predecessor_value; + } + let mut right_boundary_row = BoundaryRowData::sample(rng, query_bounds); + // for successor of `right_boundary_row` in index tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || query_bounds.max_query_primary() == U256::MAX { + right_boundary_row.index_node_info.successor_info.is_found = false; + } else { + let [successor_value] = gen_values_in_range( + rng, + query_bounds.max_query_primary() + U256::from(1), + U256::MAX, + ); + right_boundary_row.index_node_info.successor_info.value = successor_value; + } + // for successor of `right_boundary_row` in rows tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || max_secondary == U256::MAX { + right_boundary_row.row_node_info.successor_info.is_found = false; + } else { + let [successor_value] = gen_values_in_range(rng, max_secondary + U256::from(1), U256::MAX); + right_boundary_row.row_node_info.successor_info.value = successor_value; + } - (left_boundary_row, right_boundary_row) -} \ No newline at end of file + (left_boundary_row, right_boundary_row) +} From 2e7db73931aeacee6cea2b004eab7d709492f159 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 6 Dec 2024 13:03:39 +0100 Subject: [PATCH 04/12] Test distinct in revelation circuit tabular queries --- .../revelation/revelation_unproven_offset.rs | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/verifiable-db/src/revelation/revelation_unproven_offset.rs b/verifiable-db/src/revelation/revelation_unproven_offset.rs index 86481c69c..aac6d69aa 100644 --- a/verifiable-db/src/revelation/revelation_unproven_offset.rs +++ b/verifiable-db/src/revelation/revelation_unproven_offset.rs @@ -942,7 +942,7 @@ mod tests { // test function for this revelation circuit. If `distinct` is true, then the // results are enforced to be distinct - async fn test_revelation_unproven_offset_circuit() { + async fn test_revelation_unproven_offset_circuit(distinct: bool) { const ROW_TREE_MAX_DEPTH: usize = 10; const INDEX_TREE_MAX_DEPTH: usize = 10; const L: usize = 5; @@ -1145,8 +1145,17 @@ mod tests { // sample final results and set order-agnostic digests in row_pis proofs accordingly const NUM_ACTUAL_ITEMS_PER_OUTPUT: usize = 4; - let mut results: [[U256; NUM_ACTUAL_ITEMS_PER_OUTPUT]; L] = - array::from_fn(|_| array::from_fn(|_| gen_random_u256(rng))); + let mut results: [[U256; NUM_ACTUAL_ITEMS_PER_OUTPUT]; L] = if distinct { + // generate all the output values distinct from each other; generating at + // random will make them distinct with overwhelming probability + array::from_fn(|_| array::from_fn(|_| gen_random_u256(rng))) + } else { + // generate some values which are the same + let mut res = array::from_fn(|_| array::from_fn(|_| gen_random_u256(rng))); + res[L - 1] = res[0]; + res + }; + // sort them to ensure that DISTINCT constraints are satisfied results.sort_by(|a, b| { let (is_smaller, is_eq) = is_less_than_or_equal_to_u256_arr(a, b); @@ -1261,11 +1270,11 @@ mod tests { #[tokio::test] async fn test_revelation_unproven_offset_circuit_no_distinct() { - test_revelation_unproven_offset_circuit().await + test_revelation_unproven_offset_circuit(false).await } #[tokio::test] async fn test_revelation_unproven_offset_circuit_distinct() { - test_revelation_unproven_offset_circuit().await + test_revelation_unproven_offset_circuit(true).await } } From cf681c395e8151a1064b87c45b1d3fc65dfc716a Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Fri, 6 Dec 2024 15:30:38 +0100 Subject: [PATCH 05/12] Working rust toolchain --- rust-toolchain | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-toolchain b/rust-toolchain index bf867e0ae..a7a456242 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly +nightly-2024-12-03 From 0666bb9ba327edf45bd1899b89d63cf95036a179 Mon Sep 17 00:00:00 2001 From: nicholas-mainardi Date: Wed, 11 Dec 2024 18:50:42 +0100 Subject: [PATCH 06/12] Address review comments --- verifiable-db/Cargo.toml | 1 + verifiable-db/src/lib.rs | 1 + .../src/query/circuits/chunk_aggregation.rs | 31 +++--- .../src/query/circuits/non_existence.rs | 9 +- .../query/circuits/row_chunk_processing.rs | 47 ++++++---- .../src/query/computational_hash_ids.rs | 2 +- verifiable-db/src/query/mod.rs | 4 +- verifiable-db/src/query/output_computation.rs | 14 +-- verifiable-db/src/query/public_inputs.rs | 53 ++++++----- .../row_chunk_gadgets/aggregate_chunks.rs | 8 +- .../universal_query_circuit.rs | 94 +++++++++---------- .../universal_query_gadget.rs | 30 ++++-- verifiable-db/src/query/utils.rs | 4 +- .../src/results_tree/old_public_inputs.rs | 2 - verifiable-db/src/revelation/api.rs | 4 +- .../revelation_without_results_tree.rs | 6 +- verifiable-db/src/test_utils.rs | 4 +- 17 files changed, 183 insertions(+), 131 deletions(-) diff --git a/verifiable-db/Cargo.toml b/verifiable-db/Cargo.toml index 3ab92c430..8b50d33d5 100644 --- a/verifiable-db/Cargo.toml +++ b/verifiable-db/Cargo.toml @@ -30,3 +30,4 @@ tokio.workspace = true [features] original_poseidon = ["mp2_common/original_poseidon"] +results_tree = [] # temporary features to disable compiling results_tree code by default, as it is still WiP diff --git a/verifiable-db/src/lib.rs b/verifiable-db/src/lib.rs index 1cac73092..b9e0856fe 100644 --- a/verifiable-db/src/lib.rs +++ b/verifiable-db/src/lib.rs @@ -12,6 +12,7 @@ pub mod extraction; pub mod ivc; /// Module for circuits for simple queries pub mod query; +#[cfg(feature = "results_tree")] pub mod results_tree; /// Module for the query revelation circuits pub mod revelation; diff --git a/verifiable-db/src/query/circuits/chunk_aggregation.rs b/verifiable-db/src/query/circuits/chunk_aggregation.rs index 93da64d6d..7bac563e3 100644 --- a/verifiable-db/src/query/circuits/chunk_aggregation.rs +++ b/verifiable-db/src/query/circuits/chunk_aggregation.rs @@ -23,7 +23,8 @@ use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; use crate::query::{ - pi_len, public_inputs::PublicInputs, row_chunk_gadgets::aggregate_chunks::aggregate_chunks, + pi_len, public_inputs::PublicInputsQueryCircuits, + row_chunk_gadgets::aggregate_chunks::aggregate_chunks, }; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -48,7 +49,7 @@ impl { pub(crate) fn build( b: &mut CircuitBuilder, - chunk_proofs: &[PublicInputs; NUM_CHUNKS], + chunk_proofs: &[PublicInputsQueryCircuits; NUM_CHUNKS], ) -> ChunkAggregationWires where [(); MAX_NUM_RESULTS - 1]:, @@ -112,7 +113,7 @@ impl b.is_not_equal(row_chunk.chunk_outputs.num_overflows, zero) }; - PublicInputs::::new( + PublicInputsQueryCircuits::::new( &row_chunk.chunk_outputs.tree_hash.to_targets(), &row_chunk.chunk_outputs.values.to_targets(), &[row_chunk.chunk_outputs.count], @@ -161,7 +162,8 @@ where verified_proofs: [&ProofWithPublicInputsTarget; NUM_CHUNKS], _builder_parameters: Self::CircuitBuilderParams, ) -> Self { - let pis = verified_proofs.map(|proof| PublicInputs::from_slice(&proof.public_inputs)); + let pis = verified_proofs + .map(|proof| PublicInputsQueryCircuits::from_slice(&proof.public_inputs)); ChunkAggregationCircuit::build(builder, &pis) } @@ -199,7 +201,7 @@ mod tests { use crate::{ query::{ computational_hash_ids::{AggregationOperation, Identifiers}, - public_inputs::PublicInputs, + public_inputs::PublicInputsQueryCircuits, universal_circuit::universal_query_gadget::OutputValues, utils::tests::aggregate_output_values, }, @@ -249,11 +251,13 @@ mod tests { fn build(c: &mut CircuitBuilder) -> Self::Wires { let raw_pis = array::from_fn(|_| { - c.add_virtual_targets(PublicInputs::::total_len()) + c.add_virtual_targets( + PublicInputsQueryCircuits::::total_len(), + ) }); let pis = raw_pis .iter() - .map(|pi| PublicInputs::from_slice(pi)) + .map(|pi| PublicInputsQueryCircuits::from_slice(pi)) .collect_vec() .try_into() .unwrap(); @@ -283,9 +287,13 @@ mod tests { // if we test with dummy chunks to be aggregated, we generate `ACTUAL_NUM_CHUNKS <= NUM_CHUNKS` // inputs, so that the remaining `NUM_CHUNKS - ACTUAL_NUM_CHUNKS` input slots are dummies const NUM_ACTUAL_CHUNKS: usize = 3; - PublicInputs::::sample_from_ops::(&ops).to_vec() + PublicInputsQueryCircuits::::sample_from_ops::( + &ops, + ) + .to_vec() } else { - PublicInputs::::sample_from_ops::(&ops).to_vec() + PublicInputsQueryCircuits::::sample_from_ops::(&ops) + .to_vec() }; let circuit = TestChunkAggregationCircuit::::new(&raw_pis); @@ -294,7 +302,7 @@ mod tests { let input_pis = raw_pis .iter() - .map(|pi| PublicInputs::::from_slice(pi)) + .map(|pi| PublicInputsQueryCircuits::::from_slice(pi)) .collect_vec(); let (expected_outputs, expected_overflow) = { @@ -326,7 +334,8 @@ mod tests { let expected_left_row = input_pis[0].to_left_row_raw(); let expected_right_row = input_pis.last().unwrap().to_right_row_raw(); - let result_pis = PublicInputs::::from_slice(&proof.public_inputs); + let result_pis = + PublicInputsQueryCircuits::::from_slice(&proof.public_inputs); // check public inputs assert_eq!( diff --git a/verifiable-db/src/query/circuits/non_existence.rs b/verifiable-db/src/query/circuits/non_existence.rs index 68f51b3e1..18cc90303 100644 --- a/verifiable-db/src/query/circuits/non_existence.rs +++ b/verifiable-db/src/query/circuits/non_existence.rs @@ -28,7 +28,7 @@ use crate::query::{ }, output_computation::compute_dummy_output_targets, pi_len, - public_inputs::PublicInputs, + public_inputs::PublicInputsQueryCircuits, row_chunk_gadgets::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget}, universal_circuit::{ ComputationalHash, ComputationalHashTarget, PlaceholderHash, PlaceholderHashTarget, @@ -189,7 +189,7 @@ where // can be dummy values since they are un-used in this circuit let min_secondary = b.zero_u256(); let max_secondary = b.constant_u256(U256::MAX); - PublicInputs::::new( + PublicInputsQueryCircuits::::new( &index_path.root.to_targets(), &outputs, &[zero], // there are no matching rows @@ -291,7 +291,7 @@ mod tests { api::TreePathInputs, merkle_path::{tests::generate_test_tree, NeighborInfo}, output_computation::tests::compute_dummy_output_values, - public_inputs::PublicInputs, + public_inputs::PublicInputsQueryCircuits, row_chunk_gadgets::{BoundaryRowData, BoundaryRowNodeInfo}, universal_circuit::universal_circuit_inputs::Placeholders, utils::{ChildPosition, QueryBounds}, @@ -360,7 +360,8 @@ mod tests { expected_index_node_info: BoundaryRowNodeInfo, expected_query_bounds: &QueryBounds, test_name: &str| { - let pis = PublicInputs::::from_slice(&proof.public_inputs); + let pis = + PublicInputsQueryCircuits::::from_slice(&proof.public_inputs); assert_eq!( pis.tree_hash(), expected_root, diff --git a/verifiable-db/src/query/circuits/row_chunk_processing.rs b/verifiable-db/src/query/circuits/row_chunk_processing.rs index 56a9a75ec..685d45a3e 100644 --- a/verifiable-db/src/query/circuits/row_chunk_processing.rs +++ b/verifiable-db/src/query/circuits/row_chunk_processing.rs @@ -12,7 +12,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::query::{ computational_hash_ids::ColumnIDs, pi_len, - public_inputs::PublicInputs, + public_inputs::PublicInputsQueryCircuits, row_chunk_gadgets::{ aggregate_chunks::aggregate_chunks, row_process_gadget::{RowProcessingGadgetInputWires, RowProcessingGadgetInputs}, @@ -243,7 +243,7 @@ where b.is_not_equal(num_overflows, zero) }; - PublicInputs::::new( + PublicInputsQueryCircuits::::new( &row_chunk.chunk_outputs.tree_hash.to_targets(), &row_chunk.chunk_outputs.values.to_targets(), &[row_chunk.chunk_outputs.count], @@ -287,16 +287,6 @@ where self.universal_query_inputs .assign(pw, &wires.universal_query_inputs); } - - /// This method returns the ids of the placeholders employed to compute the placeholder hash, - /// in the same order, so that those ids can be provided as input to other circuits that need - /// to recompute this hash - #[cfg(test)] // only used in test for now - pub(crate) fn ids_for_placeholder_hash( - &self, - ) -> Vec { - self.universal_query_inputs.ids_for_placeholder_hash() - } } impl< @@ -382,14 +372,14 @@ mod tests { use crate::query::{ circuits::{ - row_chunk_processing::RowChunkProcessingCircuit, + row_chunk_processing::{RowChunkProcessingCircuit, UniversalQueryHashInputs}, tests::{build_test_tree, compute_output_values_for_row}, }, computational_hash_ids::{ AggregationOperation, ColumnIDs, Identifiers, Operation, PlaceholderIdentifier, }, merkle_path::{MerklePathWithNeighborsGadget, NeighborInfo}, - public_inputs::PublicInputs, + public_inputs::PublicInputsQueryCircuits, row_chunk_gadgets::{ row_process_gadget::RowProcessingGadgetInputs, BoundaryRowData, BoundaryRowNodeInfo, }, @@ -773,13 +763,22 @@ mod tests { .unwrap(); // compute placeholder hash for `circuit` - let placeholder_hash_ids = circuit.ids_for_placeholder_hash(); + let placeholder_hash_ids = UniversalQueryHashInputs::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + AggOutputCircuit, + >::ids_for_placeholder_hash( + &predicate_operations, &results, &placeholders, &bounds + ) + .unwrap(); let placeholder_hash = placeholder_hash(&placeholder_hash_ids, &placeholders, &bounds).unwrap(); let proof = run_circuit::(circuit); // check public inputs - let pis = PublicInputs::::from_slice(&proof.public_inputs); + let pis = PublicInputsQueryCircuits::::from_slice(&proof.public_inputs); let root = node_1.node.compute_node_hash(primary_index); assert_eq!(root, pis.tree_hash(),); @@ -1319,13 +1318,25 @@ mod tests { .unwrap(); // compute placeholder hash for `circuit` - let placeholder_hash_ids = circuit.ids_for_placeholder_hash(); + let placeholder_hash_ids = UniversalQueryHashInputs::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + NoAggOutputCircuit, + >::ids_for_placeholder_hash( + &predicate_operations, + &results, + &placeholders, + &query_bounds, + ) + .unwrap(); let placeholder_hash = placeholder_hash(&placeholder_hash_ids, &placeholders, &query_bounds).unwrap(); let proof = run_circuit::(circuit); // check public inputs - let pis = PublicInputs::::from_slice(&proof.public_inputs); + let pis = PublicInputsQueryCircuits::::from_slice(&proof.public_inputs); let root = node_1.node.compute_node_hash(primary_index); assert_eq!(root, pis.tree_hash(),); diff --git a/verifiable-db/src/query/computational_hash_ids.rs b/verifiable-db/src/query/computational_hash_ids.rs index a672f8da7..73a1e1be1 100644 --- a/verifiable-db/src/query/computational_hash_ids.rs +++ b/verifiable-db/src/query/computational_hash_ids.rs @@ -234,7 +234,7 @@ impl ToField for Identifiers { } } /// Data structure to provide identifiers of columns of a table to compute computational hash -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct ColumnIDs { pub(crate) primary: F, pub(crate) secondary: F, diff --git a/verifiable-db/src/query/mod.rs b/verifiable-db/src/query/mod.rs index a99fbfba8..e20a6987f 100644 --- a/verifiable-db/src/query/mod.rs +++ b/verifiable-db/src/query/mod.rs @@ -1,5 +1,5 @@ use plonky2::iop::target::Target; -use public_inputs::PublicInputs; +use public_inputs::PublicInputsQueryCircuits; pub mod api; pub(crate) mod circuits; @@ -12,5 +12,5 @@ pub mod universal_circuit; pub mod utils; pub const fn pi_len() -> usize { - PublicInputs::::total_len() + PublicInputsQueryCircuits::::total_len() } diff --git a/verifiable-db/src/query/output_computation.rs b/verifiable-db/src/query/output_computation.rs index 5c7b4d91e..70b5f232c 100644 --- a/verifiable-db/src/query/output_computation.rs +++ b/verifiable-db/src/query/output_computation.rs @@ -156,7 +156,7 @@ pub(crate) mod tests { use super::*; use crate::{ query::{ - pi_len, public_inputs::PublicInputs, + pi_len, public_inputs::PublicInputsQueryCircuits, universal_circuit::universal_query_gadget::CurveOrU256, utils::tests::compute_output_item_value, }, @@ -177,7 +177,7 @@ pub(crate) mod tests { pub(crate) fn compute_output_item( b: &mut CBuilder, i: usize, - proofs: &[&PublicInputs], + proofs: &[&PublicInputsQueryCircuits], ) -> (Vec, Target) where [(); S - 1]:, @@ -271,7 +271,8 @@ pub(crate) mod tests { }); // Build the public inputs. - let pis = [0; PROOF_NUM].map(|i| PublicInputs::::from_slice(&proofs[i])); + let pis = [0; PROOF_NUM] + .map(|i| PublicInputsQueryCircuits::::from_slice(&proofs[i])); let pis = [0; PROOF_NUM].map(|i| &pis[i]); // Check if the outputs as expected. @@ -308,7 +309,8 @@ pub(crate) mod tests { [(); S - 1]:, { fn new(proofs: [Vec; PROOF_NUM]) -> Self { - let pis = [0; PROOF_NUM].map(|i| PublicInputs::::from_slice(&proofs[i])); + let pis = + [0; PROOF_NUM].map(|i| PublicInputsQueryCircuits::::from_slice(&proofs[i])); let pis = [0; PROOF_NUM].map(|i| &pis[i]); let exp_outputs = array::from_fn(|i| { @@ -333,7 +335,7 @@ pub(crate) mod tests { let ops: [_; S] = random_aggregation_operations(); // Build the input proofs. - let inputs = PublicInputs::::sample_from_ops(&ops); + let inputs = PublicInputsQueryCircuits::::sample_from_ops(&ops); // Construct the test circuit. let test_circuit = TestOutputComputationCircuit::::new(inputs); @@ -354,7 +356,7 @@ pub(crate) mod tests { ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); // Build the input proofs. - let inputs = PublicInputs::::sample_from_ops(&ops); + let inputs = PublicInputsQueryCircuits::::sample_from_ops(&ops); // Construct the test circuit. let test_circuit = TestOutputComputationCircuit::::new(inputs); diff --git a/verifiable-db/src/query/public_inputs.rs b/verifiable-db/src/query/public_inputs.rs index a3b22f50e..3f37294c3 100644 --- a/verifiable-db/src/query/public_inputs.rs +++ b/verifiable-db/src/query/public_inputs.rs @@ -125,9 +125,9 @@ impl From for QueryPublicInputs { } } } -/// Public inputs for generic query circuits -pub type PublicInputs<'a, T, const S: usize> = PublicInputsFactory<'a, T, S, false>; -/// Public inputs for universal query circuit +/// Public inputs for query circuits +pub type PublicInputsQueryCircuits<'a, T, const S: usize> = PublicInputsFactory<'a, T, S, false>; +/// Public inputs only for universal query circuit pub type PublicInputsUniversalCircuit<'a, T, const S: usize> = PublicInputsFactory<'a, T, S, true>; /// This is the data structure employed for both public inputs of generic query circuits @@ -436,7 +436,7 @@ impl } } -impl PublicInputs<'_, Target, S> { +impl PublicInputsQueryCircuits<'_, Target, S> { pub(crate) fn left_boundary_row_target(&self) -> BoundaryRowDataTarget { BoundaryRowDataTarget::from_targets(self.to_left_row_raw()) } @@ -574,7 +574,7 @@ where } } -impl PublicInputs<'_, F, S> { +impl PublicInputsQueryCircuits<'_, F, S> { pub fn min_secondary(&self) -> U256 { U256::from_fields(self.to_min_secondary_raw()) } @@ -613,7 +613,7 @@ pub(crate) mod tests { plonk::circuit_builder::CircuitBuilder, }; - use super::{PublicInputs, QueryPublicInputs}; + use super::{PublicInputsQueryCircuits, QueryPublicInputs}; const S: usize = 10; #[derive(Clone, Debug)] @@ -625,8 +625,10 @@ pub(crate) mod tests { type Wires = Vec; fn build(c: &mut CircuitBuilder) -> Self::Wires { - let targets = c.add_virtual_target_arr::<{ PublicInputs::::total_len() }>(); - let pi_targets = PublicInputs::::from_slice(targets.as_slice()); + let targets = c + .add_virtual_target_arr::<{ PublicInputsQueryCircuits::::total_len() }>( + ); + let pi_targets = PublicInputsQueryCircuits::::from_slice(targets.as_slice()); pi_targets.register_args(c); pi_targets.to_vec() } @@ -638,59 +640,64 @@ pub(crate) mod tests { #[test] fn test_batching_query_public_inputs() { - let pis_raw: Vec = random_vector::(PublicInputs::::total_len()).to_fields(); - let pis = PublicInputs::::from_slice(pis_raw.as_slice()); + let pis_raw: Vec = + random_vector::(PublicInputsQueryCircuits::::total_len()).to_fields(); + let pis = PublicInputsQueryCircuits::::from_slice(pis_raw.as_slice()); // check public inputs are constructed correctly assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::TreeHash)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::TreeHash)], pis.to_hash_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::OutputValues)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::OutputValues)], pis.to_values_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::NumMatching)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::NumMatching)], &[*pis.to_count_raw()], ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::OpIds)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::OpIds)], pis.to_ops_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::LeftBoundaryRow)], + &pis_raw + [PublicInputsQueryCircuits::::to_range(QueryPublicInputs::LeftBoundaryRow)], pis.to_left_row_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::RightBoundaryRow)], + &pis_raw + [PublicInputsQueryCircuits::::to_range(QueryPublicInputs::RightBoundaryRow)], pis.to_right_row_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinPrimary)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::MinPrimary)], pis.to_min_primary_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxPrimary)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::MaxPrimary)], pis.to_max_primary_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinSecondary)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::MinSecondary)], pis.to_min_secondary_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxSecondary)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::MaxSecondary)], pis.to_max_secondary_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::Overflow)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::Overflow)], &[*pis.to_overflow_raw()], ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::ComputationalHash)], + &pis_raw + [PublicInputsQueryCircuits::::to_range(QueryPublicInputs::ComputationalHash)], pis.to_computational_hash_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::PlaceholderHash)], + &pis_raw + [PublicInputsQueryCircuits::::to_range(QueryPublicInputs::PlaceholderHash)], pis.to_placeholder_hash_raw(), ); // use public inputs in circuit diff --git a/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs b/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs index b942968e2..8cd1d1ec4 100644 --- a/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs +++ b/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs @@ -129,7 +129,7 @@ mod tests { tests::{build_node, generate_test_tree}, MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, NeighborInfo, }, - public_inputs::PublicInputs, + public_inputs::PublicInputsQueryCircuits, row_chunk_gadgets::{ tests::RowChunkData, BoundaryRowData, BoundaryRowDataTarget, BoundaryRowNodeInfo, BoundaryRowNodeInfoTarget, RowChunkDataTarget, @@ -445,12 +445,14 @@ mod tests { let root = index_node.compute_node_hash(primary_index_id); // generate the output values associated to each chunk - let inputs = PublicInputs::::sample_from_ops::<2>(&ops); + let inputs = PublicInputsQueryCircuits::::sample_from_ops::<2>(&ops); let [(first_chunk_count, first_chunk_outputs, fist_chunk_num_overflows), (second_chunk_count, second_chunk_outputs, second_chunk_num_overflows)] = inputs .into_iter() .map(|input| { - let pis = PublicInputs::::from_slice(input.as_slice()); + let pis = PublicInputsQueryCircuits::::from_slice( + input.as_slice(), + ); ( pis.num_matching_rows(), OutputValues::from_fields(pis.to_values_raw()), diff --git a/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs b/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs index edecea6b6..3a8522ee9 100644 --- a/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs +++ b/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs @@ -234,13 +234,6 @@ where self.hash_gadget_inputs.assign(pw, &wires.hash_wires); self.value_gadget_inputs.assign(pw, &wires.value_wires); } - - /// This method returns the ids of the placeholders employed to compute the placeholder hash, - /// in the same order, so that those ids can be provided as input to other circuits that need - /// to recompute this hash - pub(crate) fn ids_for_placeholder_hash(&self) -> Vec { - self.hash_gadget_inputs.ids_for_placeholder_hash() - } } pub(crate) fn dummy_placeholder_id() -> PlaceholderId { @@ -462,45 +455,26 @@ where placeholders: &Placeholders, query_bounds: &QueryBounds, ) -> Result<[PlaceholderId; 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]> { - let row_cells = &RowCells::default(); Ok(match results.output_variant { - Output::Aggregation => { - let circuit = UniversalQueryCircuitInputs::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - AggOutputCircuit, - >::new( - row_cells, - predicate_operations, - placeholders, - false, // doesn't matter for placeholder hash computation - query_bounds, - results, - false, // doesn't matter for placeholder hash computation - )?; - circuit.ids_for_placeholder_hash() - } - Output::NoAggregation => { - let circuit = UniversalQueryCircuitInputs::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - NoAggOutputCircuit, - >::new( - row_cells, - predicate_operations, - placeholders, - false, // doesn't matter for placeholder hash computation - query_bounds, - results, - false, // doesn't matter for placeholder hash computation - )?; - circuit.ids_for_placeholder_hash() - } - } + Output::Aggregation => UniversalQueryHashInputs::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + AggOutputCircuit, + >::ids_for_placeholder_hash( + predicate_operations, results, placeholders, query_bounds + ), + Output::NoAggregation => UniversalQueryHashInputs::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + NoAggOutputCircuit, + >::ids_for_placeholder_hash( + predicate_operations, results, placeholders, query_bounds + ), + }? .try_into() .unwrap()) } @@ -548,7 +522,9 @@ mod tests { BasicOperation, ColumnCell, InputOperand, OutputItem, PlaceholderId, Placeholders, ResultStructure, RowCells, }, - universal_query_circuit::{placeholder_hash, UniversalQueryCircuitParams}, + universal_query_circuit::{ + placeholder_hash, UniversalCircuitInput, UniversalQueryCircuitParams, + }, ComputationalHash, }, utils::{QueryBoundSource, QueryBounds}, @@ -892,7 +868,18 @@ mod tests { }) .collect_vec(); - let placeholder_hash_ids = circuit.ids_for_placeholder_hash(); + let placeholder_hash_ids = UniversalCircuitInput::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >::ids_for_placeholder_hash( + &predicate_operations, + &results, + &placeholders, + &query_bounds, + ) + .unwrap(); let placeholder_hash = placeholder_hash(&placeholder_hash_ids, &placeholders, &query_bounds).unwrap(); let computational_hash = ComputationalHash::from_bytes( @@ -1285,7 +1272,18 @@ mod tests { Point::NEUTRAL }; - let placeholder_hash_ids = circuit.ids_for_placeholder_hash(); + let placeholder_hash_ids = UniversalCircuitInput::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >::ids_for_placeholder_hash( + &predicate_operations, + &results, + &placeholders, + &query_bounds, + ) + .unwrap(); let placeholder_hash = placeholder_hash(&placeholder_hash_ids, &placeholders, &query_bounds).unwrap(); let computational_hash = ComputationalHash::from_bytes( diff --git a/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs b/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs index 97eeedaf2..268c2adbe 100644 --- a/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs +++ b/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs @@ -833,15 +833,33 @@ where /// This method returns the ids of the placeholders employed to compute the placeholder hash, /// in the same order, so that those ids can be provided as input to other circuits that need /// to recompute this hash - pub(crate) fn ids_for_placeholder_hash(&self) -> Vec { - self.filtering_predicate_inputs + pub(crate) fn ids_for_placeholder_hash( + predicate_operations: &[BasicOperation], + results: &ResultStructure, + placeholders: &Placeholders, + query_bounds: &QueryBounds, + ) -> Result> { + let hash_input_gadget = Self::new( + &ColumnIDs::default(), + predicate_operations, + placeholders, + query_bounds, + results, + )?; + Ok(hash_input_gadget + .filtering_predicate_inputs .iter() .flat_map(|op_inputs| vec![op_inputs.placeholder_ids[0], op_inputs.placeholder_ids[1]]) - .chain(self.result_values_inputs.iter().flat_map(|op_inputs| { - vec![op_inputs.placeholder_ids[0], op_inputs.placeholder_ids[1]] - })) + .chain( + hash_input_gadget + .result_values_inputs + .iter() + .flat_map(|op_inputs| { + vec![op_inputs.placeholder_ids[0], op_inputs.placeholder_ids[1]] + }), + ) .map(|id| PlaceholderIdentifier::from_fields(&[id])) - .collect_vec() + .collect_vec()) } /// Utility function to compute the `BasicOperationInputs` corresponding to the set of `operations` specified diff --git a/verifiable-db/src/query/utils.rs b/verifiable-db/src/query/utils.rs index 15ba24110..29e217664 100644 --- a/verifiable-db/src/query/utils.rs +++ b/verifiable-db/src/query/utils.rs @@ -416,7 +416,7 @@ pub struct NonExistenceInput { pub(crate) mod tests { use crate::query::{ computational_hash_ids::{AggregationOperation, Identifiers}, - public_inputs::PublicInputs, + public_inputs::PublicInputsQueryCircuits, universal_circuit::universal_query_gadget::{CurveOrU256, OutputValues}, }; use alloy::primitives::U256; @@ -518,7 +518,7 @@ pub(crate) mod tests { /// the proofs. It's the test function corresponding to `compute_output_item`. pub(crate) fn compute_output_item_value( i: usize, - proofs: &[&PublicInputs], + proofs: &[&PublicInputsQueryCircuits], ) -> (Vec, u32) where [(); S - 1]:, diff --git a/verifiable-db/src/results_tree/old_public_inputs.rs b/verifiable-db/src/results_tree/old_public_inputs.rs index 5eb805638..7f6d07b00 100644 --- a/verifiable-db/src/results_tree/old_public_inputs.rs +++ b/verifiable-db/src/results_tree/old_public_inputs.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - use std::iter::once; use alloy::primitives::U256; diff --git a/verifiable-db/src/revelation/api.rs b/verifiable-db/src/revelation/api.rs index c07f032b1..2fd9ae453 100644 --- a/verifiable-db/src/revelation/api.rs +++ b/verifiable-db/src/revelation/api.rs @@ -403,7 +403,7 @@ where TabularQueryOutputModifiers::new( limit, offset, - results_structure.distinct.unwrap_or(false), + results_structure.distinct.unwrap_or_default(), ), placeholder_inputs, )?; @@ -605,7 +605,7 @@ mod tests { ivc::PublicInputs as PreprocessingPI, query::{ computational_hash_ids::{ColumnIDs, Identifiers}, - public_inputs::PublicInputs as QueryPI, + public_inputs::PublicInputsQueryCircuits as QueryPI, }, revelation::{ api::{CircuitInput, Parameters}, diff --git a/verifiable-db/src/revelation/revelation_without_results_tree.rs b/verifiable-db/src/revelation/revelation_without_results_tree.rs index 5def42ae4..45ee1bb91 100644 --- a/verifiable-db/src/revelation/revelation_without_results_tree.rs +++ b/verifiable-db/src/revelation/revelation_without_results_tree.rs @@ -4,7 +4,7 @@ use crate::{ ivc::PublicInputs as OriginalTreePublicInputs, query::{ computational_hash_ids::AggregationOperation, pi_len as query_pi_len, - public_inputs::PublicInputs as QueryProofPublicInputs, + public_inputs::PublicInputsQueryCircuits as QueryProofPublicInputs, }, revelation::PublicInputs, }; @@ -400,7 +400,9 @@ mod tests { ivc::PublicInputs as OriginalTreePublicInputs, query::{ computational_hash_ids::AggregationOperation, - public_inputs::{PublicInputs as QueryProofPublicInputs, QueryPublicInputs}, + public_inputs::{ + PublicInputsQueryCircuits as QueryProofPublicInputs, QueryPublicInputs, + }, universal_circuit::{ universal_circuit_inputs::Placeholders, universal_query_gadget::OutputValues, }, diff --git a/verifiable-db/src/test_utils.rs b/verifiable-db/src/test_utils.rs index 2000e0b1c..864c1451f 100644 --- a/verifiable-db/src/test_utils.rs +++ b/verifiable-db/src/test_utils.rs @@ -6,7 +6,9 @@ use crate::{ computational_hash_ids::{ AggregationOperation, ColumnIDs, Identifiers, Operation, PlaceholderIdentifier, }, - public_inputs::{PublicInputs as QueryPI, PublicInputsFactory, QueryPublicInputs}, + public_inputs::{ + PublicInputsFactory, PublicInputsQueryCircuits as QueryPI, QueryPublicInputs, + }, row_chunk_gadgets::BoundaryRowData, universal_circuit::{ universal_circuit_inputs::{ From 33284df3c558a6ac97f8f64ea5d37011e8209d19 Mon Sep 17 00:00:00 2001 From: T Date: Fri, 13 Dec 2024 14:22:39 +0800 Subject: [PATCH 07/12] fix: add common `derives` for integration (#419) --- mp2-v1/src/query/batching_planner.rs | 21 ++++++++++--------- mp2-v1/src/query/planner.rs | 2 +- .../common/cases/query/aggregated_queries.rs | 2 +- verifiable-db/src/query/api.rs | 2 +- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/mp2-v1/src/query/batching_planner.rs b/mp2-v1/src/query/batching_planner.rs index d0c21d029..0a280e136 100644 --- a/mp2-v1/src/query/batching_planner.rs +++ b/mp2-v1/src/query/batching_planner.rs @@ -10,6 +10,7 @@ use ryhope::{ storage::{updatetree::UpdateTree, WideLineage}, Epoch, }; +use serde::{Deserialize, Serialize}; use verifiable_db::query::{ api::{NodePath, RowInput, TreePathInputs}, computational_hash_ids::ColumnIDs, @@ -195,8 +196,10 @@ async fn generate_chunks( /// /// (2,0) (2,1) (2,2) (2,3) (2,4) /// ``` -#[derive(Clone, Debug, Hash, Eq, PartialEq, Default)] -pub struct UTKey((usize, usize)); +#[derive( + Clone, Copy, Debug, Default, PartialEq, PartialOrd, Ord, Eq, Hash, Serialize, Deserialize, +)] +pub struct UTKey(pub (usize, usize)); impl UTKey { /// Compute the key of the child node of `self` that has `num_left_siblings` @@ -318,15 +321,13 @@ impl ProvingTree { let num_childrens = parent_node.children_keys.len(); let new_child_key = parent_key.children_key(num_childrens); let child_node = ProvingTreeNode { - parent_key: Some(parent_key.clone()), + parent_key: Some(*parent_key), children_keys: vec![], }; // insert new child in the set of children of the parent - parent_node.children_keys.push(new_child_key.clone()); + parent_node.children_keys.push(new_child_key); assert!( - self.nodes - .insert(new_child_key.clone(), child_node) - .is_none(), + self.nodes.insert(new_child_key, child_node).is_none(), "Node with key {:?} already found in the tree", new_child_key ); @@ -339,7 +340,7 @@ impl ProvingTree { }; let root_key = UTKey((0, 0)); assert!( - self.nodes.insert(root_key.clone(), root).is_none(), + self.nodes.insert(root_key, root).is_none(), "Error: root node inserted multiple times" ); root_key @@ -412,7 +413,7 @@ impl ProvingTree { while node_key.is_some() { // place node key in the path let key = node_key.unwrap(); - path.push(key.clone()); + path.push(*key); // fetch key of the parent node, if any node_key = self .nodes @@ -449,7 +450,7 @@ impl UTForChunksBuilder { let path = tree.compute_path_for_leaf(node_index); ( ( - path.last().unwrap().clone(), // chunk node is always a leaf of the tree, so it is the last node + *path.last().unwrap(), // chunk node is always a leaf of the tree, so it is the last node // in the path chunk, ), diff --git a/mp2-v1/src/query/planner.rs b/mp2-v1/src/query/planner.rs index 54734abd4..c73e9039d 100644 --- a/mp2-v1/src/query/planner.rs +++ b/mp2-v1/src/query/planner.rs @@ -65,7 +65,7 @@ impl<'a, C: ContextProvider> NonExistenceInput<'a, C> { } } - pub(crate) async fn find_row_node_for_non_existence( + pub async fn find_row_node_for_non_existence( &self, primary: BlockPrimaryIndex, ) -> anyhow::Result { diff --git a/mp2-v1/tests/common/cases/query/aggregated_queries.rs b/mp2-v1/tests/common/cases/query/aggregated_queries.rs index 8aad454f1..cae029b4d 100644 --- a/mp2-v1/tests/common/cases/query/aggregated_queries.rs +++ b/mp2-v1/tests/common/cases/query/aggregated_queries.rs @@ -237,7 +237,7 @@ pub(crate) async fn prove_query( let proof_key = ProofKey::QueryAggregate(( planner.query.query.clone(), planner.query.placeholders.placeholder_values(), - k.clone(), + *k, )); planner.ctx.storage.store_proof(proof_key.clone(), proof)?; proof_id = Some(proof_key); diff --git a/verifiable-db/src/query/api.rs b/verifiable-db/src/query/api.rs index 18d5d7e0e..58e902a9b 100644 --- a/verifiable-db/src/query/api.rs +++ b/verifiable-db/src/query/api.rs @@ -119,7 +119,7 @@ impl NodePath { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] /// Data structure containing the inputs necessary to prove a query for a row /// of the DB table. pub struct RowInput { From 98d9ed08bb0a3653f14892f0df4764ab117a89ae Mon Sep 17 00:00:00 2001 From: T Date: Mon, 16 Dec 2024 17:17:32 +0800 Subject: [PATCH 08/12] fix: replace the `stream::iter` implementation for fixing `Send` issue in DQ (#420) I add a TODO for the previous code (may have another better solution). --- mp2-v1/src/query/batching_planner.rs | 96 +++++++++++++++------------- 1 file changed, 52 insertions(+), 44 deletions(-) diff --git a/mp2-v1/src/query/batching_planner.rs b/mp2-v1/src/query/batching_planner.rs index 0a280e136..7ce4125c9 100644 --- a/mp2-v1/src/query/batching_planner.rs +++ b/mp2-v1/src/query/batching_planner.rs @@ -119,54 +119,62 @@ async fn generate_chunks( .cloned() .collect::>(); - Ok(stream::iter(sorted_index_values.into_iter()) - .then(async |index_value| { - let index_path = index_cache - .compute_path(&index_value, current_epoch) + let prove_rows = async |index_value| { + let index_path = index_cache + .compute_path(&index_value, current_epoch) + .await + .unwrap_or_else(|| panic!("node with key {index_value} not found in index tree cache")); + let proven_rows = if let Some(matching_rows) = + row_keys_by_epochs.get(&(index_value as Epoch)) + { + let sorted_rows = matching_rows.iter().collect::>(); + stream::iter(sorted_rows.iter()) + .then(async |&row_key| { + compute_input_for_row(&row_cache, row_key, index_value, &index_path, column_ids) + .await + }) + .collect::>() + .await + } else { + let proven_node = non_existence_inputs + .find_row_node_for_non_existence(index_value) .await - .unwrap_or_else(|| { - panic!("node with key {index_value} not found in index tree cache") + .unwrap_or_else(|_| { + panic!("node for non-existence not found for index value {index_value}") }); - let proven_rows = - if let Some(matching_rows) = row_keys_by_epochs.get(&(index_value as Epoch)) { - let sorted_rows = matching_rows.iter().collect::>(); - stream::iter(sorted_rows.iter()) - .then(async |&row_key| { - compute_input_for_row( - &row_cache, - row_key, - index_value, - &index_path, - column_ids, - ) - .await - }) - .collect::>() - .await - } else { - let proven_node = non_existence_inputs - .find_row_node_for_non_existence(index_value) - .await - .unwrap_or_else(|_| { - panic!("node for non-existence not found for index value {index_value}") - }); - let row_input = compute_input_for_row( - non_existence_inputs.row_tree, - &proven_node, - index_value, - &index_path, - column_ids, - ) - .await; - vec![row_input] - }; - proven_rows - }) - .concat() - .await + let row_input = compute_input_for_row( + non_existence_inputs.row_tree, + &proven_node, + index_value, + &index_path, + column_ids, + ) + .await; + vec![row_input] + }; + proven_rows + }; + + // TODO: This implementation causes an error in DQ: + // `implementation of `std::marker::Send` is not general enough` + /* + let chunks = stream::iter(sorted_index_values.into_iter()) + .then(prove_rows) + .concat() + .await + */ + let mut chunks = vec![]; + for index_value in sorted_index_values { + let chunk = prove_rows(index_value).await; + chunks.extend(chunk); + } + + let chunks = chunks .chunks(CHUNK_SIZE) .map(|chunk| chunk.to_vec()) - .collect_vec()) + .collect_vec(); + + Ok(chunks) } /// Key for nodes of the `UTForChunks` employed to From 34abd40abcddadbf390920917c1be7d376ecfa26 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 8 Jan 2025 18:40:24 +0800 Subject: [PATCH 09/12] Fix `test_pidgy_pinguin_mapping_slot`. --- mp2-common/src/eth.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mp2-common/src/eth.rs b/mp2-common/src/eth.rs index ba863d475..cfcc44912 100644 --- a/mp2-common/src/eth.rs +++ b/mp2-common/src/eth.rs @@ -432,7 +432,7 @@ mod test { // holder: 0x188b264aa1456b869c3a92eeed32117ebb835f47 // NFT id https://opensea.io/assets/ethereum/0xbd3531da5cf5857e7cfaa92426877b022e612cf8/1116 let mapping_value = - Address::from_str("0x188B264AA1456B869C3a92eeeD32117EbB835f47").unwrap(); + Address::from_str("0x29469395eAf6f95920E59F858042f0e28D98a20B").unwrap(); let nft_id: u32 = 1116; let mapping_key = left_pad32(&nft_id.to_be_bytes()); let url = get_mainnet_url(); From 5a6424b1e9d0478415b0bdd765616443a53341ee Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 8 Jan 2025 18:55:38 +0800 Subject: [PATCH 10/12] Fix --- mp2-common/src/eth.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mp2-common/src/eth.rs b/mp2-common/src/eth.rs index cfcc44912..af183352b 100644 --- a/mp2-common/src/eth.rs +++ b/mp2-common/src/eth.rs @@ -432,7 +432,7 @@ mod test { // holder: 0x188b264aa1456b869c3a92eeed32117ebb835f47 // NFT id https://opensea.io/assets/ethereum/0xbd3531da5cf5857e7cfaa92426877b022e612cf8/1116 let mapping_value = - Address::from_str("0x29469395eAf6f95920E59F858042f0e28D98a20B").unwrap(); + Address::from_str("0xee5ac9c6db07c26e71207a41e64df42e1a2b05cf").unwrap(); let nft_id: u32 = 1116; let mapping_key = left_pad32(&nft_id.to_be_bytes()); let url = get_mainnet_url(); From b27d873d884ed7d880e1d8bff2a4ad8ddd04215f Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 8 Jan 2025 19:03:25 +0800 Subject: [PATCH 11/12] Remove `test_pidgy_pinguin_mapping_slot` test case which is useless. --- mp2-common/src/eth.rs | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/mp2-common/src/eth.rs b/mp2-common/src/eth.rs index af183352b..3ba641d06 100644 --- a/mp2-common/src/eth.rs +++ b/mp2-common/src/eth.rs @@ -426,39 +426,6 @@ mod test { Ok(()) } - #[tokio::test] - async fn test_pidgy_pinguin_mapping_slot() -> Result<()> { - // first pinguin holder https://dune.com/queries/2450476/4027653 - // holder: 0x188b264aa1456b869c3a92eeed32117ebb835f47 - // NFT id https://opensea.io/assets/ethereum/0xbd3531da5cf5857e7cfaa92426877b022e612cf8/1116 - let mapping_value = - Address::from_str("0xee5ac9c6db07c26e71207a41e64df42e1a2b05cf").unwrap(); - let nft_id: u32 = 1116; - let mapping_key = left_pad32(&nft_id.to_be_bytes()); - let url = get_mainnet_url(); - let provider = ProviderBuilder::new().on_http(url.parse().unwrap()); - - // extracting from - // https://github.com/OpenZeppelin/openzeppelin-contracts/blob/master/contracts/token/ERC721/ERC721.sol - // assuming it's using ERC731Enumerable that inherits ERC721 - let mapping_slot = 2; - // pudgy pinguins - let pudgy_address = Address::from_str("0xBd3531dA5CF5857e7CfAA92426877b022e612cf8")?; - let query = ProofQuery::new_mapping_slot(pudgy_address, mapping_slot, mapping_key.to_vec()); - let res = query - .query_mpt_proof(&provider, BlockNumberOrTag::Latest) - .await?; - let raw_address = ProofQuery::verify_storage_proof(&res)?; - // the value is actually RLP encoded ! - let decoded_address: Vec = rlp::decode(&raw_address).unwrap(); - let leaf_node: Vec> = rlp::decode_list(res.storage_proof[0].proof.last().unwrap()); - println!("leaf_node[1].len() = {}", leaf_node[1].len()); - // this is read in the same order - let found_address = Address::from_slice(&decoded_address.into_iter().collect::>()); - assert_eq!(found_address, mapping_value); - Ok(()) - } - #[tokio::test] async fn test_kashish_contract_proof_query() -> Result<()> { // https://sepolia.etherscan.io/address/0xd6a2bFb7f76cAa64Dad0d13Ed8A9EFB73398F39E#code From 10528ae1cd2b9db02990a1449192d48408b135e0 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 8 Jan 2025 19:09:47 +0800 Subject: [PATCH 12/12] Fix lint. --- mp2-common/src/eth.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mp2-common/src/eth.rs b/mp2-common/src/eth.rs index 3ba641d06..ee8eda75b 100644 --- a/mp2-common/src/eth.rs +++ b/mp2-common/src/eth.rs @@ -286,7 +286,7 @@ mod test { types::MAX_BLOCK_LEN, utils::{Endianness, Packer}, }; - use mp2_test::eth::{get_mainnet_url, get_sepolia_url}; + use mp2_test::eth::get_sepolia_url; #[tokio::test] #[ignore]