Skip to content

Commit

Permalink
write assignment tables directly to output stream instead of using in…
Browse files Browse the repository at this point in the history
…termediate vectors and table #569
  • Loading branch information
CblPOK-git committed Mar 28, 2024
1 parent 9dd594f commit 63bac68
Showing 1 changed file with 96 additions and 103 deletions.
199 changes: 96 additions & 103 deletions bin/assigner/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ void print_circuit(const circuit_proxy<ArithmetizationType> &circuit_proxy,

std::vector<std::uint8_t> cv;
cv.resize(filled_val.length(), 0x00);
auto write_iter = cv.begin();
nil::marshalling::status_type status = filled_val.write(write_iter, cv.size());
auto cv_iter = cv.begin();
nil::marshalling::status_type status = filled_val.write(cv_iter, cv.size());
out.write(reinterpret_cast<char*>(cv.data()), cv.size());
}

Expand All @@ -173,11 +173,51 @@ enum class print_column_kind {
SELECTOR
};

template<typename ValueType, typename ContainerType>
void fill_vector_value(std::vector<ValueType> &table_values, const ContainerType &table_col, typename std::vector<ValueType>::iterator start) {
std::copy(table_col.begin(), table_col.end(), start);
template<typename Endianness>
void print_size_t(
std::size_t input,
std::ostream &out
) {
using TTypeBase = nil::marshalling::field_type<Endianness>;
auto integer_container = nil::marshalling::types::integral<TTypeBase, std::size_t>(input);
std::vector<std::uint8_t> char_vector;
char_vector.resize(integer_container.length(), 0x00);
auto write_iter = char_vector.begin();
nil::marshalling::status_type status = integer_container.write(write_iter, char_vector.size());
out.write(reinterpret_cast<char*>(char_vector.data()), char_vector.size());
}

template<typename Endianness, typename ArithmetizationType>
void print_field(
const typename assignment_proxy<ArithmetizationType>::field_type::value_type &input,
std::ostream &out
) {
using TTypeBase = nil::marshalling::field_type<Endianness>;
using AssignmentTableType = assignment_proxy<ArithmetizationType>;
auto field_container = nil::crypto3::marshalling::types::field_element<TTypeBase, typename AssignmentTableType::field_type::value_type>(input);
std::vector<std::uint8_t> char_vector;
char_vector.resize(field_container.length(), 0x00);
auto write_iter = char_vector.begin();
nil::marshalling::status_type status = field_container.write(write_iter, char_vector.size());
out.write(reinterpret_cast<char*>(char_vector.data()), char_vector.size());
}

template<typename Endianness, typename ArithmetizationType, typename ContainerType>
void print_vector_value(
const std::size_t padded_rows_amount,
const ContainerType &table_col,
std::ostream &out
) {
for (std::size_t i = 0; i < padded_rows_amount; i++) {
if (i < table_col.size()) {
print_field<Endianness, ArithmetizationType>(table_col[i], out);
} else {
print_field<Endianness, ArithmetizationType>(0, out);
}
}
}


template<typename Endianness, typename ArithmetizationType, typename BlueprintFieldType>
void print_assignment_table(const assignment_proxy<ArithmetizationType> &table_proxy,
print_table_kind print_kind,
Expand All @@ -203,7 +243,6 @@ void print_assignment_table(const assignment_proxy<ArithmetizationType> &table_p
max_public_inputs_size = std::max(max_public_inputs_size, table_proxy.public_input_column_size(i));
}

auto calc_params_start = std::chrono::high_resolution_clock::now();
if (print_kind == print_table_kind::MULTI_PROVER) {
total_columns = witness_size + shared_size + public_input_size + constant_size + selector_size;
std::uint32_t max_shared_size = 0;
Expand Down Expand Up @@ -247,171 +286,122 @@ void print_assignment_table(const assignment_proxy<ArithmetizationType> &table_p
nil::crypto3::marshalling::types::plonk_assignment_table<TTypeBase, AssignmentTableType>;

using column_type = typename crypto3::zk::snark::plonk_column<BlueprintFieldType>;
auto calc_params_duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - calc_params_start);
BOOST_LOG_TRIVIAL(debug) << "calc_params_duration: " << calc_params_duration.count() << "ms";

auto fill_columns_start = std::chrono::high_resolution_clock::now();
std::vector<typename AssignmentTableType::field_type::value_type> table_witness_values( padded_rows_amount * witness_size , 0);
std::vector<typename AssignmentTableType::field_type::value_type> table_public_input_values(padded_rows_amount * (public_input_size + shared_size), 0);
std::vector<typename AssignmentTableType::field_type::value_type> table_constant_values( padded_rows_amount * constant_size, 0);
std::vector<typename AssignmentTableType::field_type::value_type> table_selector_values( padded_rows_amount * selector_size, 0);

print_size_t<Endianness>(witness_size, out);
print_size_t<Endianness>(public_input_size + shared_size, out);
print_size_t<Endianness>(constant_size, out);
print_size_t<Endianness>(selector_size, out);
print_size_t<Endianness>(usable_rows_amount, out);
print_size_t<Endianness>(padded_rows_amount, out);

if (print_kind == print_table_kind::SINGLE_PROVER) {
auto it = table_witness_values.begin();
print_size_t<Endianness>(witness_size * padded_rows_amount, out);
for (std::uint32_t i = 0; i < witness_size; i++) {
fill_vector_value<typename AssignmentTableType::field_type::value_type, column_type>
(table_witness_values, table_proxy.witness(i), it);
it += padded_rows_amount;
print_vector_value<Endianness, ArithmetizationType, column_type>(padded_rows_amount, table_proxy.witness(i), out);
}
it = table_public_input_values.begin();
print_size_t<Endianness>((public_input_size + shared_size) * padded_rows_amount, out);
for (std::uint32_t i = 0; i < public_input_size; i++) {
fill_vector_value<typename AssignmentTableType::field_type::value_type, column_type>
(table_public_input_values, table_proxy.public_input(i), it);
it += padded_rows_amount;
print_vector_value<Endianness, ArithmetizationType, column_type>(padded_rows_amount, table_proxy.public_input(i), out);
}
it = table_constant_values.begin();
print_size_t<Endianness>(constant_size * padded_rows_amount, out);
for (std::uint32_t i = 0; i < constant_size; i++) {
fill_vector_value<typename AssignmentTableType::field_type::value_type, column_type>
(table_constant_values, table_proxy.constant(i), it);
it += padded_rows_amount;
print_vector_value<Endianness, ArithmetizationType, column_type>(padded_rows_amount, table_proxy.constant(i), out);
}
it = table_selector_values.begin();
print_size_t<Endianness>(selector_size * padded_rows_amount, out);
for (std::uint32_t i = 0; i < selector_size; i++) {
fill_vector_value<typename AssignmentTableType::field_type::value_type, column_type>
(table_selector_values, table_proxy.selector(i), it);
it += padded_rows_amount;
print_vector_value<Endianness, ArithmetizationType, column_type>(padded_rows_amount, table_proxy.selector(i), out);
}
} else {
const auto& rows = table_proxy.get_used_rows();
const auto& selector_rows = table_proxy.get_used_selector_rows();
std::uint32_t witness_idx = 0;

// witness
print_size_t<Endianness>(witness_size * padded_rows_amount, out);
for( std::size_t i = 0; i < witness_size; i++ ){
const auto column_size = table_proxy.witness_column_size(i);
std::uint32_t offset = 0;
for(const auto& j : rows){
if (j < column_size) {
table_witness_values[witness_idx + offset] = table_proxy.witness(i, j);
print_field<Endianness, ArithmetizationType>(table_proxy.witness(i, j), out);
offset++;
}
}
ASSERT(offset < padded_rows_amount);
while(offset < padded_rows_amount) {
print_field<Endianness, ArithmetizationType>(0, out);
offset++;
}
witness_idx += padded_rows_amount;
}
// public input
std::uint32_t pub_inp_idx = 0;
auto it_pub_inp = table_public_input_values.begin();
print_size_t<Endianness>((public_input_size + shared_size) * padded_rows_amount, out);
for (std::uint32_t i = 0; i < public_input_size; i++) {
fill_vector_value<typename AssignmentTableType::field_type::value_type, column_type>
(table_public_input_values, table_proxy.public_input(i), it_pub_inp);
it_pub_inp += padded_rows_amount;
print_vector_value<Endianness, ArithmetizationType, column_type>(padded_rows_amount, table_proxy.public_input(i), out);
pub_inp_idx += padded_rows_amount;
}
for (std::uint32_t i = 0; i < shared_size; i++) {
fill_vector_value<typename AssignmentTableType::field_type::value_type, column_type>
(table_public_input_values, table_proxy.shared(i), it_pub_inp);
it_pub_inp += padded_rows_amount;
print_vector_value<Endianness, ArithmetizationType, column_type>(padded_rows_amount, table_proxy.shared(i), out);
pub_inp_idx += padded_rows_amount;
}
// constant
print_size_t<Endianness>(constant_size * padded_rows_amount, out);
std::uint32_t constant_idx = 0;
for (std::uint32_t i = 0; i < ComponentConstantColumns; i++) {
const auto column_size = table_proxy.constant_column_size(i);
std::uint32_t offset = 0;
for(const auto& j : rows){
if (j < column_size) {
table_constant_values[constant_idx + offset] = table_proxy.constant(i, j);
print_field<Endianness, ArithmetizationType>(table_proxy.constant(i, j), out);
offset++;
}
}
ASSERT(offset < padded_rows_amount);
while(offset < padded_rows_amount) {
print_field<Endianness, ArithmetizationType>(0, out);
offset++;
}

constant_idx += padded_rows_amount;
}

auto it_const = table_constant_values.begin() + constant_idx;
for (std::uint32_t i = ComponentConstantColumns; i < constant_size; i++) {
fill_vector_value<typename AssignmentTableType::field_type::value_type, column_type>
(table_constant_values, table_proxy.constant(i), it_const);
it_const += padded_rows_amount;
print_vector_value<Endianness, ArithmetizationType, column_type>(padded_rows_amount, table_proxy.constant(i), out);
constant_idx += padded_rows_amount;
}

// selector
print_size_t<Endianness>(selector_size * padded_rows_amount, out);
std::uint32_t selector_idx = 0;
for (std::uint32_t i = 0; i < ComponentSelectorColumns; i++) {
const auto column_size = table_proxy.selector_column_size(i);
std::uint32_t offset = 0;
for(const auto& j : rows){
if (j < column_size) {
if (selector_rows.find(j) != selector_rows.end()) {

table_selector_values[selector_idx + offset] = table_proxy.selector(i, j);
print_field<Endianness, ArithmetizationType>(table_proxy.selector(i, j), out);
} else {
print_field<Endianness, ArithmetizationType>(0, out);
}
offset++;
}
}
ASSERT(offset < padded_rows_amount);
while(offset < padded_rows_amount) {
print_field<Endianness, ArithmetizationType>(0, out);
offset++;
}

selector_idx += padded_rows_amount;
}

auto it_selector = table_selector_values.begin();
for (std::uint32_t i = ComponentSelectorColumns; i < selector_size; i++) {
fill_vector_value<typename AssignmentTableType::field_type::value_type, column_type>
(table_selector_values, table_proxy.selector(i), it_selector);
it_selector += padded_rows_amount;
print_vector_value<Endianness, ArithmetizationType, column_type>(padded_rows_amount, table_proxy.selector(i), out);
selector_idx += padded_rows_amount;
}
ASSERT_MSG(witness_idx + pub_inp_idx + constant_idx + selector_idx == total_size, "Printed index not equal required assignment size" );
}
auto fill_columns_duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - fill_columns_start);
BOOST_LOG_TRIVIAL(debug) << "fill_columns_duration: " << fill_columns_duration.count() << "ms";

auto fill_table_start = std::chrono::high_resolution_clock::now();
auto filled_val = table_value_marshalling_type(std::make_tuple(
nil::marshalling::types::integral<TTypeBase, std::size_t>(witness_size),
nil::marshalling::types::integral<TTypeBase, std::size_t>(public_input_size + shared_size),
nil::marshalling::types::integral<TTypeBase, std::size_t>(constant_size),
nil::marshalling::types::integral<TTypeBase, std::size_t>(selector_size),
nil::marshalling::types::integral<TTypeBase, std::size_t>(usable_rows_amount),
nil::marshalling::types::integral<TTypeBase, std::size_t>(padded_rows_amount),
nil::crypto3::marshalling::types::fill_field_element_vector<typename AssignmentTableType::field_type::value_type, Endianness>(table_witness_values),
nil::crypto3::marshalling::types::fill_field_element_vector<typename AssignmentTableType::field_type::value_type, Endianness>(table_public_input_values),
nil::crypto3::marshalling::types::fill_field_element_vector<typename AssignmentTableType::field_type::value_type, Endianness>(table_constant_values),
nil::crypto3::marshalling::types::fill_field_element_vector<typename AssignmentTableType::field_type::value_type, Endianness>(table_selector_values)
));
auto fill_table_duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - fill_table_start);
BOOST_LOG_TRIVIAL(debug) << "fill_table_duration: " << fill_table_duration.count() << "ms";

auto clear_vectors_start = std::chrono::high_resolution_clock::now();
table_witness_values.clear();
table_witness_values.shrink_to_fit();

table_public_input_values.clear();
table_public_input_values.shrink_to_fit();

table_constant_values.clear();
table_constant_values.shrink_to_fit();

table_selector_values.clear();
table_selector_values.shrink_to_fit();
auto clear_vectors_duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - clear_vectors_start);
BOOST_LOG_TRIVIAL(debug) << "clear_vectors_duration: " << clear_vectors_duration.count() << "ms";


auto create_cv_vector_start = std::chrono::high_resolution_clock::now();
std::vector<std::uint8_t> cv;
cv.resize(filled_val.length(), 0x00);
auto create_cv_vector_duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - create_cv_vector_start);
BOOST_LOG_TRIVIAL(debug) << "create_cv_vector_duration: " << create_cv_vector_duration.count() << "ms";


auto filled_val_write_start = std::chrono::high_resolution_clock::now();
auto write_iter = cv.begin();
nil::marshalling::status_type status = filled_val.write(write_iter, cv.size());
auto filled_val_write_duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - filled_val_write_start);
BOOST_LOG_TRIVIAL(debug) << "filled_val_write_duration: " << filled_val_write_duration.count() << "ms";


auto out_write_start = std::chrono::high_resolution_clock::now();
out.write(reinterpret_cast<char*>(cv.data()), cv.size());
auto out_write_duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - out_write_start);
BOOST_LOG_TRIVIAL(debug) << "out_write_duration: " << out_write_duration.count() << "ms";
}

bool read_json(
Expand Down Expand Up @@ -466,14 +456,17 @@ void assignment_table_printer(
const std::size_t &ComponentConstantColumns,
const std::size_t &ComponentSelectorColumns
) {
BOOST_LOG_TRIVIAL(info) << "start thread " << idx;
std::ofstream otable;
otable.open(assignment_table_file_name + std::to_string(idx),
std::ios_base::binary | std::ios_base::out);
if (!otable) {
throw std::runtime_error("Failed to open file: " + assignment_table_file_name + std::to_string(idx));
}

print_assignment_table<nil::marshalling::option::big_endian, ArithmetizationType, BlueprintFieldType>(
assigner_instance.assignments[idx], print_table_kind::MULTI_PROVER, ComponentConstantColumns,
ComponentSelectorColumns, otable);

otable.close();
}

Expand Down

0 comments on commit 63bac68

Please sign in to comment.