Skip to content

Commit

Permalink
Move columns params from templates to runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
x-mass committed Mar 5, 2024
1 parent c975db4 commit d6d8d18
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 157 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ namespace nil {
bool verification_only = false;
CurvesVariant elliptic_curve_type = type_identity<nil::crypto3::algebra::curves::pallas>{};
HashesVariant hash_type = type_identity<nil::crypto3::hashes::keccak_1600<256>>{};
ColumnsParams columns = all_columns_params[0];
LambdaParam lambda = all_lambda_params[0];
GrindParam grind = all_grind_params[0];
std::size_t component_constant_columns = 5;
std::size_t expand_factor = 2;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,6 @@
namespace nil {
namespace proof_generator {

// Available in runtime columns params. Order is:
// witness | public_input | component_constant | component_selector |
// lookup_constant | lookup_selector
constexpr std::array<ColumnsParams, 4> all_columns_params = {{
{15, 1, 5, 50, 30, 6},
{15, 1, 5, 60, 0, 0},
{15, 1, 2, 50, 14, 6},
{15, 1, 5, 30, 30, 6}
// Add more params as needed.
}};

constexpr std::array<LambdaParam, 1> all_lambda_params = {
9
// Add more params as needed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,6 @@
namespace nil {
namespace proof_generator {

struct ColumnsParams {
std::size_t witness_columns;
std::size_t public_input_columns;
std::size_t component_constant_columns;
std::size_t component_selector_columns;
std::size_t lookup_constant_columns;
std::size_t lookup_selector_columns;

bool operator==(const ColumnsParams& other) const {
return witness_columns == other.witness_columns && public_input_columns == other.public_input_columns
&& component_constant_columns == other.component_constant_columns
&& component_selector_columns == other.component_selector_columns
&& lookup_constant_columns == other.lookup_constant_columns
&& lookup_selector_columns == other.lookup_selector_columns;
}
};

// Need this class to be derived into actual params, so we could overload
// read/write operators for parsing.
class SizeTParam {
Expand Down
65 changes: 20 additions & 45 deletions bin/proof-generator/include/nil/proof-generator/prover.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,25 +130,23 @@ namespace nil {
}
} // namespace detail

template<
typename CurveType,
typename HashType,
std::size_t ColumnsParamsIdx,
std::size_t LambdaParamIdx,
std::size_t GrindParamIdx>
template<typename CurveType, typename HashType, std::size_t LambdaParamIdx, std::size_t GrindParamIdx>
class Prover {
public:
Prover(
boost::filesystem::path circuit_file_name,
boost::filesystem::path preprocessed_common_data_file_name,
boost::filesystem::path assignment_table_file_name,
boost::filesystem::path proof_file,
std::size_t component_constant_columns, // We need it to calculate permutation size, and it couldn't be
// established form assignment table yet
std::size_t expand_factor
)
: circuit_file_(circuit_file_name)
, preprocessed_common_data_file_(preprocessed_common_data_file_name)
, assignment_table_file_(assignment_table_file_name)
, proof_file_(proof_file)
, component_constant_columns_(component_constant_columns)
, expand_factor_(expand_factor) {
}

Expand Down Expand Up @@ -238,39 +236,20 @@ namespace nil {
}

private:
// C++20 allows passing non-type template param by value, so we could make the
// current function use
// columns_policy type directly instead of index after standard upgrade.
// clang-format off
static constexpr std::size_t WitnessColumns = all_columns_params[ColumnsParamsIdx].witness_columns;
static constexpr std::size_t PublicInputColumns = all_columns_params[ColumnsParamsIdx].public_input_columns;
static constexpr std::size_t ComponentConstantColumns = all_columns_params[ColumnsParamsIdx].component_constant_columns;
static constexpr std::size_t LookupConstantColumns = all_columns_params[ColumnsParamsIdx].lookup_constant_columns;
static constexpr std::size_t ConstantColumns = ComponentConstantColumns + LookupConstantColumns;
static constexpr std::size_t ComponentSelectorColumns = all_columns_params[ColumnsParamsIdx].component_selector_columns;
static constexpr std::size_t LookupSelectorColumns = all_columns_params[ColumnsParamsIdx].lookup_selector_columns;
static constexpr std::size_t SelectorColumns = ComponentSelectorColumns + LookupSelectorColumns;
// clang-format on

using ArithmetizationParams = nil::crypto3::zk::snark::
plonk_arithmetization_params<WitnessColumns, PublicInputColumns, ConstantColumns, SelectorColumns>;
using BlueprintField = typename CurveType::base_field_type;
using LpcParams = nil::crypto3::zk::commitments::
list_polynomial_commitment_params<HashType, HashType, all_lambda_params[LambdaParamIdx], 2>;
using Lpc = nil::crypto3::zk::commitments::list_polynomial_commitment<BlueprintField, LpcParams>;
using LpcScheme = typename nil::crypto3::zk::commitments::lpc_commitment_scheme<Lpc>;
using CircuitParams =
nil::crypto3::zk::snark::placeholder_circuit_params<BlueprintField, ArithmetizationParams>;
using CircuitParams = nil::crypto3::zk::snark::placeholder_circuit_params<BlueprintField>;
using PlaceholderParams = nil::crypto3::zk::snark::placeholder_params<CircuitParams, LpcScheme>;
using Proof = nil::crypto3::zk::snark::placeholder_proof<BlueprintField, PlaceholderParams>;
using PublicPreprocessedData = typename nil::crypto3::zk::snark::
placeholder_public_preprocessor<BlueprintField, PlaceholderParams>::preprocessed_data_type;
using PrivatePreprocessedData = typename nil::crypto3::zk::snark::
placeholder_private_preprocessor<BlueprintField, PlaceholderParams>::preprocessed_data_type;
using ConstraintSystem =
nil::crypto3::zk::snark::plonk_constraint_system<BlueprintField, ArithmetizationParams>;
using TableDescription =
nil::crypto3::zk::snark::plonk_table_description<BlueprintField, ArithmetizationParams>;
using ConstraintSystem = nil::crypto3::zk::snark::plonk_constraint_system<BlueprintField>;
using TableDescription = nil::crypto3::zk::snark::plonk_table_description<BlueprintField>;
using Endianness = nil::marshalling::option::big_endian;
using FriParams = typename Lpc::fri_type::params_type;

Expand All @@ -280,6 +259,7 @@ namespace nil {
nil::crypto3::zk::snark::placeholder_verifier<BlueprintField, PlaceholderParams>::process(
*public_preprocessed_data_,
proof,
*table_description_,
*constraint_system_,
*lpc_scheme_
);
Expand All @@ -300,8 +280,7 @@ namespace nil {
nil::crypto3::marshalling::types::plonk_constraint_system<TTypeBase, ConstraintSystem>;

using Column = nil::crypto3::zk::snark::plonk_column<BlueprintField>;
using AssignmentTable =
nil::crypto3::zk::snark::plonk_table<BlueprintField, ArithmetizationParams, Column>;
using AssignmentTable = nil::crypto3::zk::snark::plonk_table<BlueprintField, Column>;

{
auto marshalled_value = detail::decode_marshalling_from_file<ConstraintMarshalling>(circuit_file_);
Expand All @@ -317,21 +296,16 @@ namespace nil {

using TableValueMarshalling =
nil::crypto3::marshalling::types::plonk_assignment_table<TTypeBase, AssignmentTable>;
AssignmentTable assignment_table;
{
TableDescription table_description;
auto marshalled_table =
detail::decode_marshalling_from_file<TableValueMarshalling>(assignment_table_file_);
if (!marshalled_table) {
return false;
}
std::tie(table_description.usable_rows_amount, assignment_table) =
nil::crypto3::marshalling::types::make_assignment_table<Endianness, AssignmentTable>(
*marshalled_table
);
table_description.rows_amount = assignment_table.rows_amount();
table_description_.emplace(table_description);
auto marshalled_table =
detail::decode_marshalling_from_file<TableValueMarshalling>(assignment_table_file_);
if (!marshalled_table) {
return false;
}
auto [table_description, assignment_table] =
nil::crypto3::marshalling::types::make_assignment_table<Endianness, AssignmentTable>(
*marshalled_table
);
table_description_.emplace(table_description);

// Lambdas and grinding bits should be passed threw preprocessor directives
std::size_t table_rows_log = std::ceil(std::log2(table_description_->rows_amount));
Expand All @@ -341,7 +315,7 @@ namespace nil {
);

std::size_t permutation_size = table_description_->witness_columns
+ table_description_->public_input_columns + ComponentConstantColumns;
+ table_description_->public_input_columns + component_constant_columns_;
lpc_scheme_.emplace(*fri_params_);

BOOST_LOG_TRIVIAL(info) << "Preprocessing public data";
Expand Down Expand Up @@ -369,6 +343,7 @@ namespace nil {
const boost::filesystem::path assignment_table_file_;
const boost::filesystem::path proof_file_;
const std::size_t expand_factor_;
const std::size_t component_constant_columns_;

// All set on prepare_for_operation()
std::optional<PublicPreprocessedData> public_preprocessed_data_;
Expand Down
64 changes: 15 additions & 49 deletions bin/proof-generator/src/arg_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,6 @@ namespace nil {
namespace proof_generator {
namespace po = boost::program_options;

void print_all_columns_params() {
std::cout << "Available Policies:\n";
std::cout << "Index: witness, public input, component constant, component "
"selector, lookup constant, "
"lookup selector\n";

for (std::size_t i = 0; i < all_columns_params.size(); ++i) {
const auto& params = all_columns_params[i];
std::cout << std::setw(5) << i << ":\t" << params.witness_columns << "," << params.public_input_columns
<< "," << params.component_constant_columns << "," << params.component_selector_columns << ","
<< params.lookup_constant_columns << "," << params.lookup_selector_columns << "\n";
}
}

void check_exclusive_options(const po::variables_map& vm, const std::vector<std::string>& opts) {
std::vector<std::string> found_opts;
for (const auto& opt : opts) {
Expand All @@ -55,6 +41,11 @@ namespace nil {
}
}

template<typename T>
po::typed_value<T>* make_defaulted_option(T& variable) {
return po::value(&variable)->default_value(variable);
}

std::optional<ProverOptions> parse_args(int argc, char* argv[]) {
po::options_description options("Nil; Proof Generator Options");
// Declare a group of options that will be
Expand All @@ -64,8 +55,7 @@ namespace nil {
generic.add_options()
("help,h", "Produce help message")
("version,v", "Print version string")
("config,c", po::value<std::string>(), "Config file path")
("list-columns-params", "Print available columns params");
("config,c", po::value<std::string>(), "Config file path");
// clang-format on

ProverOptions prover_options;
Expand All @@ -80,17 +70,17 @@ namespace nil {
);
// clang-format off
auto options_appender = config.add_options()
("proof,p", po::value(&prover_options.proof_file_path)->default_value(prover_options.proof_file_path), "Output proof file")
("common-data", po::value(&prover_options.preprocessed_common_data_path)->default_value(prover_options.preprocessed_common_data_path), "Output preprocessed common data file")
("proof,p", make_defaulted_option(prover_options.proof_file_path), "Output proof file")
("common-data", make_defaulted_option(prover_options.preprocessed_common_data_path), "Output preprocessed common data file")
("circuit,c", po::value(&prover_options.circuit_file_path)->required(), "Circuit input file")
("assignment-table,t", po::value(&prover_options.assignment_table_file_path)->required(), "Assignment table input file")
("log-level,l", po::value(&prover_options.log_level)->default_value(prover_options.log_level), "Log level (trace, debug, info, warning, error, fatal)")
("elliptic-curve-type,e", po::value(&prover_options.elliptic_curve_type)->default_value(prover_options.elliptic_curve_type), "Elliptic curve type (pallas)")
("hash-type", po::value(&prover_options.hash_type)->default_value(prover_options.hash_type), "Hash type (keccak)")
("columns-params", po::value(&prover_options.columns)->default_value(prover_options.columns), "Columns params, use --list-columns-params to list")
("lambda-param", po::value(&prover_options.lambda)->default_value(prover_options.lambda), "Lambda param (9)")
("grind-param", po::value(&prover_options.lambda)->default_value(prover_options.lambda), "Grind param (69)")
("expand-factor", po::value(&prover_options.expand_factor)->default_value(prover_options.expand_factor), "Expand factor")
("log-level,l", make_defaulted_option(prover_options.log_level), "Log level (trace, debug, info, warning, error, fatal)")
("elliptic-curve-type,e", make_defaulted_option(prover_options.elliptic_curve_type), "Elliptic curve type (pallas)")
("hash-type", make_defaulted_option(prover_options.hash_type), "Hash type (keccak)")
("lambda-param", make_defaulted_option(prover_options.lambda), "Lambda param (9)")
("grind-param", make_defaulted_option(prover_options.grind), "Grind param (69)")
("expand-factor", make_defaulted_option(prover_options.expand_factor), "Expand factor")
("component-constant-columns", make_defaulted_option(prover_options.component_constant_columns), "Component constant columns")
("skip-verification", po::bool_switch(&prover_options.skip_verification), "Skip generated proof verifying step")
("verification-only", po::bool_switch(&prover_options.verification_only), "Read proof for verification instead of writing to it");
// clang-format on
Expand Down Expand Up @@ -134,11 +124,6 @@ namespace nil {
}
}

if (vm.count("list-columns-params")) {
print_all_columns_params();
return std::nullopt;
}

// Calling notify(vm) after handling no-op cases prevent parser from alarming
// about absence of required args
try {
Expand All @@ -163,25 +148,6 @@ namespace nil {
// >> and << operators are needed for Boost porgram_options to read values and
// to print default values to help message: The rest of the file contains them:

std::ostream& operator<<(std::ostream& strm, const ColumnsParams& columns) {
auto it = std::find(all_columns_params.cbegin(), all_columns_params.cend(), columns);
strm << std::distance(all_columns_params.cbegin(), it);
return strm;
}

std::istream& operator>>(std::istream& strm, ColumnsParams& columns) {
std::string str;
strm >> str;
std::size_t pos;
int idx = std::stoi(str, &pos);
if (pos < str.size() || idx < 0 || static_cast<std::size_t>(idx) >= all_columns_params.size()) {
strm.setstate(std::ios_base::failbit);
} else {
columns = all_columns_params[idx];
}
return strm;
}

std::ostream& operator<<(std::ostream& strm, const LambdaParam& lambda) {
strm << static_cast<size_t>(lambda);
return strm;
Expand Down
Loading

0 comments on commit d6d8d18

Please sign in to comment.