Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Improve Code Readability #2

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions benches/bench_client_prepare_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ template<size_t λ, size_t db_entry_count, size_t db_entry_byte_len, size_t mat_
static void
bench_client_prepare_query(benchmark::State& state)
{
using server_t = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>;
using client_t = frodoPIR_client::client_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>;

constexpr size_t db_byte_len = db_entry_count * db_entry_byte_len;
constexpr size_t parsed_db_column_count = frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen);
constexpr size_t pub_matM_byte_len = frodoPIR_matrix::matrix_t<lwe_dimension, parsed_db_column_count>::get_byte_len();
Expand All @@ -33,10 +36,10 @@ bench_client_prepare_query(benchmark::State& state)
prng.read(seed_μ_span);
prng.read(db_bytes_span);

auto [server, M] = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>::setup(seed_μ_span, db_bytes_span);
auto [server, M] = server_t::setup(seed_μ_span, db_bytes_span);

M.to_le_bytes(pub_matM_bytes_span);
auto client = frodoPIR_client::client_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>::setup(seed_μ, pub_matM_bytes_span);
auto client = client_t::setup(seed_μ, pub_matM_bytes_span);

size_t rand_db_row_index = [&]() {
size_t buffer = 0;
Expand Down
7 changes: 5 additions & 2 deletions benches/bench_client_process_response.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ template<size_t λ, size_t db_entry_count, size_t db_entry_byte_len, size_t mat_
static void
bench_client_process_response(benchmark::State& state)
{
using server_t = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>;
using client_t = frodoPIR_client::client_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>;

constexpr size_t db_byte_len = db_entry_count * db_entry_byte_len;
constexpr size_t parsed_db_column_count = frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen);
constexpr size_t pub_matM_byte_len = frodoPIR_matrix::matrix_t<lwe_dimension, parsed_db_column_count>::get_byte_len();
Expand All @@ -33,10 +36,10 @@ bench_client_process_response(benchmark::State& state)
prng.read(seed_μ_span);
prng.read(db_bytes_span);

auto [server, M] = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>::setup(seed_μ_span, db_bytes_span);
auto [server, M] = server_t::setup(seed_μ_span, db_bytes_span);

M.to_le_bytes(pub_matM_bytes_span);
auto client = frodoPIR_client::client_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>::setup(seed_μ, pub_matM_bytes_span);
auto client = client_t::setup(seed_μ, pub_matM_bytes_span);

size_t rand_db_row_index = [&]() {
size_t buffer = 0;
Expand Down
7 changes: 5 additions & 2 deletions benches/bench_client_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ template<size_t λ, size_t db_entry_count, size_t db_entry_byte_len, size_t mat_
static void
bench_client_query(benchmark::State& state)
{
using server_t = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>;
using client_t = frodoPIR_client::client_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>;

constexpr size_t db_byte_len = db_entry_count * db_entry_byte_len;
constexpr size_t parsed_db_column_count = frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen);
constexpr size_t pub_matM_byte_len = frodoPIR_matrix::matrix_t<lwe_dimension, parsed_db_column_count>::get_byte_len();
Expand All @@ -33,10 +36,10 @@ bench_client_query(benchmark::State& state)
prng.read(seed_μ_span);
prng.read(db_bytes_span);

auto [server, M] = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>::setup(seed_μ_span, db_bytes_span);
auto [server, M] = server_t::setup(seed_μ_span, db_bytes_span);

M.to_le_bytes(pub_matM_bytes_span);
auto client = frodoPIR_client::client_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>::setup(seed_μ, pub_matM_bytes_span);
auto client = client_t::setup(seed_μ, pub_matM_bytes_span);

size_t rand_db_row_index = [&]() {
size_t buffer = 0;
Expand Down
7 changes: 5 additions & 2 deletions benches/bench_client_setup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ template<size_t λ, size_t db_entry_count, size_t db_entry_byte_len, size_t mat_
static void
bench_client_setup(benchmark::State& state)
{
using server_t = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>;
using client_t = frodoPIR_client::client_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>;

constexpr size_t db_byte_len = db_entry_count * db_entry_byte_len;
constexpr size_t parsed_db_column_count = frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen);
constexpr size_t pub_matM_byte_len = frodoPIR_matrix::matrix_t<lwe_dimension, parsed_db_column_count>::get_byte_len();
Expand All @@ -25,12 +28,12 @@ bench_client_setup(benchmark::State& state)
prng.read(seed_μ_span);
prng.read(db_bytes_span);

auto [server, M] = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>::setup(seed_μ_span, db_bytes_span);
auto [server, M] = server_t::setup(seed_μ_span, db_bytes_span);

M.to_le_bytes(pub_matM_bytes_span);

for (auto _ : state) {
auto client = frodoPIR_client::client_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>::setup(seed_μ, pub_matM_bytes_span);
auto client = client_t::setup(seed_μ, pub_matM_bytes_span);

benchmark::DoNotOptimize(client);
benchmark::DoNotOptimize(seed_μ);
Expand Down
7 changes: 5 additions & 2 deletions benches/bench_server_respond.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ template<size_t λ, size_t db_entry_count, size_t db_entry_byte_len, size_t mat_
static void
bench_server_respond(benchmark::State& state)
{
using server_t = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>;
using client_t = frodoPIR_client::client_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>;

constexpr size_t db_byte_len = db_entry_count * db_entry_byte_len;
constexpr size_t parsed_db_column_count = frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen);
constexpr size_t pub_matM_byte_len = frodoPIR_matrix::matrix_t<lwe_dimension, parsed_db_column_count>::get_byte_len();
Expand All @@ -31,10 +34,10 @@ bench_server_respond(benchmark::State& state)
prng.read(seed_μ_span);
prng.read(db_bytes_span);

auto [server, M] = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>::setup(seed_μ_span, db_bytes_span);
auto [server, M] = server_t::setup(seed_μ_span, db_bytes_span);

M.to_le_bytes(pub_matM_bytes_span);
auto client = frodoPIR_client::client_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>::setup(seed_μ, pub_matM_bytes_span);
auto client = client_t::setup(seed_μ, pub_matM_bytes_span);

const size_t rand_db_row_index = [&]() {
size_t buffer = 0;
Expand Down
4 changes: 3 additions & 1 deletion benches/bench_server_setup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ template<size_t λ, size_t db_entry_count, size_t db_entry_byte_len, size_t mat_
static void
bench_server_setup(benchmark::State& state)
{
using server_t = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>;

constexpr size_t db_byte_len = db_entry_count * db_entry_byte_len;

std::array<uint8_t, λ / std::numeric_limits<uint8_t>::digits> seed_μ{};
Expand All @@ -20,7 +22,7 @@ bench_server_setup(benchmark::State& state)
prng.read(db_bytes_span);

for (auto _ : state) {
auto [server, M] = frodoPIR_server::server_t<λ, db_entry_count, db_entry_byte_len, mat_element_bitlen, lwe_dimension>::setup(seed_μ_span, db_bytes_span);
auto [server, M] = server_t::setup(seed_μ_span, db_bytes_span);

benchmark::DoNotOptimize(seed_μ_span);
benchmark::DoNotOptimize(db_bytes_span);
Expand Down
59 changes: 31 additions & 28 deletions include/frodoPIR/client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,22 @@ template<size_t λ, size_t db_entry_count, size_t db_entry_byte_len, size_t mat_
struct client_t
{
public:
// Compile-time computable values.
static constexpr auto parsed_db_num_cols = frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen);
static constexpr auto pub_mat_M_byte_len = lwe_dimension * parsed_db_num_cols * sizeof(frodoPIR_matrix::zq_t);
static constexpr auto query_byte_len = db_entry_count * sizeof(frodoPIR_matrix::zq_t);
static constexpr auto response_byte_len = parsed_db_num_cols * sizeof(frodoPIR_matrix::zq_t);

// Type aliases.
using pub_mat_A_t = frodoPIR_matrix::matrix_t<lwe_dimension, db_entry_count>;
using pub_mat_M_t = frodoPIR_matrix::matrix_t<lwe_dimension, parsed_db_num_cols>;
using secret_vec_t = frodoPIR_vector::row_vector_t<lwe_dimension>;
using error_vec_t = frodoPIR_vector::row_vector_t<db_entry_count>;
using query_t = client_query_t<db_entry_count, db_entry_byte_len, mat_element_bitlen>;
using response_t = frodoPIR_vector::row_vector_t<parsed_db_num_cols>;

// Constructor(s)
explicit constexpr client_t(
frodoPIR_matrix::matrix_t<lwe_dimension, db_entry_count> pub_matA,
frodoPIR_matrix::matrix_t<lwe_dimension, frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen)> pub_matM)
explicit constexpr client_t(auto pub_matA, auto pub_matM)
: A(std::move(pub_matA))
, M(std::move(pub_matM))
{
Expand All @@ -53,15 +65,10 @@ struct client_t

// Given a `λ` -bit seed and a byte serialized public matrix M, computed by frodoPIR server, this routine can be used
// for setting up FrodoPIR client, ready to generate queries and process server response.
static forceinline constexpr client_t setup(
std::span<const uint8_t, λ / std::numeric_limits<uint8_t>::digits> seed_μ,
std::span<const uint8_t, lwe_dimension * frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen) * sizeof(frodoPIR_matrix::zq_t)>
pub_matM_bytes)
static forceinline constexpr client_t setup(std::span<const uint8_t, λ / std::numeric_limits<uint8_t>::digits> seed_μ,
std::span<const uint8_t, pub_mat_M_byte_len> pub_matM_bytes)
{
constexpr auto cols = frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen);

return client_t(frodoPIR_matrix::matrix_t<lwe_dimension, db_entry_count>::template generate<λ>(seed_μ),
frodoPIR_matrix::matrix_t<lwe_dimension, cols>::from_le_bytes(pub_matM_bytes));
return client_t(pub_mat_A_t::template generate<λ>(seed_μ), pub_mat_M_t::from_le_bytes(pub_matM_bytes));
}

// Given `n` -many database row indices, this routine prepares `n` -many queries, for enquiring their values,
Expand Down Expand Up @@ -90,13 +97,13 @@ struct client_t
return false;
}

const auto s = frodoPIR_vector::row_vector_t<lwe_dimension>::sample_from_uniform_ternary_distribution(prng); // secret vector
const auto e = frodoPIR_vector::row_vector_t<db_entry_count>::sample_from_uniform_ternary_distribution(prng); // error vector
const auto s = secret_vec_t::sample_from_uniform_ternary_distribution(prng); // secret vector
const auto e = error_vec_t::sample_from_uniform_ternary_distribution(prng); // error vector

const auto b = s * this->A + e;
const auto c = s * this->M;

this->queries[db_row_index] = client_query_t<db_entry_count, db_entry_byte_len, mat_element_bitlen>{
this->queries[db_row_index] = query_t{
.status = query_status_t::prepared,
.db_index = db_row_index,
.b = b,
Expand All @@ -112,8 +119,7 @@ struct client_t
//
// (a) Query is not yet prepared for requested database row index.
// (b) Query is already sent to server for requested database row index.
[[nodiscard("Must use status of query finalization")]] constexpr bool query(const size_t db_row_index,
std::span<uint8_t, db_entry_count * sizeof(frodoPIR_matrix::zq_t)> query_bytes)
[[nodiscard("Must use status of query finalization")]] constexpr bool query(const size_t db_row_index, std::span<uint8_t, query_byte_len> query_bytes)
{
if (!this->queries.contains(db_row_index)) {
return false;
Expand All @@ -139,10 +145,9 @@ struct client_t
//
// (a) Query is not yet prepared for requested database row index.
// (b) Query has not yet been sent to server, so can't process response for it.
[[nodiscard("Must use status of response decoding")]] constexpr bool process_response(
const size_t db_row_index,
std::span<const uint8_t, frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen) * sizeof(frodoPIR_matrix::zq_t)> response_bytes,
std::span<uint8_t, db_entry_byte_len> db_row_bytes)
[[nodiscard("Must use status of response decoding")]] constexpr bool process_response(const size_t db_row_index,
std::span<const uint8_t, response_byte_len> response_bytes,
std::span<uint8_t, db_entry_byte_len> db_row_bytes)
{
if (!this->queries.contains(db_row_index)) {
return false;
Expand All @@ -155,12 +160,10 @@ struct client_t
constexpr auto rounding_factor = static_cast<frodoPIR_matrix::zq_t>(frodoPIR_matrix::Q / rho);
constexpr auto rounding_floor = rounding_factor / 2;

constexpr size_t db_matrix_row_width = frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen);

frodoPIR_vector::row_vector_t<db_matrix_row_width> db_matrix_row{};
auto c_tilda = frodoPIR_vector::row_vector_t<db_matrix_row_width>::from_le_bytes(response_bytes);
response_t db_matrix_row{};
auto c_tilda = response_t::from_le_bytes(response_bytes);

for (size_t idx = 0; idx < db_matrix_row_width; idx++) {
for (size_t idx = 0; idx < parsed_db_num_cols; idx++) {
const auto unscaled_res = c_tilda[idx] - this->queries[db_row_index].c[idx];

const auto scaled_res = unscaled_res / rounding_factor;
Expand All @@ -181,9 +184,9 @@ struct client_t
}

private:
frodoPIR_matrix::matrix_t<lwe_dimension, db_entry_count> A{};
frodoPIR_matrix::matrix_t<lwe_dimension, frodoPIR_matrix::get_required_num_columns(db_entry_byte_len, mat_element_bitlen)> M{};
std::unordered_map<size_t, client_query_t<db_entry_count, db_entry_byte_len, mat_element_bitlen>> queries{};
pub_mat_A_t A{};
pub_mat_M_t M{};
std::unordered_map<size_t, query_t> queries{};
};

}
55 changes: 53 additions & 2 deletions include/frodoPIR/internals/matrix/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,33 @@ struct matrix_t
requires(std::endian::native == std::endian::little)
{
auto elements_ptr = reinterpret_cast<const uint8_t*>(this->elements.data());
memcpy(bytes.data(), elements_ptr, bytes.size());
auto bytes_ptr = reinterpret_cast<uint8_t*>(bytes.data());

constexpr size_t min_num_threads = 1;
const size_t hw_hinted_max_num_threads = std::thread::hardware_concurrency();
const size_t spawnable_num_threads = std::max(min_num_threads, hw_hinted_max_num_threads);

constexpr size_t total_num_bytes = bytes.size();
const size_t num_bytes_per_thread = total_num_bytes / spawnable_num_threads;
const size_t num_bytes_distributed = num_bytes_per_thread * spawnable_num_threads;
const size_t remaining_num_bytes = total_num_bytes - num_bytes_distributed;

std::vector<std::thread> threads;
threads.reserve(spawnable_num_threads);

for (size_t t_idx = 0; t_idx < spawnable_num_threads; t_idx++) {
const size_t byte_idx_begin = t_idx * num_bytes_per_thread;

auto thread = std::thread([=, &elements_ptr, &bytes_ptr]() { memcpy(bytes_ptr + byte_idx_begin, elements_ptr + byte_idx_begin, num_bytes_per_thread); });
threads.push_back(std::move(thread));
}

if (remaining_num_bytes > 0) {
const size_t final_thread_byte_idx_begin = num_bytes_distributed;
memcpy(bytes_ptr + final_thread_byte_idx_begin, elements_ptr + final_thread_byte_idx_begin, remaining_num_bytes);
}

std::ranges::for_each(threads, [](auto& handle) { handle.join(); });
}

// Given a byte array of length `rows * cols * 4`, this routine can be used for deserializing it as a matrix of dimension
Expand All @@ -329,8 +355,33 @@ struct matrix_t
matrix_t res{};

auto elements_ptr = reinterpret_cast<uint8_t*>(res.elements.data());
memcpy(elements_ptr, bytes.data(), bytes.size());
auto bytes_ptr = reinterpret_cast<const uint8_t*>(bytes.data());

constexpr size_t min_num_threads = 1;
const size_t hw_hinted_max_num_threads = std::thread::hardware_concurrency();
const size_t spawnable_num_threads = std::max(min_num_threads, hw_hinted_max_num_threads);

constexpr size_t total_num_bytes = bytes.size();
const size_t num_bytes_per_thread = total_num_bytes / spawnable_num_threads;
const size_t num_bytes_distributed = num_bytes_per_thread * spawnable_num_threads;
const size_t remaining_num_bytes = total_num_bytes - num_bytes_distributed;

std::vector<std::thread> threads;
threads.reserve(spawnable_num_threads);

for (size_t t_idx = 0; t_idx < spawnable_num_threads; t_idx++) {
const size_t byte_idx_begin = t_idx * num_bytes_per_thread;

auto thread = std::thread([=, &elements_ptr, &bytes_ptr]() { memcpy(elements_ptr + byte_idx_begin, bytes_ptr + byte_idx_begin, num_bytes_per_thread); });
threads.push_back(std::move(thread));
}

if (remaining_num_bytes > 0) {
const size_t final_thread_byte_idx_begin = num_bytes_distributed;
memcpy(elements_ptr + final_thread_byte_idx_begin, bytes_ptr + final_thread_byte_idx_begin, remaining_num_bytes);
}

std::ranges::for_each(threads, [](auto& handle) { handle.join(); });
return res;
}

Expand Down
Loading