forked from microsoft/SealPIR
-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.cpp
133 lines (109 loc) · 5.31 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include "pir.hpp"
#include "pir_client.hpp"
#include "pir_server.hpp"
#include <seal/seal.h>
#include <chrono>
#include <memory>
#include <random>
#include <cstdint>
#include <cstddef>
using namespace std::chrono;
using namespace std;
using namespace seal;
int main(int argc, char *argv[]) {
uint64_t number_of_items = 1 << 10;
uint64_t size_per_item = 288; // in bytes
uint32_t N = 4096;
// Recommended values: (logt, d) = (12, 2) or (8, 1).
uint32_t logt = 16;
uint32_t d = 2;
EncryptionParameters params(scheme_type::BFV);
PirParams pir_params;
// Generates all parameters
cout << "Main: Generating all parameters" << endl;
gen_params(number_of_items, size_per_item, N, logt, d, params, pir_params);
auto context = SEALContext::Create(params, false);
if (!context->parameters_set()) {
cout << "Main: failed to set encryption parameters: "
<< context->parameter_error_message() << endl;
}
cout << "Main: Initializing the database (this may take some time) ..." << endl;
// Create test database
auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
// Copy of the database. We use this at the end to make sure we retrieved
// the correct element.
auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
random_device rd;
for (uint64_t i = 0; i < number_of_items; i++) {
for (uint64_t j = 0; j < size_per_item; j++) {
auto val = rd() % 256;
db.get()[(i * size_per_item) + j] = val;
db_copy.get()[(i * size_per_item) + j] = val;
}
}
// Initialize PIR Server
cout << "Main: Initializing server" << endl;
PIRServer server(params, pir_params);
// Initialize PIR client....
cout << "Main: Initializing client" << endl;
PIRClient client(params, pir_params);
cout << "Main: Generating Galois Keys" << endl;
GaloisKeys galois_keys = client.generate_galois_keys();
// Set galois key for client with id 0
cout << "Main: Setting Galois keys..." << endl;
server.set_galois_key(0, galois_keys);
// Measure database setup
cout << "Main: pre processing database... " << endl;
auto time_pre_s = high_resolution_clock::now();
server.set_database(move(db), number_of_items, size_per_item);
server.preprocess_database();
auto time_pre_e = high_resolution_clock::now();
auto time_pre_us = duration_cast<microseconds>(time_pre_e - time_pre_s).count();
cout << "Main: database pre processed " << endl;
// Choose an index of an element in the DB
uint64_t ele_index = rd() % number_of_items; // element in DB at random position
uint64_t index = client.get_fv_index(ele_index, size_per_item); // index of FV plaintext
uint64_t offset = client.get_fv_offset(ele_index, size_per_item); // offset in FV plaintext
cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
// Measure query generation
auto time_query_s = high_resolution_clock::now();
PirQuery query = client.generate_query(index);
auto time_query_e = high_resolution_clock::now();
auto time_query_us = duration_cast<microseconds>(time_query_e - time_query_s).count();
cout << "Main: query generated" << endl;
//To marshall query to send over the network, you can use serialize/deserialize:
//std::string query_ser = serialize_query(query);
//PirQuery query2 = deserialize_query(d, 1, query_ser, CIPHER_SIZE);
// Measure query processing (including expansion)
auto time_server_s = high_resolution_clock::now();
PirReply reply = server.generate_reply(query, 0);
auto time_server_e = high_resolution_clock::now();
auto time_server_us = duration_cast<microseconds>(time_server_e - time_server_s).count();
// Measure response extraction
auto time_decode_s = chrono::high_resolution_clock::now();
Plaintext result = client.decode_reply(reply);
auto time_decode_e = chrono::high_resolution_clock::now();
auto time_decode_us = duration_cast<microseconds>(time_decode_e - time_decode_s).count();
// Convert from FV plaintext (polynomial) to database element at the client
vector<uint8_t> elems(N * logt / 8);
coeffs_to_bytes(logt, result, elems.data(), (N * logt) / 8);
// Check that we retrieved the correct element
for (uint32_t i = 0; i < size_per_item; i++) {
if (elems[(offset * size_per_item) + i] != db_copy.get()[(ele_index * size_per_item) + i]) {
cout << "Main: elems " << (int)elems[(offset * size_per_item) + i] << ", db "
<< (int) db_copy.get()[(ele_index * size_per_item) + i] << endl;
cout << "Main: PIR result wrong!" << endl;
return -1;
}
}
// Output results
cout << "Main: PIR result correct!" << endl;
cout << "Main: PIRServer pre-processing time: " << time_pre_us / 1000 << " ms" << endl;
cout << "Main: PIRClient query generation time: " << time_query_us / 1000 << " ms" << endl;
cout << "Main: PIRServer reply generation time: " << time_server_us / 1000 << " ms"
<< endl;
cout << "Main: PIRClient answer decode time: " << time_decode_us / 1000 << " ms" << endl;
cout << "Main: Reply num ciphertexts: " << reply.size() << endl;
return 0;
}