Skip to content

Commit

Permalink
feat(frontend): refactor key value database example to use modules an…
Browse files Browse the repository at this point in the history
…d benchmark it
  • Loading branch information
umut-sahin committed Aug 6, 2024
1 parent a3cc2f1 commit ace5a45
Show file tree
Hide file tree
Showing 5 changed files with 686 additions and 317 deletions.
270 changes: 270 additions & 0 deletions frontends/concrete-python/benchmarks/key_value_database/static_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import random

import py_progress_tracker as progress

from examples.key_value_database.static_size import StaticKeyValueDatabase

targets = []

for number_of_entries in [10, 100, 1000]:
for key_size in [8, 16]:
for value_size in [8, 16]:
for chunk_size in [2, 4]:
targets.append(
{
"id": (
f"static-kvdb-insert :: "
f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}"
),
"name": (
f"Insertion to "
f"static key-value database "
f"from {key_size}b to {value_size}b "
f"with chunk size of {chunk_size} "
f"on {number_of_entries} entries"
),
"parameters": {
"operation": "insert",
"number_of_entries": number_of_entries,
"key_size": key_size,
"value_size": value_size,
"chunk_size": chunk_size,
},
}
)
targets.append(
{
"id": (
f"static-kvdb-replace :: "
f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}"
),
"name": (
f"Replacement in "
f"static key-value database "
f"from {key_size}b to {value_size}b "
f"with chunk size of {chunk_size} "
f"on {number_of_entries} entries"
),
"parameters": {
"operation": "replace",
"number_of_entries": number_of_entries,
"key_size": key_size,
"value_size": value_size,
"chunk_size": chunk_size,
},
}
)
targets.append(
{
"id": (
f"static-kvdb-query :: "
f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}"
),
"name": (
f"Query of "
f"static key-value database "
f"from {key_size}b to {value_size}b "
f"with chunk size of {chunk_size} "
f"on {number_of_entries} entries"
),
"parameters": {
"operation": "query",
"number_of_entries": number_of_entries,
"key_size": key_size,
"value_size": value_size,
"chunk_size": chunk_size,
},
}
)


def benchmark_insert(db: StaticKeyValueDatabase):
db.initialize()

print("Warming up...")

sample_key = random.randint(0, 2**db.key_size - 1)
sample_value = random.randint(0, 2**db.value_size - 1)

encoded_sample_key = db.encode_key(sample_key)
encoded_sample_value = db.encode_value(sample_value)

_, encrypted_sample_key, encrypted_sample_value = db.insert.encrypt(
None, encoded_sample_key, encoded_sample_value
)
ran = db.insert.run(db.state, encrypted_sample_key, encrypted_sample_value)
decrypted = db.insert.decrypt(ran) # noqa: F841

def calculate_input_output_size(input_output):
if isinstance(input_output, tuple):
result = sum(len(value.serialize()) for value in input_output)
else:
result = len(input_output.serialize())
return result / (1024 * 1024)

progress.measure(
id="input-ciphertext-size-mb",
label="Input Ciphertext Size (MB)",
value=calculate_input_output_size((encrypted_sample_key, encrypted_sample_value)),
)

for i in range(5):
print(f"Running subsample {i + 1} out of 5...")

sample_key = random.randint(0, 2**db.key_size - 1)
sample_value = random.randint(0, 2**db.value_size - 1)

encoded_sample_key = db.encode_key(sample_key)
encoded_sample_value = db.encode_value(sample_value)

with progress.measure(id="encryption-time-ms", label="Encryption Time (ms)"):
_, encrypted_sample_key, encrypted_sample_value = db.insert.encrypt(
None,
encoded_sample_key,
encoded_sample_value,
)

with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"):
ran = db.insert.run( # noqa: F841
db.state,
encrypted_sample_key,
encrypted_sample_value,
)


def benchmark_replace(db: StaticKeyValueDatabase):
db.initialize(
[
[1] + db.encode_key(i).tolist() + db.encode_value(i).tolist()
for i in range(db.number_of_entries)
]
)

print("Warming up...")

sample_key = random.randint(0, db.number_of_entries - 1)
sample_value = random.randint(0, db.number_of_entries - 1)

encoded_sample_key = db.encode_key(sample_key)
encoded_sample_value = db.encode_value(sample_value)

_, encrypted_sample_key, encrypted_sample_value = db.replace.encrypt(
None, encoded_sample_key, encoded_sample_value
)
ran = db.replace.run(db.state, encrypted_sample_key, encrypted_sample_value)
decrypted = db.replace.decrypt(ran) # noqa: F841

def calculate_input_output_size(input_output):
if isinstance(input_output, tuple):
result = sum(len(value.serialize()) for value in input_output)
else:
result = len(input_output.serialize())
return result / (1024 * 1024)

progress.measure(
id="input-ciphertext-size-mb",
label="Input Ciphertext Size (MB)",
value=calculate_input_output_size((encrypted_sample_key, encrypted_sample_value)),
)

for i in range(5):
print(f"Running subsample {i + 1} out of 5...")

sample_key = random.randint(0, db.number_of_entries - 1)
sample_value = random.randint(0, db.number_of_entries - 1)

encoded_sample_key = db.encode_key(sample_key)
encoded_sample_value = db.encode_value(sample_value)

with progress.measure(id="encryption-time-ms", label="Encryption Time (ms)"):
_, encrypted_sample_key, encrypted_sample_value = db.replace.encrypt(
None,
encoded_sample_key,
encoded_sample_value,
)

with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"):
ran = db.replace.run( # noqa: F841
db.state,
encrypted_sample_key,
encrypted_sample_value,
)


def benchmark_query(db: StaticKeyValueDatabase):
db.initialize(
[
[1] + db.encode_key(i).tolist() + db.encode_value(i).tolist()
for i in range(db.number_of_entries)
]
)

print("Warming up...")

sample_key = random.randint(0, db.number_of_entries - 1)
encoded_sample_key = db.encode_key(sample_key)
_, encrypted_sample_key = db.query.encrypt(None, encoded_sample_key)
ran = db.query.run(db.state, encrypted_sample_key)
decrypted = db.query.decrypt(ran) # noqa: F841

def calculate_input_output_size(input_output):
if isinstance(input_output, tuple):
result = sum(len(value.serialize()) for value in input_output)
else:
result = len(input_output.serialize())
return result / (1024 * 1024)

progress.measure(
id="input-ciphertext-size-mb",
label="Input Ciphertext Size (MB)",
value=calculate_input_output_size(encrypted_sample_key),
)
progress.measure(
id="output-ciphertext-size-mb",
label="Output Ciphertext Size (MB)",
value=calculate_input_output_size(ran),
)

for i in range(5):
print(f"Running subsample {i + 1} out of 5...")

sample_key = random.randint(0, db.number_of_entries - 1)
encoded_sample_key = db.encode_key(sample_key)

with progress.measure(id="encryption-time-ms", label="Encryption Time (ms)"):
_, encrypted_sample_key = db.query.encrypt(None, encoded_sample_key)

with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"):
ran = db.query.run(db.state, encrypted_sample_key) # noqa: F841


@progress.track(targets)
def main(operation, number_of_entries, key_size, value_size, chunk_size):
print("Compiling...")
with progress.measure(id="compilation-time-ms", label="Compilation Time (ms)"):
db = StaticKeyValueDatabase(number_of_entries, key_size, value_size, chunk_size)

progress.measure(
id="complexity",
label="Complexity",
value=db.module.complexity,
)

print("Generating keys...")
with progress.measure(id="key-generation-time-ms", label="Key Generation Time (ms)"):
db.keygen(force=True)

progress.measure(
id="evaluation-key-size-mb",
label="Evaluation Key Size (MB)",
value=(len(db.module.keys.evaluation.serialize()) / (1024 * 1024)),
)

if operation == "insert":
benchmark_insert(db)
elif operation == "replace":
benchmark_replace(db)
elif operation == "query":
benchmark_query(db)
else:
raise ValueError(f"Invalid operation: {operation}")
4 changes: 4 additions & 0 deletions frontends/concrete-python/concrete/fhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
DebugArtifacts,
EncryptionStatus,
Exactness,
)
from .compilation import FheFunction as Function
from .compilation import FheModule as Module
from .compilation import (
FunctionDebugArtifacts,
Input,
Keys,
Expand Down
Loading

0 comments on commit ace5a45

Please sign in to comment.