Skip to content

Commit

Permalink
repo sync (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
huocun-ant authored Dec 11, 2024
1 parent 6e67a19 commit 8caf034
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 41 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ chmod +x traceconv
```
4. Open chrome://tracing in your chrome and load JSON file.


## PSI V2 Benchamrk

Please refer to [PSI V2 Benchmark](docs/user_guide/psi_v2_benchmark.md)
Expand Down
3 changes: 1 addition & 2 deletions benchmark/docker-compose/.env
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# OPENSOURCE-CLEANUP GSUB psi:latest secretflow/psi:latest
# docker env
IMAGE_WITH_TAG=secretflow/psi-anolis8:0.4.2b0
IMAGE_WITH_TAG=secretflow/psi:latest

# network env
# LATENCY=10ms
Expand Down
3 changes: 2 additions & 1 deletion benchmark/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import csv
import sys
import os
import time
from datetime import datetime

Expand All @@ -40,7 +41,7 @@ def stream_container_stats(container_name, output_file):
data = json.loads(stats)
running_time_s = int(time.time()) - start_unix_time
cpu_percent = ((data['cpu_stats']['cpu_usage']['total_usage'] - prev_cpu_total) /
(data['cpu_stats']['system_cpu_usage'] - prev_cpu_system)) * 100
(data['cpu_stats']['system_cpu_usage'] - prev_cpu_system)) * 100 * os.cpu_count()
mem_usage = (data['memory_stats']['usage'] - data['memory_stats']['stats']['inactive_file']) / 1024 / 1024
mem_limit = data['memory_stats']['limit'] / 1024 / 1024
net_tx = 0
Expand Down
2 changes: 1 addition & 1 deletion docker/entry.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ cd src_copied

conda install -y perl=5.20.3.1

bazel build psi:main -c opt --config=linux-release --repository_cache=/tmp/bazel_repo_cache
bazel build psi:main -c opt --config=linux-release --remote_timeout=300s --remote_retries=10
chmod 777 bazel-bin/psi/main
mkdir -p ../src/docker/linux/amd64
cp bazel-bin/psi/main ../src/docker/linux/amd64
1 change: 1 addition & 0 deletions experiment/pir/pps/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ psi_cc_library(
deps = [
":ggm_pset",
"@yacl//yacl/base:dynamic_bitset",
"@yacl//yacl/base:exception",
"@yacl//yacl/crypto/rand",
"@yacl//yacl/crypto/tools:prg",
],
Expand Down
41 changes: 27 additions & 14 deletions experiment/pir/pps/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include <spdlog/spdlog.h>

#include "yacl/base/exception.h"

namespace pir::pps {

bool PpsPirClient::Bernoulli() {
Expand All @@ -34,17 +36,22 @@ uint64_t PpsPirClient::GetRandomU64Less() {
// Generate sk and m random numbers \in [n]
void PpsPirClient::Setup(PIRKey& sk, std::set<uint64_t>& deltas) {
sk = pps_.Gen(lambda_);

size_t max_try_count = 10 * M();
size_t count = 0;

// The map.size() must be equal to SET_SIZE.
std::vector<uint64_t> rand =
yacl::crypto::PrgAesCtr<uint64_t>(yacl::crypto::RandU64(), M());
for (uint64_t i = 0; i < M(); i++) {
// The most expensive operation.
uint64_t r = LemireTrick(rand[i], universe_size_);
size_t i = 0;
while (i < M() && count < max_try_count) {
count += 1;
uint64_t r = LemireTrick(yacl::crypto::RandU64(), universe_size_);
if (!deltas.insert(r).second) {
rand[i] = yacl::crypto::RandU64();
i--;
continue;
}
++i;
}

YACL_ENFORCE(count < max_try_count);
}

// Params:
Expand Down Expand Up @@ -91,18 +98,24 @@ void PpsPirClient::Setup(std::vector<PIRKeyUnion>& ck,
std::vector<std::unordered_set<uint64_t>>& v) {
ck.resize(MM());
v.resize(MM());
std::vector<uint128_t> rand =
yacl::crypto::PrgAesCtr<uint128_t>(yacl::crypto::RandU128(), MM());
for (uint64_t i = 0; i < MM(); ++i) {
pps_.Eval(rand[i], v[i]);

size_t max_try_count = 10 * MM();
size_t count = 0;

size_t i = 0;
while (i < MM() && count < max_try_count) {
count += 1;
auto rand = yacl::crypto::RandU128();
pps_.Eval(rand, v[i]);
if (v[i].size() == set_size_) {
ck[i] = PIRKeyUnion(rand[i]);
ck[i] = PIRKeyUnion(rand);
} else {
v[i].clear();
rand[i] = yacl::crypto::RandU128();
--i;
continue;
}
++i;
}
YACL_ENFORCE(count < max_try_count);
}

void PpsPirClient::Query(uint64_t i, std::vector<PIRKeyUnion>& ck,
Expand Down
22 changes: 11 additions & 11 deletions experiment/pir/pps/pps_pir_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ static void BM_PpsSingleBitPir(benchmark::State& state) {
pir::pps::PpsPirServer pirOfflineServer(n * n, n);
pir::pps::PpsPirServer pirOnlineServer(n * n, n);

pir::pps::PIRKey pirKey, pirKeyOffline;
pir::pps::PIRQueryParam pirQueryParam;
pir::pps::PIRPuncKey pirPuncKey, pirPuncKeyOnline;
std::set<uint64_t> deltas, deltasOffline;
pir::pps::PIRKey pirKey{}, pirKeyOffline{};
pir::pps::PIRQueryParam pirQueryParam{};
pir::pps::PIRPuncKey pirPuncKey{}, pirPuncKeyOnline{};
std::set<uint64_t> deltas{}, deltasOffline{};
yacl::dynamic_bitset<> bits;
GenerateRandomBitString(bits, n * n);
yacl::dynamic_bitset<> h, hOffline;
uint64_t query_index = pirClient.GetRandomU64Less();
bool query_result;
bool query_result{};

constexpr int kWorldSize = 2;
const auto contextsOffline = yacl::link::test::SetupWorld(kWorldSize);
Expand Down Expand Up @@ -102,7 +102,7 @@ static void BM_PpsSingleBitPir(benchmark::State& state) {
recver_future.get();

bool a = pirOnlineServer.Answer(pirPuncKeyOnline, bits);
bool aClient;
bool aClient{};

sender_future =
std::async(std::launch::async, pir::pps::OnlineServerSendToClient,
Expand All @@ -129,13 +129,13 @@ static void BM_PpsMultiBitsPir(benchmark::State& state) {
pir::pps::PpsPirServer pirOfflineServer(n * n, n);
pir::pps::PpsPirServer pirOnlineServer(n * n, n);

std::vector<pir::pps::PIRKeyUnion> pirKey, pirKeyOffline;
std::vector<pir::pps::PIRKeyUnion> pirKey{}, pirKeyOffline{};
yacl::dynamic_bitset<> bits;
GenerateRandomBitString(bits, n * n);
yacl::dynamic_bitset<> h, hOffline;
pir::pps::PIRQueryParam pirParam;
pir::pps::PIRQueryParam pirParam{};

bool aLeft, aRight, aLeftOnline, aRightOnline, queryResult;
bool aLeft{}, aRight{}, aLeftOnline{}, aRightOnline{}, queryResult{};
std::vector<std::unordered_set<uint64_t>> v;

constexpr int kWorldSize = 2;
Expand Down Expand Up @@ -170,8 +170,8 @@ static void BM_PpsMultiBitsPir(benchmark::State& state) {
recver_future.get();

for (uint i = 0; i < n * n; ++i) {
pir::pps::PIRPuncKey pirPuncKeyL, pirPuncKeyR;
pir::pps::PIRPuncKey pirPuncKeyLOnline, pirPuncKeyROnline;
pir::pps::PIRPuncKey pirPuncKeyL{}, pirPuncKeyR{};
pir::pps::PIRPuncKey pirPuncKeyLOnline{}, pirPuncKeyROnline{};

pirClient.Query(i, pirKey, v, pirParam, pirPuncKeyL, pirPuncKeyR);

Expand Down
2 changes: 1 addition & 1 deletion experiment/pir/pps/sender.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

namespace pir::pps {
std::array<std::byte, 16> Uint128_to_bytes(PIRKey sk) {
std::array<std::byte, 16> bytes;
std::array<std::byte, 16> bytes{};
uint64_t high = static_cast<uint64_t>(sk >> 64);
uint64_t low = static_cast<uint64_t>(sk & 0xFFFFFFFFFFFFFFFF);
std::memcpy(bytes.data(), &high, sizeof(high));
Expand Down
2 changes: 1 addition & 1 deletion psi/apsi_wrapper/api/receiver_c_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Receiver* BucketReceiverMake(size_t bucket_cnt, size_t thread_count) {
}

void BucketReceiverFree(Receiver** receiver) {
if (receiver != nullptr || *receiver == nullptr) {
if (receiver == nullptr || *receiver == nullptr) {
return;
}
(void)std::unique_ptr<ApiReceiver>(reinterpret_cast<ApiReceiver*>(*receiver));
Expand Down
13 changes: 9 additions & 4 deletions psi/rr22/rr22_psi.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,18 @@ class Rr22Runner {
futures[i] = std::async(
std::launch::async,
[&](size_t thread_idx) {
std::shared_ptr<yacl::link::Context> spawn_read_lctx =
read_lctx_->Spawn(std::to_string(thread_idx));
std::shared_ptr<yacl::link::Context> spawn_run_lctx =
run_lctx_->Spawn(std::to_string(thread_idx));
std::shared_ptr<yacl::link::Context> spawn_intersection_lctx =
intersection_lctx_->Spawn(std::to_string(thread_idx));
for (size_t j = 0; j < bucket_num_; j++) {
if (j % parallel_num == thread_idx) {
auto runner = CreateBucketRunner(j, is_sender);
runner->Prepare(read_lctx_->Spawn(std::to_string(thread_idx)));
runner->RunOprf(run_lctx_->Spawn(std::to_string(thread_idx)));
runner->GetIntersection(
intersection_lctx_->Spawn(std::to_string(thread_idx)));
runner->Prepare(spawn_read_lctx);
runner->RunOprf(spawn_run_lctx);
runner->GetIntersection(spawn_intersection_lctx);
}
}
},
Expand Down
9 changes: 5 additions & 4 deletions psi/sealpir/seal_pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ uint32_t ComputeExpansionRatio(seal::EncryptionParameters params) {
double logqi = log2(params.coeff_modulus()[i].value());
expansion_ratio += ceil(logqi / logt);
}
YACL_ENFORCE(expansion_ratio > 0, "expansion_ratio must be greater than 0");
return expansion_ratio;
}
uint64_t CoefficientsPerElement(uint32_t logt, uint64_t ele_size) {
Expand Down Expand Up @@ -169,7 +170,7 @@ vector<seal::Plaintext> DecomposeToPlaintexts(seal::EncryptionParameters params,
const auto N = params.poly_modulus_degree();
const auto coeff_mod_count = params.coeff_modulus().size();
const uint32_t logt = log2(params.plain_modulus().value());
const uint64_t pt_bitmask = (1 << logt) - 1;
const uint64_t pt_bitmask = (1ULL << logt) - 1;

vector<seal::Plaintext> result(ComputeExpansionRatio(params) * ct.size());
auto pt_iter = result.begin();
Expand Down Expand Up @@ -750,7 +751,7 @@ inline vector<Ciphertext> SealPirServer::ExpandQuery(

for (uint32_t i = 0; i < logm - 1; ++i) {
vector<Ciphertext> new_tmp(tmp.size() << 1);
int index_raw = (N << 1) - (1 << i);
int index_raw = (N << 1) - (1ULL << i);
int index = (index_raw + N) % (N << 1);
// int index = (index_raw * galelts[i]) % (N << 1);

Expand All @@ -768,13 +769,13 @@ inline vector<Ciphertext> SealPirServer::ExpandQuery(
}

vector<Ciphertext> new_tmp(tmp.size() << 1);
int index_raw = (N << 1) - (1 << (logm - 1));
int index_raw = (N << 1) - (1ULL << (logm - 1));
int index = (index_raw + N) % (N << 1);
// int index = (index_raw * galelts[logm - 1]) % (N << 1);
Plaintext two("2");

for (uint32_t j = 0; j < tmp.size(); ++j) {
if (j < (m - (1 << (logm - 1)))) {
if (j < (m - (1ULL << (logm - 1)))) {
evaluator_->apply_galois(tmp[j], galelts[logm - 1], galkey,
tmpctxt_rotated);
evaluator_->add(tmp[j], tmpctxt_rotated, new_tmp[j]);
Expand Down
11 changes: 9 additions & 2 deletions psi/utils/ub_psi_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <utility>
#include <vector>

#include "spdlog/spdlog.h"
#include "yacl/base/byte_container_view.h"

#include "psi/utils/batch_provider.h"
Expand Down Expand Up @@ -90,8 +91,14 @@ class UbPsiCache : public IUbPsiCache {
std::vector<uint8_t> private_key);

~UbPsiCache() {
Flush();
out_stream_->Close();
try {
Flush();
if (out_stream_) {
out_stream_->Close();
}
} catch (const std::exception& e) {
SPDLOG_ERROR("UbPsiCache flush failed: {}", e.what());
}
}

void SaveData(yacl::ByteContainerView item, size_t index,
Expand Down

0 comments on commit 8caf034

Please sign in to comment.