-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(frontend): refactor key value database example to use modules an…
…d benchmark it
- Loading branch information
1 parent
a3cc2f1
commit ace5a45
Showing
5 changed files
with
686 additions
and
317 deletions.
There are no files selected for viewing
270 changes: 270 additions & 0 deletions
270
frontends/concrete-python/benchmarks/key_value_database/static_size.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
import random | ||
|
||
import py_progress_tracker as progress | ||
|
||
from examples.key_value_database.static_size import StaticKeyValueDatabase | ||
|
||
targets = [] | ||
|
||
for number_of_entries in [10, 100, 1000]: | ||
for key_size in [8, 16]: | ||
for value_size in [8, 16]: | ||
for chunk_size in [2, 4]: | ||
targets.append( | ||
{ | ||
"id": ( | ||
f"static-kvdb-insert :: " | ||
f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" | ||
), | ||
"name": ( | ||
f"Insertion to " | ||
f"static key-value database " | ||
f"from {key_size}b to {value_size}b " | ||
f"with chunk size of {chunk_size} " | ||
f"on {number_of_entries} entries" | ||
), | ||
"parameters": { | ||
"operation": "insert", | ||
"number_of_entries": number_of_entries, | ||
"key_size": key_size, | ||
"value_size": value_size, | ||
"chunk_size": chunk_size, | ||
}, | ||
} | ||
) | ||
targets.append( | ||
{ | ||
"id": ( | ||
f"static-kvdb-replace :: " | ||
f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" | ||
), | ||
"name": ( | ||
f"Replacement in " | ||
f"static key-value database " | ||
f"from {key_size}b to {value_size}b " | ||
f"with chunk size of {chunk_size} " | ||
f"on {number_of_entries} entries" | ||
), | ||
"parameters": { | ||
"operation": "replace", | ||
"number_of_entries": number_of_entries, | ||
"key_size": key_size, | ||
"value_size": value_size, | ||
"chunk_size": chunk_size, | ||
}, | ||
} | ||
) | ||
targets.append( | ||
{ | ||
"id": ( | ||
f"static-kvdb-query :: " | ||
f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" | ||
), | ||
"name": ( | ||
f"Query of " | ||
f"static key-value database " | ||
f"from {key_size}b to {value_size}b " | ||
f"with chunk size of {chunk_size} " | ||
f"on {number_of_entries} entries" | ||
), | ||
"parameters": { | ||
"operation": "query", | ||
"number_of_entries": number_of_entries, | ||
"key_size": key_size, | ||
"value_size": value_size, | ||
"chunk_size": chunk_size, | ||
}, | ||
} | ||
) | ||
|
||
|
||
def benchmark_insert(db: StaticKeyValueDatabase): | ||
db.initialize() | ||
|
||
print("Warming up...") | ||
|
||
sample_key = random.randint(0, 2**db.key_size - 1) | ||
sample_value = random.randint(0, 2**db.value_size - 1) | ||
|
||
encoded_sample_key = db.encode_key(sample_key) | ||
encoded_sample_value = db.encode_value(sample_value) | ||
|
||
_, encrypted_sample_key, encrypted_sample_value = db.insert.encrypt( | ||
None, encoded_sample_key, encoded_sample_value | ||
) | ||
ran = db.insert.run(db.state, encrypted_sample_key, encrypted_sample_value) | ||
decrypted = db.insert.decrypt(ran) # noqa: F841 | ||
|
||
def calculate_input_output_size(input_output): | ||
if isinstance(input_output, tuple): | ||
result = sum(len(value.serialize()) for value in input_output) | ||
else: | ||
result = len(input_output.serialize()) | ||
return result / (1024 * 1024) | ||
|
||
progress.measure( | ||
id="input-ciphertext-size-mb", | ||
label="Input Ciphertext Size (MB)", | ||
value=calculate_input_output_size((encrypted_sample_key, encrypted_sample_value)), | ||
) | ||
|
||
for i in range(5): | ||
print(f"Running subsample {i + 1} out of 5...") | ||
|
||
sample_key = random.randint(0, 2**db.key_size - 1) | ||
sample_value = random.randint(0, 2**db.value_size - 1) | ||
|
||
encoded_sample_key = db.encode_key(sample_key) | ||
encoded_sample_value = db.encode_value(sample_value) | ||
|
||
with progress.measure(id="encryption-time-ms", label="Encryption Time (ms)"): | ||
_, encrypted_sample_key, encrypted_sample_value = db.insert.encrypt( | ||
None, | ||
encoded_sample_key, | ||
encoded_sample_value, | ||
) | ||
|
||
with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): | ||
ran = db.insert.run( # noqa: F841 | ||
db.state, | ||
encrypted_sample_key, | ||
encrypted_sample_value, | ||
) | ||
|
||
|
||
def benchmark_replace(db: StaticKeyValueDatabase): | ||
db.initialize( | ||
[ | ||
[1] + db.encode_key(i).tolist() + db.encode_value(i).tolist() | ||
for i in range(db.number_of_entries) | ||
] | ||
) | ||
|
||
print("Warming up...") | ||
|
||
sample_key = random.randint(0, db.number_of_entries - 1) | ||
sample_value = random.randint(0, db.number_of_entries - 1) | ||
|
||
encoded_sample_key = db.encode_key(sample_key) | ||
encoded_sample_value = db.encode_value(sample_value) | ||
|
||
_, encrypted_sample_key, encrypted_sample_value = db.replace.encrypt( | ||
None, encoded_sample_key, encoded_sample_value | ||
) | ||
ran = db.replace.run(db.state, encrypted_sample_key, encrypted_sample_value) | ||
decrypted = db.replace.decrypt(ran) # noqa: F841 | ||
|
||
def calculate_input_output_size(input_output): | ||
if isinstance(input_output, tuple): | ||
result = sum(len(value.serialize()) for value in input_output) | ||
else: | ||
result = len(input_output.serialize()) | ||
return result / (1024 * 1024) | ||
|
||
progress.measure( | ||
id="input-ciphertext-size-mb", | ||
label="Input Ciphertext Size (MB)", | ||
value=calculate_input_output_size((encrypted_sample_key, encrypted_sample_value)), | ||
) | ||
|
||
for i in range(5): | ||
print(f"Running subsample {i + 1} out of 5...") | ||
|
||
sample_key = random.randint(0, db.number_of_entries - 1) | ||
sample_value = random.randint(0, db.number_of_entries - 1) | ||
|
||
encoded_sample_key = db.encode_key(sample_key) | ||
encoded_sample_value = db.encode_value(sample_value) | ||
|
||
with progress.measure(id="encryption-time-ms", label="Encryption Time (ms)"): | ||
_, encrypted_sample_key, encrypted_sample_value = db.replace.encrypt( | ||
None, | ||
encoded_sample_key, | ||
encoded_sample_value, | ||
) | ||
|
||
with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): | ||
ran = db.replace.run( # noqa: F841 | ||
db.state, | ||
encrypted_sample_key, | ||
encrypted_sample_value, | ||
) | ||
|
||
|
||
def benchmark_query(db: StaticKeyValueDatabase): | ||
db.initialize( | ||
[ | ||
[1] + db.encode_key(i).tolist() + db.encode_value(i).tolist() | ||
for i in range(db.number_of_entries) | ||
] | ||
) | ||
|
||
print("Warming up...") | ||
|
||
sample_key = random.randint(0, db.number_of_entries - 1) | ||
encoded_sample_key = db.encode_key(sample_key) | ||
_, encrypted_sample_key = db.query.encrypt(None, encoded_sample_key) | ||
ran = db.query.run(db.state, encrypted_sample_key) | ||
decrypted = db.query.decrypt(ran) # noqa: F841 | ||
|
||
def calculate_input_output_size(input_output): | ||
if isinstance(input_output, tuple): | ||
result = sum(len(value.serialize()) for value in input_output) | ||
else: | ||
result = len(input_output.serialize()) | ||
return result / (1024 * 1024) | ||
|
||
progress.measure( | ||
id="input-ciphertext-size-mb", | ||
label="Input Ciphertext Size (MB)", | ||
value=calculate_input_output_size(encrypted_sample_key), | ||
) | ||
progress.measure( | ||
id="output-ciphertext-size-mb", | ||
label="Output Ciphertext Size (MB)", | ||
value=calculate_input_output_size(ran), | ||
) | ||
|
||
for i in range(5): | ||
print(f"Running subsample {i + 1} out of 5...") | ||
|
||
sample_key = random.randint(0, db.number_of_entries - 1) | ||
encoded_sample_key = db.encode_key(sample_key) | ||
|
||
with progress.measure(id="encryption-time-ms", label="Encryption Time (ms)"): | ||
_, encrypted_sample_key = db.query.encrypt(None, encoded_sample_key) | ||
|
||
with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): | ||
ran = db.query.run(db.state, encrypted_sample_key) # noqa: F841 | ||
|
||
|
||
@progress.track(targets) | ||
def main(operation, number_of_entries, key_size, value_size, chunk_size): | ||
print("Compiling...") | ||
with progress.measure(id="compilation-time-ms", label="Compilation Time (ms)"): | ||
db = StaticKeyValueDatabase(number_of_entries, key_size, value_size, chunk_size) | ||
|
||
progress.measure( | ||
id="complexity", | ||
label="Complexity", | ||
value=db.module.complexity, | ||
) | ||
|
||
print("Generating keys...") | ||
with progress.measure(id="key-generation-time-ms", label="Key Generation Time (ms)"): | ||
db.keygen(force=True) | ||
|
||
progress.measure( | ||
id="evaluation-key-size-mb", | ||
label="Evaluation Key Size (MB)", | ||
value=(len(db.module.keys.evaluation.serialize()) / (1024 * 1024)), | ||
) | ||
|
||
if operation == "insert": | ||
benchmark_insert(db) | ||
elif operation == "replace": | ||
benchmark_replace(db) | ||
elif operation == "query": | ||
benchmark_query(db) | ||
else: | ||
raise ValueError(f"Invalid operation: {operation}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.