-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(frontend): refactor key value database example to use modules, t…
…est and benchmark it
- Loading branch information
1 parent
8477f95
commit d3c5d64
Showing
10 changed files
with
952 additions
and
593 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,335 @@ | ||
""" | ||
Benchmarks of the static key value database example. | ||
""" | ||
|
||
import random | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import py_progress_tracker as progress | ||
|
||
from concrete import fhe | ||
from examples.key_value_database.static_size import StaticKeyValueDatabase | ||
|
||
|
||
def benchmark_insert(db: StaticKeyValueDatabase, client: fhe.Client, server: fhe.Server): | ||
""" | ||
Benchmark inserting an entry to the database. | ||
""" | ||
|
||
print("Warming up...") | ||
|
||
sample_key = random.randint(0, 2**db.key_size - 1) | ||
sample_value = random.randint(0, 2**db.value_size - 1) | ||
|
||
# Initial state only contains odd keys for benchmarks. | ||
# To avoid collisions, we'll make sure that sample_key is even. | ||
if sample_key % 2 == 1: | ||
sample_key -= 1 | ||
|
||
encoded_sample_key = db.encode_key(sample_key) | ||
encoded_sample_value = db.encode_value(sample_value) | ||
|
||
_, encrypted_sample_key, encrypted_sample_value = client.encrypt( # type: ignore | ||
None, | ||
encoded_sample_key, | ||
encoded_sample_value, | ||
function_name="insert", | ||
) | ||
ran = server.run( # noqa: F841 # pylint: disable=unused-variable | ||
db.state, | ||
encrypted_sample_key, | ||
encrypted_sample_value, | ||
function_name="insert", | ||
evaluation_keys=client.evaluation_keys, | ||
) | ||
|
||
for i in range(5): | ||
print(f"Running subsample {i + 1} out of 5...") | ||
|
||
sample_key = random.randint(0, 2**db.key_size - 1) | ||
sample_value = random.randint(0, 2**db.value_size - 1) | ||
|
||
if sample_key % 2 == 1: | ||
sample_key -= 1 | ||
|
||
encoded_sample_key = db.encode_key(sample_key) | ||
encoded_sample_value = db.encode_value(sample_value) | ||
|
||
_, encrypted_sample_key, encrypted_sample_value = client.encrypt( # type: ignore | ||
None, | ||
encoded_sample_key, | ||
encoded_sample_value, | ||
function_name="insert", | ||
) | ||
|
||
with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): | ||
ran = server.run( # noqa: F841 | ||
db.state, | ||
encrypted_sample_key, | ||
encrypted_sample_value, | ||
function_name="insert", | ||
evaluation_keys=client.evaluation_keys, | ||
) | ||
|
||
|
||
def benchmark_replace(db: StaticKeyValueDatabase, client: fhe.Client, server: fhe.Server): | ||
""" | ||
Benchmark replacing an entry in the database. | ||
""" | ||
|
||
print("Warming up...") | ||
|
||
sample_key = random.randint(0, db.number_of_entries // 2) * 2 | ||
sample_value = random.randint(0, db.number_of_entries // 2) * 2 | ||
|
||
# Initial state only contains odd keys for benchmarks. | ||
# To actually replace, we'll make sure that sample_key is odd. | ||
if sample_key % 2 == 0: | ||
sample_key += 1 | ||
|
||
encoded_sample_key = db.encode_key(sample_key) | ||
encoded_sample_value = db.encode_value(sample_value) | ||
|
||
_, encrypted_sample_key, encrypted_sample_value = client.encrypt( # type: ignore | ||
None, | ||
encoded_sample_key, | ||
encoded_sample_value, | ||
function_name="replace", | ||
) | ||
ran = server.run( # noqa: F841 # pylint: disable=unused-variable | ||
db.state, | ||
encrypted_sample_key, | ||
encrypted_sample_value, | ||
function_name="replace", | ||
evaluation_keys=client.evaluation_keys, | ||
) | ||
|
||
for i in range(5): | ||
print(f"Running subsample {i + 1} out of 5...") | ||
|
||
sample_key = random.randint(0, db.number_of_entries - 1) | ||
sample_value = random.randint(0, db.number_of_entries - 1) | ||
|
||
if sample_key % 2 == 0: | ||
sample_key += 1 | ||
|
||
encoded_sample_key = db.encode_key(sample_key) | ||
encoded_sample_value = db.encode_value(sample_value) | ||
|
||
_, encrypted_sample_key, encrypted_sample_value = client.encrypt( # type: ignore | ||
None, | ||
encoded_sample_key, | ||
encoded_sample_value, | ||
function_name="replace", | ||
) | ||
|
||
with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): | ||
ran = server.run( # noqa: F841 | ||
db.state, | ||
encrypted_sample_key, | ||
encrypted_sample_value, | ||
function_name="replace", | ||
evaluation_keys=client.evaluation_keys, | ||
) | ||
|
||
|
||
def benchmark_query(db: StaticKeyValueDatabase, client: fhe.Client, server: fhe.Server): | ||
""" | ||
Benchmark querying a key in the database. | ||
""" | ||
|
||
print("Warming up...") | ||
|
||
sample_key = random.randint(0, db.number_of_entries - 1) | ||
encoded_sample_key = db.encode_key(sample_key) | ||
|
||
_, encrypted_sample_key = client.encrypt( # type: ignore | ||
None, | ||
encoded_sample_key, | ||
function_name="query", | ||
) | ||
ran = server.run( # noqa: F841 # pylint: disable=unused-variable | ||
db.state, | ||
encrypted_sample_key, | ||
function_name="query", | ||
evaluation_keys=client.evaluation_keys, | ||
) | ||
|
||
for i in range(5): | ||
print(f"Running subsample {i + 1} out of 5...") | ||
|
||
sample_key = random.randint(0, db.number_of_entries - 1) | ||
encoded_sample_key = db.encode_key(sample_key) | ||
|
||
_, encrypted_sample_key = client.encrypt( # type: ignore | ||
None, | ||
encoded_sample_key, | ||
function_name="query", | ||
) | ||
with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): | ||
ran = server.run( # noqa: F841 | ||
db.state, | ||
encrypted_sample_key, | ||
function_name="query", | ||
evaluation_keys=client.evaluation_keys, | ||
) | ||
|
||
|
||
def targets(): | ||
""" | ||
Generates targets to benchmark. | ||
""" | ||
|
||
result = [] | ||
for number_of_entries in [8, 16]: | ||
for key_size in [8, 16]: | ||
for value_size in [8, 16]: | ||
for chunk_size in [2, 4]: | ||
result.append( | ||
{ | ||
"id": ( | ||
f"static-kvdb-insert :: " | ||
f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" | ||
), | ||
"name": ( | ||
f"Insertion to " | ||
f"static key-value database " | ||
f"from {key_size}b to {value_size}b " | ||
f"with chunk size of {chunk_size} " | ||
f"on {number_of_entries} entries" | ||
), | ||
"parameters": { | ||
"operation": "insert", | ||
"number_of_entries": number_of_entries, | ||
"key_size": key_size, | ||
"value_size": value_size, | ||
"chunk_size": chunk_size, | ||
}, | ||
} | ||
) | ||
result.append( | ||
{ | ||
"id": ( | ||
f"static-kvdb-replace :: " | ||
f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" | ||
), | ||
"name": ( | ||
f"Replacement in " | ||
f"static key-value database " | ||
f"from {key_size}b to {value_size}b " | ||
f"with chunk size of {chunk_size} " | ||
f"on {number_of_entries} entries" | ||
), | ||
"parameters": { | ||
"operation": "replace", | ||
"number_of_entries": number_of_entries, | ||
"key_size": key_size, | ||
"value_size": value_size, | ||
"chunk_size": chunk_size, | ||
}, | ||
} | ||
) | ||
result.append( | ||
{ | ||
"id": ( | ||
f"static-kvdb-query :: " | ||
f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" | ||
), | ||
"name": ( | ||
f"Query of " | ||
f"static key-value database " | ||
f"from {key_size}b to {value_size}b " | ||
f"with chunk size of {chunk_size} " | ||
f"on {number_of_entries} entries" | ||
), | ||
"parameters": { | ||
"operation": "query", | ||
"number_of_entries": number_of_entries, | ||
"key_size": key_size, | ||
"value_size": value_size, | ||
"chunk_size": chunk_size, | ||
}, | ||
} | ||
) | ||
return result | ||
|
||
|
||
@progress.track(targets()) | ||
def main(operation, number_of_entries, key_size, value_size, chunk_size): | ||
""" | ||
Benchmark a target. | ||
Args: | ||
operation: | ||
operation to benchmark | ||
number_of_entries: | ||
size of the database | ||
key_size: | ||
size of the keys of the database | ||
value_size: | ||
size of the values of the database | ||
chunk_size: | ||
chunks size of the database | ||
""" | ||
|
||
print("Compiling...") | ||
cached_server = Path( | ||
f"static_kvdb.{number_of_entries}.{key_size}.{value_size}.{chunk_size}.server.zip" | ||
) | ||
if cached_server.exists(): | ||
db = StaticKeyValueDatabase( | ||
number_of_entries, | ||
key_size, | ||
value_size, | ||
chunk_size, | ||
compiled=False, | ||
) | ||
server = fhe.Server.load(cached_server) | ||
client = fhe.Client(server.client_specs, keyset_cache_directory=".keys") | ||
else: | ||
db = StaticKeyValueDatabase( | ||
number_of_entries, | ||
key_size, | ||
value_size, | ||
chunk_size, | ||
compiled=True, | ||
configuration=fhe.Configuration( | ||
enable_unsafe_features=True, | ||
use_insecure_key_cache=True, | ||
insecure_key_cache_location=".keys", | ||
), | ||
) | ||
db.module.server.save(cached_server) | ||
|
||
server = db.module.server | ||
client = db.module.client | ||
|
||
db.state = server.run( | ||
client.encrypt( | ||
[ | ||
np.array([1] + db.encode_key(i).tolist() + db.encode_value(i).tolist()) * (i % 2) | ||
for i in range(db.number_of_entries) | ||
], | ||
function_name="reset", | ||
), | ||
function_name="reset", | ||
evaluation_keys=client.evaluation_keys, | ||
) | ||
|
||
print("Generating keys...") | ||
client.keygen() | ||
|
||
if operation == "insert": | ||
benchmark_insert(db, client, server) | ||
elif operation == "replace": | ||
benchmark_replace(db, client, server) | ||
elif operation == "query": | ||
benchmark_query(db, client, server) | ||
else: | ||
message = f"Invalid operation: {operation}" | ||
raise ValueError(message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.