From 0f3290c3a7c1a4df279db590a65427cca520ff42 Mon Sep 17 00:00:00 2001 From: Martun Karapetyan Date: Mon, 26 Aug 2024 13:26:10 +0400 Subject: [PATCH] Separate preprocessing from proving and verification by adding a parameter --stage to proof producer --- .../nil/proof-generator/arg_parser.hpp | 2 + .../include/nil/proof-generator/prover.hpp | 190 ++++++++++++------ bin/proof-producer/src/arg_parser.cpp | 7 +- bin/proof-producer/src/main.cpp | 45 ++++- 4 files changed, 176 insertions(+), 68 deletions(-) diff --git a/bin/proof-producer/include/nil/proof-generator/arg_parser.hpp b/bin/proof-producer/include/nil/proof-generator/arg_parser.hpp index fe00a924..c174432d 100644 --- a/bin/proof-producer/include/nil/proof-generator/arg_parser.hpp +++ b/bin/proof-producer/include/nil/proof-generator/arg_parser.hpp @@ -34,9 +34,11 @@ namespace nil { typename tuple_to_variant::type>::type; struct ProverOptions { + std::string stage = "all"; boost::filesystem::path proof_file_path = "proof.bin"; boost::filesystem::path json_file_path = "proof.json"; boost::filesystem::path preprocessed_common_data_path = "preprocessed_common_data.dat"; + boost::filesystem::path preprocessed_public_data_path = "preprocessed_data.dat"; boost::filesystem::path circuit_file_path; boost::filesystem::path assignment_table_file_path; boost::log::trivial::severity_level log_level = boost::log::trivial::severity_level::info; diff --git a/bin/proof-producer/include/nil/proof-generator/prover.hpp b/bin/proof-producer/include/nil/proof-generator/prover.hpp index ac6fbcad..6d6cae33 100644 --- a/bin/proof-producer/include/nil/proof-generator/prover.hpp +++ b/bin/proof-producer/include/nil/proof-generator/prover.hpp @@ -26,14 +26,21 @@ #include +#include +#include +#include + #include #include #include #include +#include #include #include #include + #include + #include #include #include @@ -45,9 +52,8 @@ #include #include -#include -#include -#include + + #include #include @@ -68,7 +74,7 @@ namespace nil { auto read_iter = v->begin(); auto status = marshalled_data.read(read_iter, v->size()); if (status != nil::marshalling::status_type::success) { - BOOST_LOG_TRIVIAL(error) << "Marshalled structure decoding failed"; + BOOST_LOG_TRIVIAL(error) << "When reading a Marshalled structure from file " << path << ", decoding step failed"; return std::nullopt; } return marshalled_data; @@ -92,62 +98,59 @@ namespace nil { return hex ? write_vector_to_hex_file(v, path.c_str()) : write_vector_to_file(v, path.c_str()); } - std::vector generate_random_step_list(const std::size_t r, const int max_step) { - using Distribution = std::uniform_int_distribution; - static std::random_device random_engine; - - std::vector step_list; - std::size_t steps_sum = 0; - while (steps_sum != r) { - if (r - steps_sum <= max_step) { - while (r - steps_sum != 1) { - step_list.emplace_back(r - steps_sum - 1); - steps_sum += step_list.back(); - } - step_list.emplace_back(1); - steps_sum += step_list.back(); - } else { - step_list.emplace_back(Distribution(1, max_step)(random_engine)); - steps_sum += step_list.back(); - } + enum class ProverStage { + ALL = 0, + PREPROCESS = 1, + PROVE = 2, + VERIFY = 3 + }; + + ProverStage prover_stage_from_string(const std::string& stage) { + static std::unordered_map stage_map = { + {"all", ProverStage::ALL}, + {"preprocess", ProverStage::PREPROCESS}, + {"prove", ProverStage::PROVE}, + {"verify", ProverStage::VERIFY} + }; + auto it = stage_map.find(stage); + if (it == stage_map.end()) { + throw std::invalid_argument("Invalid stage: " + stage); } - return step_list; + return it->second; } + } // namespace detail + template class Prover { public: Prover( boost::filesystem::path circuit_file_name, - boost::filesystem::path preprocessed_common_data_file_name, boost::filesystem::path assignment_table_file_name, - boost::filesystem::path proof_file, - boost::filesystem::path json_file, std::size_t lambda, std::size_t expand_factor, std::size_t max_q_chunks, std::size_t grind ) : circuit_file_(circuit_file_name) - , preprocessed_common_data_file_(preprocessed_common_data_file_name) , assignment_table_file_(assignment_table_file_name) - , proof_file_(proof_file) - , json_file_(json_file) , lambda_(lambda) , expand_factor_(expand_factor) , max_quotient_chunks_(max_q_chunks) , grind_(grind) { } - bool generate_to_file(bool skip_verification) { + // The caller must call the preprocessor or load the preprocessed data before calling this function. + bool generate_to_file( + boost::filesystem::path proof_file_, + boost::filesystem::path json_file_, + bool skip_verification) { if (!nil::proof_generator::can_write_to_file(proof_file_.string())) { BOOST_LOG_TRIVIAL(error) << "Can't write to file " << proof_file_; return false; } - prepare_for_operation(); - BOOST_ASSERT(public_preprocessed_data_); BOOST_ASSERT(private_preprocessed_data_); BOOST_ASSERT(table_description_); @@ -182,7 +185,9 @@ namespace nil { true ); if (res) { - BOOST_LOG_TRIVIAL(info) << "Proof written"; + BOOST_LOG_TRIVIAL(info) << "Proof written."; + } else { + BOOST_LOG_TRIVIAL(error) << "Failed to write proof to file."; } BOOST_LOG_TRIVIAL(info) << "Writing json proof to " << json_file_; @@ -204,10 +209,14 @@ namespace nil { return res; } - bool verify_from_file() { - prepare_for_operation(); + bool verify_from_file(boost::filesystem::path proof_file_, boost::filesystem::path preprocessed_data_file) { + read_circuit(); + read_assignment_table(); + read_public_preprocessed_data_from_file(preprocessed_data_file); + using ProofMarshalling = nil::crypto3::marshalling::types:: placeholder_proof, Proof>; + BOOST_LOG_TRIVIAL(info) << "Reading proof from file"; auto marshalled_proof = detail::decode_marshalling_from_file(proof_file_, true); if (!marshalled_proof) { @@ -217,12 +226,12 @@ namespace nil { verify(nil::crypto3::marshalling::types::make_placeholder_proof(*marshalled_proof )); if (res) { - BOOST_LOG_TRIVIAL(info) << "Proof verified"; + BOOST_LOG_TRIVIAL(info) << "Proof verification passed."; } return res; } - bool save_preprocessed_common_data_to_file() { + bool save_preprocessed_common_data_to_file(boost::filesystem::path preprocessed_common_data_file) { BOOST_LOG_TRIVIAL(info) << "Writing preprocessed common data to file..."; using Endianness = nil::marshalling::option::big_endian; using TTypeBase = nil::marshalling::field_type; @@ -232,16 +241,60 @@ namespace nil { public_preprocessed_data_->common_data ); bool res = nil::proof_generator::detail::encode_marshalling_to_file( - preprocessed_common_data_file_, + preprocessed_common_data_file, marshalled_common_data ); if (res) { - BOOST_LOG_TRIVIAL(info) << "Preprocessed common data written"; + BOOST_LOG_TRIVIAL(info) << "Preprocessed common data written."; } return res; } - private: + // This includes not only the common data, but also merkle trees, polynomials, etc, everything that a + // public preprocessor generates. + bool save_public_preprocessed_data_to_file(boost::filesystem::path preprocessed_data_file) { + using namespace nil::crypto3::marshalling::types; + + BOOST_LOG_TRIVIAL(info) << "Writing all preprocessed public data to " << + preprocessed_data_file << std::endl; + using Endianness = nil::marshalling::option::big_endian; + using TTypeBase = nil::marshalling::field_type; + using PreprocessedPublicDataType = typename PublicPreprocessedData::preprocessed_data_type; + + auto marshalled_preprocessed_public_data = + fill_placeholder_preprocessed_public_data( + *public_preprocessed_data_ + ); + bool res = nil::proof_generator::detail::encode_marshalling_to_file( + preprocessed_data_file, + marshalled_preprocessed_public_data + ); + if (res) { + BOOST_LOG_TRIVIAL(info) << "Preprocessed public data written."; + } + return res; + } + + bool read_public_preprocessed_data_from_file(boost::filesystem::path preprocessed_data_file) { + using namespace nil::crypto3::marshalling::types; + + using BlueprintField = typename CurveType::base_field_type; + using TTypeBase = nil::marshalling::field_type; + using PreprocessedPublicDataType = typename PublicPreprocessedData::preprocessed_data_type; + using PublicPreprocessedDataMarshalling = + placeholder_preprocessed_public_data; + + auto marshalled_value = detail::decode_marshalling_from_file( + preprocessed_data_file); + if (!marshalled_value) { + return false; + } + public_preprocessed_data_.emplace( + make_placeholder_preprocessed_public_data(*marshalled_value) + ); + return true; + } + using BlueprintField = typename CurveType::base_field_type; using LpcParams = nil::crypto3::zk::commitments::list_polynomial_commitment_params; using Lpc = nil::crypto3::zk::commitments::list_polynomial_commitment; @@ -280,23 +333,27 @@ namespace nil { return verification_result; } - bool prepare_for_operation() { + bool read_circuit() { using BlueprintField = typename CurveType::base_field_type; using TTypeBase = nil::marshalling::field_type; using ConstraintMarshalling = nil::crypto3::marshalling::types::plonk_constraint_system; - { - auto marshalled_value = detail::decode_marshalling_from_file(circuit_file_); - if (!marshalled_value) { - return false; - } - constraint_system_.emplace( - nil::crypto3::marshalling::types::make_plonk_constraint_system( - *marshalled_value - ) - ); + auto marshalled_value = detail::decode_marshalling_from_file(circuit_file_); + if (!marshalled_value) { + return false; } + constraint_system_.emplace( + nil::crypto3::marshalling::types::make_plonk_constraint_system( + *marshalled_value + ) + ); + return true; + } + + bool read_assignment_table() { + using BlueprintField = typename CurveType::base_field_type; + using TTypeBase = nil::marshalling::field_type; using TableValueMarshalling = nil::crypto3::marshalling::types::plonk_assignment_table; @@ -309,11 +366,17 @@ namespace nil { nil::crypto3::marshalling::types::make_assignment_table( *marshalled_table ); - - public_inputs_.emplace(assignment_table.public_inputs()); table_description_.emplace(table_description); + assignment_table_.emplace(std::move(assignment_table)); + return true; + } + + bool preprocess_public_data() { + using BlueprintField = typename CurveType::base_field_type; + + public_inputs_.emplace(assignment_table_->public_inputs()); - // Lambdas and grinding bits should be passed threw preprocessor directives + // Lambdas and grinding bits should be passed through preprocessor directives std::size_t table_rows_log = std::ceil(std::log2(table_description_->rows_amount)); fri_params_.emplace(FriParams(1, table_rows_log, lambda_, expand_factor_)); @@ -324,37 +387,44 @@ namespace nil { nil::crypto3::zk::snark::placeholder_public_preprocessor:: process( *constraint_system_, - assignment_table.move_public_table(), + assignment_table_->move_public_table(), *table_description_, *lpc_scheme_, max_quotient_chunks_ ) ); + return true; + } + + bool preprocess_private_data() { + using BlueprintField = typename CurveType::base_field_type; BOOST_LOG_TRIVIAL(info) << "Preprocessing private data"; private_preprocessed_data_.emplace( nil::crypto3::zk::snark::placeholder_private_preprocessor:: - process(*constraint_system_, assignment_table.move_private_table(), *table_description_) + process(*constraint_system_, assignment_table_->move_private_table(), *table_description_) ); + + // This is the last stage of preprocessor, and the assignment table is not used after this function call. + assignment_table_.reset(); + return true; } + private: const boost::filesystem::path circuit_file_; - const boost::filesystem::path preprocessed_common_data_file_; const boost::filesystem::path assignment_table_file_; - const boost::filesystem::path proof_file_; - const boost::filesystem::path json_file_; const std::size_t expand_factor_; const std::size_t max_quotient_chunks_; const std::size_t lambda_; const std::size_t grind_; - // All set on prepare_for_operation() std::optional public_preprocessed_data_; std::optional private_preprocessed_data_; std::optional public_inputs_; std::optional table_description_; std::optional constraint_system_; + std::optional assignment_table_; std::optional fri_params_; std::optional lpc_scheme_; }; diff --git a/bin/proof-producer/src/arg_parser.cpp b/bin/proof-producer/src/arg_parser.cpp index 06bf3b9f..599f02ce 100644 --- a/bin/proof-producer/src/arg_parser.cpp +++ b/bin/proof-producer/src/arg_parser.cpp @@ -72,9 +72,12 @@ namespace nil { ); // clang-format off auto options_appender = config.add_options() + ("stage", make_defaulted_option(prover_options.stage), + "Stage of the prover to run, one of (all, preprocess, prove, verify). Defaults to 'all'.") ("proof,p", make_defaulted_option(prover_options.proof_file_path), "Output proof file") ("json,j", make_defaulted_option(prover_options.json_file_path), "JSON proof file") ("common-data,d", make_defaulted_option(prover_options.preprocessed_common_data_path), "Output preprocessed common data file") + ("preprocessed-data,d", make_defaulted_option(prover_options.preprocessed_public_data_path), "Output preprocessed public data file") ("circuit", po::value(&prover_options.circuit_file_path)->required(), "Circuit input file") ("assignment-table,t", po::value(&prover_options.assignment_table_file_path)->required(), "Assignment table input file") ("log-level,l", make_defaulted_option(prover_options.log_level), "Log level (trace, debug, info, warning, error, fatal)") @@ -84,8 +87,8 @@ namespace nil { ("grind-param", make_defaulted_option(prover_options.grind), "Grind param (69)") ("expand-factor,x", make_defaulted_option(prover_options.expand_factor), "Expand factor") ("max-quotient-chunks,q", make_defaulted_option(prover_options.max_quotient_chunks), "Maximum quotient polynomial parts amount") - ("skip-verification", po::bool_switch(&prover_options.skip_verification), "Skip generated proof verifying step") - ("verification-only", po::bool_switch(&prover_options.verification_only), "Read proof for verification instead of writing to it"); + ("skip-verification", po::bool_switch(&prover_options.skip_verification), "Skip generated proof verifying step"); + // clang-format on po::options_description cmdline_options("nil; Proof Producer"); cmdline_options.add(generic).add(config); diff --git a/bin/proof-producer/src/main.cpp b/bin/proof-producer/src/main.cpp index 15d37be3..eede2d20 100644 --- a/bin/proof-producer/src/main.cpp +++ b/bin/proof-producer/src/main.cpp @@ -34,10 +34,7 @@ int run_prover(const nil::proof_generator::ProverOptions& prover_options) { auto prover_task = [&] { auto prover = nil::proof_generator::Prover( prover_options.circuit_file_path, - prover_options.preprocessed_common_data_path, prover_options.assignment_table_file_path, - prover_options.proof_file_path, - prover_options.json_file_path, prover_options.lambda, prover_options.expand_factor, prover_options.max_quotient_chunks, @@ -45,9 +42,45 @@ int run_prover(const nil::proof_generator::ProverOptions& prover_options) { ); bool prover_result; try { - prover_result = prover_options.verification_only ? prover.verify_from_file() - : prover.generate_to_file(prover_options.skip_verification) - && prover.save_preprocessed_common_data_to_file(); + switch (nil::proof_generator::detail::prover_stage_from_string(prover_options.stage)) { + case nil::proof_generator::detail::ProverStage::ALL: + prover_result = + prover.read_circuit() && + prover.read_assignment_table() && + prover.preprocess_public_data() && + prover.preprocess_private_data() && + prover.generate_to_file( + prover_options.proof_file_path, + prover_options.json_file_path, + false/*don't skip verification*/) && + prover.save_preprocessed_common_data_to_file(prover_options.preprocessed_common_data_path) && + prover.save_public_preprocessed_data_to_file(prover_options.preprocessed_public_data_path); + break; + case nil::proof_generator::detail::ProverStage::PREPROCESS: + prover_result = + prover.read_circuit() && + prover.read_assignment_table() && + prover.preprocess_public_data() && + prover.save_preprocessed_common_data_to_file(prover_options.preprocessed_common_data_path) && + prover.save_public_preprocessed_data_to_file(prover_options.preprocessed_public_data_path); + break; + case nil::proof_generator::detail::ProverStage::PROVE: + // Load preprocessed data from file and generate the proof. + prover_result = + prover.read_circuit() && + prover.read_assignment_table() && + prover.read_public_preprocessed_data_from_file(prover_options.preprocessed_public_data_path) && + prover.preprocess_private_data() && + prover.generate_to_file( + prover_options.proof_file_path, + prover_options.json_file_path, + true/*skip verification*/); + break; + case nil::proof_generator::detail::ProverStage::VERIFY: + prover_result = prover.verify_from_file( + prover_options.proof_file_path, prover_options.preprocessed_public_data_path); + break; + } } catch (const std::exception& e) { BOOST_LOG_TRIVIAL(error) << e.what(); return 1;