Skip to content

Commit

Permalink
Separate preprocessing from proving and verification by adding a para…
Browse files Browse the repository at this point in the history
…meter --stage to proof producer
  • Loading branch information
martun committed Aug 26, 2024
1 parent 9928ef4 commit 0f3290c
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 68 deletions.
2 changes: 2 additions & 0 deletions bin/proof-producer/include/nil/proof-generator/arg_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ namespace nil {
typename tuple_to_variant<typename transform_tuple<HashTypes, to_type_identity>::type>::type;

struct ProverOptions {
std::string stage = "all";
boost::filesystem::path proof_file_path = "proof.bin";
boost::filesystem::path json_file_path = "proof.json";
boost::filesystem::path preprocessed_common_data_path = "preprocessed_common_data.dat";
boost::filesystem::path preprocessed_public_data_path = "preprocessed_data.dat";
boost::filesystem::path circuit_file_path;
boost::filesystem::path assignment_table_file_path;
boost::log::trivial::severity_level log_level = boost::log::trivial::severity_level::info;
Expand Down
190 changes: 130 additions & 60 deletions bin/proof-producer/include/nil/proof-generator/prover.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,21 @@

#include <boost/log/trivial.hpp>

#include <nil/marshalling/endianness.hpp>
#include <nil/marshalling/field_type.hpp>
#include <nil/marshalling/status_type.hpp>

#include <nil/crypto3/algebra/fields/arithmetic_params/pallas.hpp>
#include <nil/crypto3/marshalling/zk/types/commitments/eval_storage.hpp>
#include <nil/crypto3/marshalling/zk/types/commitments/lpc.hpp>
#include <nil/crypto3/marshalling/zk/types/placeholder/common_data.hpp>
#include <nil/crypto3/marshalling/zk/types/placeholder/preprocessed_public_data.hpp>
#include <nil/crypto3/marshalling/zk/types/placeholder/proof.hpp>
#include <nil/crypto3/marshalling/zk/types/plonk/assignment_table.hpp>
#include <nil/crypto3/marshalling/zk/types/plonk/constraint_system.hpp>

#include <nil/crypto3/math/algorithms/calculate_domain_set.hpp>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>
#include <nil/crypto3/zk/snark/arithmetization/plonk/params.hpp>
#include <nil/crypto3/zk/snark/systems/plonk/placeholder/detail/placeholder_policy.hpp>
Expand All @@ -45,9 +52,8 @@
#include <nil/crypto3/zk/snark/systems/plonk/placeholder/verifier.hpp>

#include <nil/blueprint/transpiler/recursive_verifier_generator.hpp>
#include <nil/marshalling/endianness.hpp>
#include <nil/marshalling/field_type.hpp>
#include <nil/marshalling/status_type.hpp>


#include <nil/proof-generator/arithmetization_params.hpp>
#include <nil/proof-generator/file_operations.hpp>

Expand All @@ -68,7 +74,7 @@ namespace nil {
auto read_iter = v->begin();
auto status = marshalled_data.read(read_iter, v->size());
if (status != nil::marshalling::status_type::success) {
BOOST_LOG_TRIVIAL(error) << "Marshalled structure decoding failed";
BOOST_LOG_TRIVIAL(error) << "When reading a Marshalled structure from file " << path << ", decoding step failed";
return std::nullopt;
}
return marshalled_data;
Expand All @@ -92,62 +98,59 @@ namespace nil {
return hex ? write_vector_to_hex_file(v, path.c_str()) : write_vector_to_file(v, path.c_str());
}

std::vector<std::size_t> generate_random_step_list(const std::size_t r, const int max_step) {
using Distribution = std::uniform_int_distribution<int>;
static std::random_device random_engine;

std::vector<std::size_t> step_list;
std::size_t steps_sum = 0;
while (steps_sum != r) {
if (r - steps_sum <= max_step) {
while (r - steps_sum != 1) {
step_list.emplace_back(r - steps_sum - 1);
steps_sum += step_list.back();
}
step_list.emplace_back(1);
steps_sum += step_list.back();
} else {
step_list.emplace_back(Distribution(1, max_step)(random_engine));
steps_sum += step_list.back();
}
enum class ProverStage {
ALL = 0,
PREPROCESS = 1,
PROVE = 2,
VERIFY = 3
};

ProverStage prover_stage_from_string(const std::string& stage) {
static std::unordered_map<std::string, ProverStage> stage_map = {
{"all", ProverStage::ALL},
{"preprocess", ProverStage::PREPROCESS},
{"prove", ProverStage::PROVE},
{"verify", ProverStage::VERIFY}
};
auto it = stage_map.find(stage);
if (it == stage_map.end()) {
throw std::invalid_argument("Invalid stage: " + stage);
}
return step_list;
return it->second;
}

} // namespace detail


template<typename CurveType, typename HashType>
class Prover {
public:
Prover(
boost::filesystem::path circuit_file_name,
boost::filesystem::path preprocessed_common_data_file_name,
boost::filesystem::path assignment_table_file_name,
boost::filesystem::path proof_file,
boost::filesystem::path json_file,
std::size_t lambda,
std::size_t expand_factor,
std::size_t max_q_chunks,
std::size_t grind
)
: circuit_file_(circuit_file_name)
, preprocessed_common_data_file_(preprocessed_common_data_file_name)
, assignment_table_file_(assignment_table_file_name)
, proof_file_(proof_file)
, json_file_(json_file)
, lambda_(lambda)
, expand_factor_(expand_factor)
, max_quotient_chunks_(max_q_chunks)
, grind_(grind) {
}

bool generate_to_file(bool skip_verification) {
// The caller must call the preprocessor or load the preprocessed data before calling this function.
bool generate_to_file(
boost::filesystem::path proof_file_,
boost::filesystem::path json_file_,
bool skip_verification) {
if (!nil::proof_generator::can_write_to_file(proof_file_.string())) {
BOOST_LOG_TRIVIAL(error) << "Can't write to file " << proof_file_;
return false;
}

prepare_for_operation();

BOOST_ASSERT(public_preprocessed_data_);
BOOST_ASSERT(private_preprocessed_data_);
BOOST_ASSERT(table_description_);
Expand Down Expand Up @@ -182,7 +185,9 @@ namespace nil {
true
);
if (res) {
BOOST_LOG_TRIVIAL(info) << "Proof written";
BOOST_LOG_TRIVIAL(info) << "Proof written.";
} else {
BOOST_LOG_TRIVIAL(error) << "Failed to write proof to file.";
}

BOOST_LOG_TRIVIAL(info) << "Writing json proof to " << json_file_;
Expand All @@ -204,10 +209,14 @@ namespace nil {
return res;
}

bool verify_from_file() {
prepare_for_operation();
bool verify_from_file(boost::filesystem::path proof_file_, boost::filesystem::path preprocessed_data_file) {
read_circuit();
read_assignment_table();
read_public_preprocessed_data_from_file(preprocessed_data_file);

using ProofMarshalling = nil::crypto3::marshalling::types::
placeholder_proof<nil::marshalling::field_type<Endianness>, Proof>;

BOOST_LOG_TRIVIAL(info) << "Reading proof from file";
auto marshalled_proof = detail::decode_marshalling_from_file<ProofMarshalling>(proof_file_, true);
if (!marshalled_proof) {
Expand All @@ -217,12 +226,12 @@ namespace nil {
verify(nil::crypto3::marshalling::types::make_placeholder_proof<Endianness, Proof>(*marshalled_proof
));
if (res) {
BOOST_LOG_TRIVIAL(info) << "Proof verified";
BOOST_LOG_TRIVIAL(info) << "Proof verification passed.";
}
return res;
}

bool save_preprocessed_common_data_to_file() {
bool save_preprocessed_common_data_to_file(boost::filesystem::path preprocessed_common_data_file) {
BOOST_LOG_TRIVIAL(info) << "Writing preprocessed common data to file...";
using Endianness = nil::marshalling::option::big_endian;
using TTypeBase = nil::marshalling::field_type<Endianness>;
Expand All @@ -232,16 +241,60 @@ namespace nil {
public_preprocessed_data_->common_data
);
bool res = nil::proof_generator::detail::encode_marshalling_to_file(
preprocessed_common_data_file_,
preprocessed_common_data_file,
marshalled_common_data
);
if (res) {
BOOST_LOG_TRIVIAL(info) << "Preprocessed common data written";
BOOST_LOG_TRIVIAL(info) << "Preprocessed common data written.";
}
return res;
}

private:
// This includes not only the common data, but also merkle trees, polynomials, etc, everything that a
// public preprocessor generates.
bool save_public_preprocessed_data_to_file(boost::filesystem::path preprocessed_data_file) {
using namespace nil::crypto3::marshalling::types;

BOOST_LOG_TRIVIAL(info) << "Writing all preprocessed public data to " <<
preprocessed_data_file << std::endl;
using Endianness = nil::marshalling::option::big_endian;
using TTypeBase = nil::marshalling::field_type<Endianness>;
using PreprocessedPublicDataType = typename PublicPreprocessedData::preprocessed_data_type;

auto marshalled_preprocessed_public_data =
fill_placeholder_preprocessed_public_data<Endianness, PreprocessedPublicDataType>(
*public_preprocessed_data_
);
bool res = nil::proof_generator::detail::encode_marshalling_to_file(
preprocessed_data_file,
marshalled_preprocessed_public_data
);
if (res) {
BOOST_LOG_TRIVIAL(info) << "Preprocessed public data written.";
}
return res;
}

bool read_public_preprocessed_data_from_file(boost::filesystem::path preprocessed_data_file) {
using namespace nil::crypto3::marshalling::types;

using BlueprintField = typename CurveType::base_field_type;
using TTypeBase = nil::marshalling::field_type<Endianness>;
using PreprocessedPublicDataType = typename PublicPreprocessedData::preprocessed_data_type;
using PublicPreprocessedDataMarshalling =
placeholder_preprocessed_public_data<TTypeBase, PreprocessedPublicDataType>;

auto marshalled_value = detail::decode_marshalling_from_file<PublicPreprocessedDataMarshalling>(
preprocessed_data_file);
if (!marshalled_value) {
return false;
}
public_preprocessed_data_.emplace(
make_placeholder_preprocessed_public_data<Endianness, PreprocessedPublicDataType>(*marshalled_value)
);
return true;
}

using BlueprintField = typename CurveType::base_field_type;
using LpcParams = nil::crypto3::zk::commitments::list_polynomial_commitment_params<HashType, HashType, 2>;
using Lpc = nil::crypto3::zk::commitments::list_polynomial_commitment<BlueprintField, LpcParams>;
Expand Down Expand Up @@ -280,23 +333,27 @@ namespace nil {
return verification_result;
}

bool prepare_for_operation() {
bool read_circuit() {
using BlueprintField = typename CurveType::base_field_type;
using TTypeBase = nil::marshalling::field_type<Endianness>;
using ConstraintMarshalling =
nil::crypto3::marshalling::types::plonk_constraint_system<TTypeBase, ConstraintSystem>;

{
auto marshalled_value = detail::decode_marshalling_from_file<ConstraintMarshalling>(circuit_file_);
if (!marshalled_value) {
return false;
}
constraint_system_.emplace(
nil::crypto3::marshalling::types::make_plonk_constraint_system<Endianness, ConstraintSystem>(
*marshalled_value
)
);
auto marshalled_value = detail::decode_marshalling_from_file<ConstraintMarshalling>(circuit_file_);
if (!marshalled_value) {
return false;
}
constraint_system_.emplace(
nil::crypto3::marshalling::types::make_plonk_constraint_system<Endianness, ConstraintSystem>(
*marshalled_value
)
);
return true;
}

bool read_assignment_table() {
using BlueprintField = typename CurveType::base_field_type;
using TTypeBase = nil::marshalling::field_type<Endianness>;

using TableValueMarshalling =
nil::crypto3::marshalling::types::plonk_assignment_table<TTypeBase, AssignmentTable>;
Expand All @@ -309,11 +366,17 @@ namespace nil {
nil::crypto3::marshalling::types::make_assignment_table<Endianness, AssignmentTable>(
*marshalled_table
);

public_inputs_.emplace(assignment_table.public_inputs());
table_description_.emplace(table_description);
assignment_table_.emplace(std::move(assignment_table));
return true;
}

bool preprocess_public_data() {
using BlueprintField = typename CurveType::base_field_type;

public_inputs_.emplace(assignment_table_->public_inputs());

// Lambdas and grinding bits should be passed threw preprocessor directives
// Lambdas and grinding bits should be passed through preprocessor directives
std::size_t table_rows_log = std::ceil(std::log2(table_description_->rows_amount));

fri_params_.emplace(FriParams(1, table_rows_log, lambda_, expand_factor_));
Expand All @@ -324,37 +387,44 @@ namespace nil {
nil::crypto3::zk::snark::placeholder_public_preprocessor<BlueprintField, PlaceholderParams>::
process(
*constraint_system_,
assignment_table.move_public_table(),
assignment_table_->move_public_table(),
*table_description_,
*lpc_scheme_,
max_quotient_chunks_
)
);
return true;
}

bool preprocess_private_data() {
using BlueprintField = typename CurveType::base_field_type;

BOOST_LOG_TRIVIAL(info) << "Preprocessing private data";
private_preprocessed_data_.emplace(
nil::crypto3::zk::snark::placeholder_private_preprocessor<BlueprintField, PlaceholderParams>::
process(*constraint_system_, assignment_table.move_private_table(), *table_description_)
process(*constraint_system_, assignment_table_->move_private_table(), *table_description_)
);

// This is the last stage of preprocessor, and the assignment table is not used after this function call.
assignment_table_.reset();

return true;
}

private:
const boost::filesystem::path circuit_file_;
const boost::filesystem::path preprocessed_common_data_file_;
const boost::filesystem::path assignment_table_file_;
const boost::filesystem::path proof_file_;
const boost::filesystem::path json_file_;
const std::size_t expand_factor_;
const std::size_t max_quotient_chunks_;
const std::size_t lambda_;
const std::size_t grind_;

// All set on prepare_for_operation()
std::optional<PublicPreprocessedData> public_preprocessed_data_;
std::optional<PrivatePreprocessedData> private_preprocessed_data_;
std::optional<typename AssignmentTable::public_input_container_type> public_inputs_;
std::optional<TableDescription> table_description_;
std::optional<ConstraintSystem> constraint_system_;
std::optional<AssignmentTable> assignment_table_;
std::optional<FriParams> fri_params_;
std::optional<LpcScheme> lpc_scheme_;
};
Expand Down
7 changes: 5 additions & 2 deletions bin/proof-producer/src/arg_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,12 @@ namespace nil {
);
// clang-format off
auto options_appender = config.add_options()
("stage", make_defaulted_option(prover_options.stage),
"Stage of the prover to run, one of (all, preprocess, prove, verify). Defaults to 'all'.")
("proof,p", make_defaulted_option(prover_options.proof_file_path), "Output proof file")
("json,j", make_defaulted_option(prover_options.json_file_path), "JSON proof file")
("common-data,d", make_defaulted_option(prover_options.preprocessed_common_data_path), "Output preprocessed common data file")
("preprocessed-data,d", make_defaulted_option(prover_options.preprocessed_public_data_path), "Output preprocessed public data file")
("circuit", po::value(&prover_options.circuit_file_path)->required(), "Circuit input file")
("assignment-table,t", po::value(&prover_options.assignment_table_file_path)->required(), "Assignment table input file")
("log-level,l", make_defaulted_option(prover_options.log_level), "Log level (trace, debug, info, warning, error, fatal)")
Expand All @@ -84,8 +87,8 @@ namespace nil {
("grind-param", make_defaulted_option(prover_options.grind), "Grind param (69)")
("expand-factor,x", make_defaulted_option(prover_options.expand_factor), "Expand factor")
("max-quotient-chunks,q", make_defaulted_option(prover_options.max_quotient_chunks), "Maximum quotient polynomial parts amount")
("skip-verification", po::bool_switch(&prover_options.skip_verification), "Skip generated proof verifying step")
("verification-only", po::bool_switch(&prover_options.verification_only), "Read proof for verification instead of writing to it");
("skip-verification", po::bool_switch(&prover_options.skip_verification), "Skip generated proof verifying step");

// clang-format on
po::options_description cmdline_options("nil; Proof Producer");
cmdline_options.add(generic).add(config);
Expand Down
Loading

0 comments on commit 0f3290c

Please sign in to comment.