Skip to content

Commit

Permalink
Improvedlpn (#15)
Browse files Browse the repository at this point in the history
* improved lpnb

* update

* update perf

* update perf

* reduce to on thread in test

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
wangxiao1254 and Ubuntu authored Sep 5, 2021
1 parent 15ff0e2 commit 7190dc4
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 65 deletions.
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,16 @@ All values are for "million gates per second".
##### Boolean circuits
|Threads|10 Mbps|20 Mbps|30 Mbps|50 Mbps|Localhost|
|-------|-------|-------|-------|-------|---------|
|1|4.4|6.2|7.0|7.5|7.6|
|2|5.3|8.1|9.9|11.8|11.8|
|3|5.7|9.1|11.4|13.9|14.3|
|4|5.8|9.9|12.2|14.9|15.8|
|1|5.1|7.8|8.6|8.6|8.6|
|2|6|10|12.9|14.3|13.6|
|3|6.3|10.9|14.5|17.3|18|
|4|6.4|11.4|15.1|19|19.4|
##### Arithmetic circuits
|Threads|100 Mbps|500 Mbps|1 Gbps|2 Gbps|Localhost|
|-------|-------|-------|-------|-------|---------|
|1|1.2|3.4|4.2|4.8|4.8|
|2|1.3|4.4|6.1|7.0|7.1|
|3|1.4|4.9|7.2|8.4|8.4|
|4|1.4|5.0|7.5|8.9|8.9|
|1|1.4|4.8|6.8|7.8|7.8|
|2|1.4|5.6|8.7|10.2|10.4|
|3|1.4|5.9|9.3|11.7|12.5|


Question
Expand Down
152 changes: 96 additions & 56 deletions emp-zk/emp-vole/lpn.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,108 +12,148 @@ class LpnFp { public:
int threads;
block seed;

int round, leftover;

__uint128_t *M;
const __uint128_t *preM, *prex;
__uint128_t *K;
const __uint128_t *preK;

uint32_t k_mask;

LpnFp (int n, int k, ThreadPool * pool, int threads, block seed = zero_block) {
this->k = k;
this->n = n;
this->pool = pool;
this->threads = threads;
this->seed = seed;

round = d / 4;
leftover = d % 4;

this->k_mask = k_mask_gen(k);
k_mask = 1;
while(k_mask < (uint32_t)k) {
k_mask <<=1;
k_mask = k_mask | 0x1;
}
}

uint32_t k_mask_gen(int kin) {
int ksz = kin;
int sz = 0;
while(ksz > 1) {
sz++;
ksz = ksz>>1;
}
return (1<<sz)-1;
void add2_single(int idx1, int* idx2) {
block Midx1 = (block)M[idx1];
for(int j = 0; j < 5; ++j)
Midx1 = _mm_add_epi64(Midx1, (block)preM[idx2[j]]);
Midx1 = vec_mod(Midx1);
for(int j = 5; j < 10; ++j)
Midx1 = _mm_add_epi64(Midx1, (block)preM[idx2[j]]);
M[idx1] = (__uint128_t)vec_mod(Midx1);
}
void add1_single(int idx1, int* idx2) {
uint64_t Kidx1 = K[idx1];
for(int j = 0; j < 5; ++j)
Kidx1 = Kidx1 + preK[idx2[j]];
Kidx1 = mod(Kidx1);
for(int j = 5; j < 10; ++j)
Kidx1 = Kidx1 + preK[idx2[j]];
K[idx1] = mod(Kidx1);
}

void add2(int idx1, int* idx2, uint64_t* mult) {
__uint128_t res[2], valM[2];
void add2(int idx1, int* idx2) {
block tmp[4];
tmp[0] = (block)M[idx1];
tmp[1] = (block)M[idx1+1];
tmp[2] = (block)M[idx1+2];
tmp[3] = (block)M[idx1+3];
int * p = idx2;
for(int j = 0; j < 5; ++j) {
valM[0] = preM[idx2[2*j]];
valM[1] = preM[idx2[2*j+1]];
mult_mod_bch2((block*)res, (block*)valM, mult+2*j);
M[idx1] = (__uint128_t)add_mod((block)M[idx1], (block)res[0]);
M[idx1] = (__uint128_t)add_mod((block)M[idx1], (block)res[1]);
tmp[0] = _mm_add_epi64((block)tmp[0], (block)preM[*(p++)]);
tmp[1] = _mm_add_epi64((block)tmp[1], (block)preM[*(p++)]);
tmp[2] = _mm_add_epi64((block)tmp[2], (block)preM[*(p++)]);
tmp[3] = _mm_add_epi64((block)tmp[3], (block)preM[*(p++)]);
}
}

void add1(int idx1, int* idx2, uint64_t* mult) {
uint64_t res[2], valK[2];
tmp[0] = vec_mod(tmp[0]);
tmp[1] = vec_mod(tmp[1]);
tmp[2] = vec_mod(tmp[2]);
tmp[3] = vec_mod(tmp[3]);
for(int j = 5; j < 10; ++j) {
tmp[0] = _mm_add_epi64((block)tmp[0], (block)preM[*(p++)]);
tmp[1] = _mm_add_epi64((block)tmp[1], (block)preM[*(p++)]);
tmp[2] = _mm_add_epi64((block)tmp[2], (block)preM[*(p++)]);
tmp[3] = _mm_add_epi64((block)tmp[3], (block)preM[*(p++)]);
}
M[idx1] = (__uint128_t)vec_mod(tmp[0]);
M[idx1+1] = (__uint128_t)vec_mod(tmp[1]);
M[idx1+2] = (__uint128_t)vec_mod(tmp[2]);
M[idx1+3] = (__uint128_t)vec_mod(tmp[3]);
}

void add1(int idx1, int* idx2) {
uint64_t tmp[4];
tmp[0] = 0;
tmp[1] = 0;
tmp[2] = 0;
tmp[3] = 0;
int * p = idx2;
for(int j = 0; j < 5; ++j) {
valK[0] = preK[idx2[2*j]];
valK[1] = preK[idx2[2*j+1]];
mult_mod_bch2(res, valK, mult+2*j);
K[idx1] = (__uint128_t)add_mod(K[idx1], res[0]);
K[idx1] = (__uint128_t)add_mod(K[idx1], res[1]);
tmp[0] += preK[*(p++)];
tmp[1] += preK[*(p++)];
tmp[2] += preK[*(p++)];
tmp[3] += preK[*(p++)];
}
tmp[0] = mod(tmp[0]);
tmp[1] = mod(tmp[1]);
tmp[2] = mod(tmp[2]);
tmp[3] = mod(tmp[3]);
for(int j = 5; j < 10; ++j) {
tmp[0] += preK[*(p++)];
tmp[1] += preK[*(p++)];
tmp[2] += preK[*(p++)];
tmp[3] += preK[*(p++)];
}
K[idx1] = mod(K[idx1] + tmp[0]);
K[idx1+1] = mod(K[idx1+1] + tmp[1]);
K[idx1+2] = mod(K[idx1+2] + tmp[2]);
K[idx1+3] = mod(K[idx1+3] + tmp[3]);
}

void __compute4(int i, PRP *prp, std::function<void(int, int*, uint64_t*)> add_func) {
block tmp[30];
for(int m = 0; m < 30; ++m)

void __compute4(int i, PRP *prp, std::function<void(int, int*)> add_func) {
block tmp[10];
for(int m = 0; m < 10; ++m)
tmp[m] = makeBlock(i, m);
prp->permute_block(tmp, 30);
uint32_t* r = (uint32_t*)(tmp);
uint64_t* mult = (uint64_t*)(tmp+10);
for(int m = 0; m < 4; ++m) {
int index[d];
for (int j = 0; j < d; ++j) {
index[j] = r[m*d+j]&k_mask;
mult[m*d+j] = mod(mult[m*d+j]);
}
add_func(i+m, index, mult+m*d);
prp->permute_block(tmp, 10);
int* index = (int*)(tmp);
for(int j = 0; j < 4*d; ++j) {
index[j] = index[j]&k_mask;
index[j] = index[j] >= k? index[j]-k:index[j];
}
add_func(i, index);
}

void __compute1(int i, PRP *prp, std::function<void(int, int*, uint64_t*)> add_func) {
block tmp[8];
for(int m = 0; m < 8; ++m)
void __compute1(int i, PRP *prp, std::function<void(int, int*)> add_func) {
block tmp[3];
for(int m = 0; m < 3; ++m)
tmp[m] = makeBlock(i, m);
prp->permute_block(tmp, 8);
prp->permute_block(tmp, 3);
uint32_t* r = (uint32_t*)(tmp);
uint64_t* mult = (uint64_t*)(tmp+3);

int index[d];
for (int j = 0; j < d; ++j) {
index[j] = r[j]&k_mask;
mult[j] = mod(mult[j]);
index[j] = index[j] >= k? index[j]-k:index[j];
}
add_func(i, index, mult);
add_func(i, index);
}

void task(int start, int end) {
PRP prp(seed);
int j = start;
if(party == 1) {
std::function<void(int, int*, uint64_t*)> add_func1 = std::bind(&LpnFp::add1, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3);
std::function<void(int, int*)> add_func1 = std::bind(&LpnFp::add1, this, std::placeholders::_1, std::placeholders::_2);
std::function<void(int, int*)> add_func1s = std::bind(&LpnFp::add1_single, this, std::placeholders::_1, std::placeholders::_2);
for(; j < end-4; j+=4)
__compute4(j, &prp, add_func1);
for(; j < end; ++j)
__compute1(j, &prp, add_func1);
__compute1(j, &prp, add_func1s);
} else {
std::function<void(int, int*, uint64_t*)> add_func2 = std::bind(&LpnFp::add2, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3);
std::function<void(int, int*)> add_func2 = std::bind(&LpnFp::add2, this, std::placeholders::_1, std::placeholders::_2);
std::function<void(int, int*)> add_func2s = std::bind(&LpnFp::add2_single, this, std::placeholders::_1, std::placeholders::_2);
for(; j < end-4; j+=4)
__compute4(j, &prp, add_func2);
for(; j < end; ++j)
__compute1(j, &prp, add_func2);
__compute1(j, &prp, add_func2s);
}
}

Expand Down
3 changes: 2 additions & 1 deletion test/vole/lpn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ void test_lpn(NetIO *io, int party) {
Delta = Delta & ((__uint128_t)0xFFFFFFFFFFFFFFFFLL);
Delta = mod(Delta, pr);

//test cases reduced for github action
int test_n = 1016832/2;
int test_k = 15800;
int test_k = 158000/10;
__uint128_t *mac1 = new __uint128_t[test_n];
__uint128_t *mac2 = new __uint128_t[test_k];

Expand Down

0 comments on commit 7190dc4

Please sign in to comment.