From d3c5d64c54746716ebcc80e075ed595801af983f Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 5 Aug 2024 11:39:27 +0300 Subject: [PATCH] feat(frontend): refactor key value database example to use modules, test and benchmark it --- frontends/concrete-python/.pylintrc | 1 + frontends/concrete-python/.ruff.toml | 2 +- frontends/concrete-python/Makefile | 2 +- .../concrete-python/benchmarks/static_kvdb.py | 335 ++++++++++++++ .../concrete-python/concrete/fhe/__init__.py | 4 + .../key_value_database/dynamic-size.py | 268 ----------- .../key_value_database.ipynb | 7 +- .../key_value_database/static-size.py | 317 ------------- .../key_value_database/static_size.py | 426 ++++++++++++++++++ .../tests/execution/test_examples.py | 183 ++++++++ 10 files changed, 952 insertions(+), 593 deletions(-) create mode 100644 frontends/concrete-python/benchmarks/static_kvdb.py delete mode 100644 frontends/concrete-python/examples/key_value_database/dynamic-size.py delete mode 100644 frontends/concrete-python/examples/key_value_database/static-size.py create mode 100644 frontends/concrete-python/examples/key_value_database/static_size.py create mode 100644 frontends/concrete-python/tests/execution/test_examples.py diff --git a/frontends/concrete-python/.pylintrc b/frontends/concrete-python/.pylintrc index bd89d34277..e9f7524b8a 100644 --- a/frontends/concrete-python/.pylintrc +++ b/frontends/concrete-python/.pylintrc @@ -193,6 +193,7 @@ good-names=i, xs, on, of, + db, _ # Good variable names regexes, separated by a comma. If names match any regex, diff --git a/frontends/concrete-python/.ruff.toml b/frontends/concrete-python/.ruff.toml index 24124f7d96..c43ae0d0f6 100644 --- a/frontends/concrete-python/.ruff.toml +++ b/frontends/concrete-python/.ruff.toml @@ -9,7 +9,7 @@ ignore = [ "A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100", "ERA001", "SIM105", "RET504", "TID252", "PD011", "I001", "UP015", "C901", "A001", "SIM118", "PGH003", "PLW2901", "PLR0915", "C416", "PLR0911", "PLR0912", "PLR0913", "RUF005", "PLR2004", "S110", "PLC1901", - "E731", "RET507", "SIM102" + "E731", "RET507", "SIM102", "N805", ] [per-file-ignores] diff --git a/frontends/concrete-python/Makefile b/frontends/concrete-python/Makefile index 52607561fe..a6a5526f4f 100644 --- a/frontends/concrete-python/Makefile +++ b/frontends/concrete-python/Makefile @@ -162,7 +162,7 @@ check-sanitize-notebooks: mypy: eval $(shell make silent_cp_activate) - mypy concrete examples scripts tests benchmarks --ignore-missing-imports + mypy concrete examples scripts tests benchmarks --ignore-missing-imports --explicit-package-bases pydocstyle: eval $(shell make silent_cp_activate) diff --git a/frontends/concrete-python/benchmarks/static_kvdb.py b/frontends/concrete-python/benchmarks/static_kvdb.py new file mode 100644 index 0000000000..1ea6a111e9 --- /dev/null +++ b/frontends/concrete-python/benchmarks/static_kvdb.py @@ -0,0 +1,335 @@ +""" +Benchmarks of the static key value database example. +""" + +import random +from pathlib import Path + +import numpy as np +import py_progress_tracker as progress + +from concrete import fhe +from examples.key_value_database.static_size import StaticKeyValueDatabase + + +def benchmark_insert(db: StaticKeyValueDatabase, client: fhe.Client, server: fhe.Server): + """ + Benchmark inserting an entry to the database. + """ + + print("Warming up...") + + sample_key = random.randint(0, 2**db.key_size - 1) + sample_value = random.randint(0, 2**db.value_size - 1) + + # Initial state only contains odd keys for benchmarks. + # To avoid collisions, we'll make sure that sample_key is even. + if sample_key % 2 == 1: + sample_key -= 1 + + encoded_sample_key = db.encode_key(sample_key) + encoded_sample_value = db.encode_value(sample_value) + + _, encrypted_sample_key, encrypted_sample_value = client.encrypt( # type: ignore + None, + encoded_sample_key, + encoded_sample_value, + function_name="insert", + ) + ran = server.run( # noqa: F841 # pylint: disable=unused-variable + db.state, + encrypted_sample_key, + encrypted_sample_value, + function_name="insert", + evaluation_keys=client.evaluation_keys, + ) + + for i in range(5): + print(f"Running subsample {i + 1} out of 5...") + + sample_key = random.randint(0, 2**db.key_size - 1) + sample_value = random.randint(0, 2**db.value_size - 1) + + if sample_key % 2 == 1: + sample_key -= 1 + + encoded_sample_key = db.encode_key(sample_key) + encoded_sample_value = db.encode_value(sample_value) + + _, encrypted_sample_key, encrypted_sample_value = client.encrypt( # type: ignore + None, + encoded_sample_key, + encoded_sample_value, + function_name="insert", + ) + + with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): + ran = server.run( # noqa: F841 + db.state, + encrypted_sample_key, + encrypted_sample_value, + function_name="insert", + evaluation_keys=client.evaluation_keys, + ) + + +def benchmark_replace(db: StaticKeyValueDatabase, client: fhe.Client, server: fhe.Server): + """ + Benchmark replacing an entry in the database. + """ + + print("Warming up...") + + sample_key = random.randint(0, db.number_of_entries // 2) * 2 + sample_value = random.randint(0, db.number_of_entries // 2) * 2 + + # Initial state only contains odd keys for benchmarks. + # To actually replace, we'll make sure that sample_key is odd. + if sample_key % 2 == 0: + sample_key += 1 + + encoded_sample_key = db.encode_key(sample_key) + encoded_sample_value = db.encode_value(sample_value) + + _, encrypted_sample_key, encrypted_sample_value = client.encrypt( # type: ignore + None, + encoded_sample_key, + encoded_sample_value, + function_name="replace", + ) + ran = server.run( # noqa: F841 # pylint: disable=unused-variable + db.state, + encrypted_sample_key, + encrypted_sample_value, + function_name="replace", + evaluation_keys=client.evaluation_keys, + ) + + for i in range(5): + print(f"Running subsample {i + 1} out of 5...") + + sample_key = random.randint(0, db.number_of_entries - 1) + sample_value = random.randint(0, db.number_of_entries - 1) + + if sample_key % 2 == 0: + sample_key += 1 + + encoded_sample_key = db.encode_key(sample_key) + encoded_sample_value = db.encode_value(sample_value) + + _, encrypted_sample_key, encrypted_sample_value = client.encrypt( # type: ignore + None, + encoded_sample_key, + encoded_sample_value, + function_name="replace", + ) + + with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): + ran = server.run( # noqa: F841 + db.state, + encrypted_sample_key, + encrypted_sample_value, + function_name="replace", + evaluation_keys=client.evaluation_keys, + ) + + +def benchmark_query(db: StaticKeyValueDatabase, client: fhe.Client, server: fhe.Server): + """ + Benchmark querying a key in the database. + """ + + print("Warming up...") + + sample_key = random.randint(0, db.number_of_entries - 1) + encoded_sample_key = db.encode_key(sample_key) + + _, encrypted_sample_key = client.encrypt( # type: ignore + None, + encoded_sample_key, + function_name="query", + ) + ran = server.run( # noqa: F841 # pylint: disable=unused-variable + db.state, + encrypted_sample_key, + function_name="query", + evaluation_keys=client.evaluation_keys, + ) + + for i in range(5): + print(f"Running subsample {i + 1} out of 5...") + + sample_key = random.randint(0, db.number_of_entries - 1) + encoded_sample_key = db.encode_key(sample_key) + + _, encrypted_sample_key = client.encrypt( # type: ignore + None, + encoded_sample_key, + function_name="query", + ) + with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): + ran = server.run( # noqa: F841 + db.state, + encrypted_sample_key, + function_name="query", + evaluation_keys=client.evaluation_keys, + ) + + +def targets(): + """ + Generates targets to benchmark. + """ + + result = [] + for number_of_entries in [8, 16]: + for key_size in [8, 16]: + for value_size in [8, 16]: + for chunk_size in [2, 4]: + result.append( + { + "id": ( + f"static-kvdb-insert :: " + f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" + ), + "name": ( + f"Insertion to " + f"static key-value database " + f"from {key_size}b to {value_size}b " + f"with chunk size of {chunk_size} " + f"on {number_of_entries} entries" + ), + "parameters": { + "operation": "insert", + "number_of_entries": number_of_entries, + "key_size": key_size, + "value_size": value_size, + "chunk_size": chunk_size, + }, + } + ) + result.append( + { + "id": ( + f"static-kvdb-replace :: " + f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" + ), + "name": ( + f"Replacement in " + f"static key-value database " + f"from {key_size}b to {value_size}b " + f"with chunk size of {chunk_size} " + f"on {number_of_entries} entries" + ), + "parameters": { + "operation": "replace", + "number_of_entries": number_of_entries, + "key_size": key_size, + "value_size": value_size, + "chunk_size": chunk_size, + }, + } + ) + result.append( + { + "id": ( + f"static-kvdb-query :: " + f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" + ), + "name": ( + f"Query of " + f"static key-value database " + f"from {key_size}b to {value_size}b " + f"with chunk size of {chunk_size} " + f"on {number_of_entries} entries" + ), + "parameters": { + "operation": "query", + "number_of_entries": number_of_entries, + "key_size": key_size, + "value_size": value_size, + "chunk_size": chunk_size, + }, + } + ) + return result + + +@progress.track(targets()) +def main(operation, number_of_entries, key_size, value_size, chunk_size): + """ + Benchmark a target. + + Args: + operation: + operation to benchmark + + number_of_entries: + size of the database + + key_size: + size of the keys of the database + + value_size: + size of the values of the database + + chunk_size: + chunks size of the database + """ + + print("Compiling...") + cached_server = Path( + f"static_kvdb.{number_of_entries}.{key_size}.{value_size}.{chunk_size}.server.zip" + ) + if cached_server.exists(): + db = StaticKeyValueDatabase( + number_of_entries, + key_size, + value_size, + chunk_size, + compiled=False, + ) + server = fhe.Server.load(cached_server) + client = fhe.Client(server.client_specs, keyset_cache_directory=".keys") + else: + db = StaticKeyValueDatabase( + number_of_entries, + key_size, + value_size, + chunk_size, + compiled=True, + configuration=fhe.Configuration( + enable_unsafe_features=True, + use_insecure_key_cache=True, + insecure_key_cache_location=".keys", + ), + ) + db.module.server.save(cached_server) + + server = db.module.server + client = db.module.client + + db.state = server.run( + client.encrypt( + [ + np.array([1] + db.encode_key(i).tolist() + db.encode_value(i).tolist()) * (i % 2) + for i in range(db.number_of_entries) + ], + function_name="reset", + ), + function_name="reset", + evaluation_keys=client.evaluation_keys, + ) + + print("Generating keys...") + client.keygen() + + if operation == "insert": + benchmark_insert(db, client, server) + elif operation == "replace": + benchmark_replace(db, client, server) + elif operation == "query": + benchmark_query(db, client, server) + else: + message = f"Invalid operation: {operation}" + raise ValueError(message) diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index 723616cb49..f82430727c 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -24,6 +24,10 @@ DebugArtifacts, EncryptionStatus, Exactness, +) +from .compilation import FheFunction as Function +from .compilation import FheModule as Module +from .compilation import ( FunctionDebugArtifacts, Input, Keys, diff --git a/frontends/concrete-python/examples/key_value_database/dynamic-size.py b/frontends/concrete-python/examples/key_value_database/dynamic-size.py deleted file mode 100644 index 478b5b3597..0000000000 --- a/frontends/concrete-python/examples/key_value_database/dynamic-size.py +++ /dev/null @@ -1,268 +0,0 @@ -import time -from typing import List - -import numpy as np - -from concrete import fhe - -CHUNK_SIZE = 4 - -KEY_SIZE = 32 -VALUE_SIZE = 32 - -assert KEY_SIZE % CHUNK_SIZE == 0 -assert VALUE_SIZE % CHUNK_SIZE == 0 - -NUMBER_OF_KEY_CHUNKS = KEY_SIZE // CHUNK_SIZE -NUMBER_OF_VALUE_CHUNKS = VALUE_SIZE // CHUNK_SIZE - - -def encode(number, width): - binary_repr = np.binary_repr(number, width=width) - blocks = [binary_repr[i : (i + CHUNK_SIZE)] for i in range(0, len(binary_repr), CHUNK_SIZE)] - return np.array([int(block, 2) for block in blocks]) - - -def encode_key(number): - return encode(number, width=KEY_SIZE) - - -def encode_value(number): - return encode(number, width=VALUE_SIZE) - - -def decode(encoded_number): - result = 0 - for i in range(len(encoded_number)): - result += 2 ** (CHUNK_SIZE * i) * encoded_number[(len(encoded_number) - i) - 1] - return result - - -keep_if_match_lut = fhe.LookupTable([0 for _ in range(16)] + [i for i in range(16)]) -keep_if_no_match_lut = fhe.LookupTable([i for i in range(16)] + [0 for _ in range(16)]) - - -def _replace_impl(key, value, candidate_key, candidate_value): - number_of_matching_chunks = np.sum((candidate_key - key) == 0) - fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) - - match = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS - - packed_match_and_value = (2**CHUNK_SIZE) * match + value - value_if_match_else_zeros = keep_if_match_lut[packed_match_and_value] - - packed_match_and_candidate_value = (2**CHUNK_SIZE) * match + candidate_value - zeros_if_match_else_candidate_value = keep_if_no_match_lut[packed_match_and_candidate_value] - - return value_if_match_else_zeros + zeros_if_match_else_candidate_value - - -def _query_impl(key, candidate_key, candidate_value): - number_of_matching_chunks = np.sum((candidate_key - key) == 0) - fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) - - match = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS - - packed_match_and_candidate_value = (2**CHUNK_SIZE) * match + candidate_value - candidate_value_if_match_else_zeros = keep_if_match_lut[packed_match_and_candidate_value] - - return fhe.array([match, *candidate_value_if_match_else_zeros]) - - -class KeyValueDatabase: - _state: List[np.ndarray] - - _replace_circuit: fhe.Circuit - _query_circuit: fhe.Circuit - - def __init__(self): - self._state = [] - - sample_state = [[encode_key(i), encode_value(i * 2)] for i in range(10)] - replace_inputset = [ - ( - # key - encode_key(i), - # value - encode_value(i), - # candidate_key - entry[0], - # candidate_value - entry[1], - ) - for i in range(10) - for entry in sample_state - ] - query_inputset = [ - ( - # key - encode_key(i), - # candidate_key - entry[0], - # candidate_value - entry[1], - ) - for i in range(10) - for entry in sample_state - ] - - configuration = fhe.Configuration( - enable_unsafe_features=True, - use_insecure_key_cache=True, - insecure_key_cache_location=".keys", - ) - - replace_compiler = fhe.Compiler( - _replace_impl, - { - "key": "encrypted", - "value": "encrypted", - "candidate_key": "encrypted", - "candidate_value": "encrypted", - }, - ) - query_compiler = fhe.Compiler( - _query_impl, - { - "key": "encrypted", - "candidate_key": "encrypted", - "candidate_value": "encrypted", - }, - ) - - print() - - print("Compiling replacement circuit...") - start = time.time() - self._replace_circuit = replace_compiler.compile(replace_inputset, configuration) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Compiling query circuit...") - start = time.time() - self._query_circuit = query_compiler.compile(query_inputset, configuration) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Generating replacement keys...") - start = time.time() - self._replace_circuit.keygen() - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Generating query keys...") - start = time.time() - self._query_circuit.keygen() - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - def insert(self, key, value): - print() - print("Inserting...") - start = time.time() - - self._state.append([encode_key(key), encode_value(value)]) - - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - def replace(self, key, value): - print() - print("Replacing...") - start = time.time() - - encoded_key = encode_key(key) - encoded_value = encode_value(value) - - for entry in self._state: - entry[1] = self._replace_circuit.encrypt_run_decrypt(encoded_key, encoded_value, *entry) - - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - def query(self, key): - print() - print("Querying...") - start = time.time() - - encoded_key = encode_key(key) - - accumulation = np.zeros(1 + NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) - for entry in self._state: - contribution = self._query_circuit.encrypt_run_decrypt(encoded_key, *entry) - accumulation += contribution - - match_count = accumulation[0] - if match_count > 1: - message = "Key inserted multiple times" - raise RuntimeError(message) - - result = decode(accumulation[1:]) if match_count == 1 else None - - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - return result - - -db = KeyValueDatabase() - -# Test: Insert/Query -db.insert(3, 4) -assert db.query(3) == 4 - -db.replace(3, 1) -assert db.query(3) == 1 - -# Test: Insert/Query -db.insert(25, 40) -assert db.query(25) == 40 - -# Test: Query Not Found -assert db.query(4) is None - -# Test: Replace/Query -db.replace(3, 5) -assert db.query(3) == 5 - -# Define lower/upper bounds for the key -minimum_key = 0 -maximum_key = 2**KEY_SIZE - 1 -# Define lower/upper bounds for the value -minimum_value = 0 -maximum_value = 2**VALUE_SIZE - 1 - -# Test: Insert/Replace/Query Bounds -# Insert (key: minimum_key , value: minimum_value) into the database -db.insert(minimum_key, minimum_value) - -# Query the database for the key=minimum_key -# The value minimum_value should be returned -assert db.query(minimum_key) == minimum_value - -# Insert (key: maximum_key , value: maximum_value) into the database -db.insert(maximum_key, maximum_value) - -# Query the database for the key=maximum_key -# The value maximum_value should be returned -assert db.query(maximum_key) == maximum_value - -# Replace the value of key=minimum_key with maximum_value -db.replace(minimum_key, maximum_value) - -# Query the database for the key=minimum_key -# The value maximum_value should be returned -assert db.query(minimum_key) == maximum_value - -# Replace the value of key=maximum_key with minimum_value -db.replace(maximum_key, minimum_value) - -# Query the database for the key=maximum_key -# The value minimum_value should be returned -assert db.query(maximum_key) == minimum_value diff --git a/frontends/concrete-python/examples/key_value_database/key_value_database.ipynb b/frontends/concrete-python/examples/key_value_database/key_value_database.ipynb index 06aa25a77e..54f34fbb3c 100644 --- a/frontends/concrete-python/examples/key_value_database/key_value_database.ipynb +++ b/frontends/concrete-python/examples/key_value_database/key_value_database.ipynb @@ -8,12 +8,7 @@ "\n", "This is an interactive tutorial of an Encrypted Key Value Database. The database allows for three operations, **Insert, Replace, and Query**. All the operations are implemented as fully-homomorphic encrypted circuits.\n", "\n", - "In `frontends/concrete-python/examples/key_value_database/`, you will find the following files:\n", - "\n", - "- `static-size.py`: This file contains a static size database implementation, meaning that the number of entries is given as a parameter at the beginning.\n", - "- `dynamic-size.py`: This file contains a dynamic size database implementation, meaning that the database starts as a zero entry database, and is grown as needed.\n", - "\n", - "This tutorial goes over the statically-sized database implementation. The dynamic database implementation is very similar, and the reader is encouraged to look at the code to see the differences.\n" + "In `frontends/concrete-python/examples/key_value_database/static_size.py`, you can find the full implementation\n" ] }, { diff --git a/frontends/concrete-python/examples/key_value_database/static-size.py b/frontends/concrete-python/examples/key_value_database/static-size.py deleted file mode 100644 index 61f7c5f920..0000000000 --- a/frontends/concrete-python/examples/key_value_database/static-size.py +++ /dev/null @@ -1,317 +0,0 @@ -import time - -import numpy as np - -from concrete import fhe - -NUMBER_OF_ENTRIES = 5 -CHUNK_SIZE = 4 - -KEY_SIZE = 32 -VALUE_SIZE = 32 - -assert KEY_SIZE % CHUNK_SIZE == 0 -assert VALUE_SIZE % CHUNK_SIZE == 0 - -NUMBER_OF_KEY_CHUNKS = KEY_SIZE // CHUNK_SIZE -NUMBER_OF_VALUE_CHUNKS = VALUE_SIZE // CHUNK_SIZE - -STATE_SHAPE = (NUMBER_OF_ENTRIES, 1 + NUMBER_OF_KEY_CHUNKS + NUMBER_OF_VALUE_CHUNKS) - -FLAG = 0 -KEY = slice(1, 1 + NUMBER_OF_KEY_CHUNKS) -VALUE = slice(1 + NUMBER_OF_KEY_CHUNKS, None) - - -def encode(number: int, width: int) -> np.ndarray: - binary_repr = np.binary_repr(number, width=width) - blocks = [binary_repr[i : (i + CHUNK_SIZE)] for i in range(0, len(binary_repr), CHUNK_SIZE)] - return np.array([int(block, 2) for block in blocks]) - - -def encode_key(number: int) -> np.ndarray: - return encode(number, width=KEY_SIZE) - - -def encode_value(number: int) -> np.ndarray: - return encode(number, width=VALUE_SIZE) - - -def decode(encoded_number: np.ndarray) -> int: - result = 0 - for i in range(len(encoded_number)): - result += 2 ** (CHUNK_SIZE * i) * encoded_number[(len(encoded_number) - i) - 1] - return result - - -keep_selected_lut = fhe.LookupTable([0 for _ in range(16)] + [i for i in range(16)]) - - -def _insert_impl(state, key, value): - flags = state[:, FLAG] - - selection = fhe.zeros(NUMBER_OF_ENTRIES) - - found = fhe.zero() - for i in range(NUMBER_OF_ENTRIES): - packed_flag_and_already_found = (found * 2) + flags[i] - is_selected = packed_flag_and_already_found == 0 - - selection[i] = is_selected - found += is_selected - - state_update = fhe.zeros(STATE_SHAPE) - state_update[:, FLAG] = selection - - selection = selection.reshape((-1, 1)) - - packed_selection_and_key = (selection * (2**CHUNK_SIZE)) + key - key_update = keep_selected_lut[packed_selection_and_key] - - packed_selection_and_value = selection * (2**CHUNK_SIZE) + value - value_update = keep_selected_lut[packed_selection_and_value] - - state_update[:, KEY] = key_update - state_update[:, VALUE] = value_update - - new_state = state + state_update - return new_state - - -def _replace_impl(state, key, value): - flags = state[:, FLAG] - keys = state[:, KEY] - values = state[:, VALUE] - - number_of_matching_chunks = np.sum((keys - key) == 0, axis=1) - fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) - - equal_rows = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS - selection = (flags * 2 + equal_rows == 3).reshape((-1, 1)) - - packed_selection_and_value = selection * (2**CHUNK_SIZE) + value - set_value = keep_selected_lut[packed_selection_and_value] - - inverse_selection = 1 - selection - packed_inverse_selection_and_values = inverse_selection * (2**CHUNK_SIZE) + values - kept_values = keep_selected_lut[packed_inverse_selection_and_values] - - new_values = kept_values + set_value - state[:, VALUE] = new_values - - return state - - -def _query_impl(state, key): - keys = state[:, KEY] - values = state[:, VALUE] - - number_of_matching_chunks = np.sum((keys - key) == 0, axis=1) - fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) - - selection = (number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS).reshape((-1, 1)) - found = np.sum(selection) - fhe.hint(found, can_store=NUMBER_OF_ENTRIES) - - packed_selection_and_values = selection * (2**CHUNK_SIZE) + values - value_selection = keep_selected_lut[packed_selection_and_values] - value = np.sum(value_selection, axis=0) - fhe.hint(value, bit_width=CHUNK_SIZE) - - return fhe.array([found, *value]) - - -class KeyValueDatabase: - _state: np.ndarray - - _insert_circuit: fhe.Circuit - _replace_circuit: fhe.Circuit - _query_circuit: fhe.Circuit - - def __init__(self): - self._state = np.zeros(STATE_SHAPE, dtype=np.int64) - - sample_state = np.array( - [ - [i % 2] + encode_key(i).tolist() + encode_value(i).tolist() - for i in range(STATE_SHAPE[0]) - ] - ) - - insert_replace_inputset = [ - ( - # state - sample_state, - # key - encode_key(i), - # value - encode_key(i), - ) - for i in range(20) - ] - query_inputset = [ - ( - # state - sample_state, - # key - encode_key(i), - ) - for i in range(20) - ] - - configuration = fhe.Configuration( - enable_unsafe_features=True, - use_insecure_key_cache=True, - insecure_key_cache_location=".keys", - ) - - insert_compiler = fhe.Compiler( - _insert_impl, - {"state": "encrypted", "key": "encrypted", "value": "encrypted"}, - ) - replace_compiler = fhe.Compiler( - _replace_impl, - {"state": "encrypted", "key": "encrypted", "value": "encrypted"}, - ) - query_compiler = fhe.Compiler( - _query_impl, - {"state": "encrypted", "key": "encrypted"}, - ) - - print() - - print("Compiling insertion circuit...") - start = time.time() - self._insert_circuit = insert_compiler.compile(insert_replace_inputset, configuration) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Compiling replacement circuit...") - start = time.time() - self._replace_circuit = replace_compiler.compile(insert_replace_inputset, configuration) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Compiling query circuit...") - start = time.time() - self._query_circuit = query_compiler.compile(query_inputset, configuration) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Generating insertion keys...") - start = time.time() - self._insert_circuit.keygen() - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Generating replacement keys...") - start = time.time() - self._replace_circuit.keygen() - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Generating query keys...") - start = time.time() - self._query_circuit.keygen() - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - def insert(self, key, value): - print() - print("Inserting...") - start = time.time() - self._state = self._insert_circuit.encrypt_run_decrypt( - self._state, encode_key(key), encode_value(value) - ) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - def replace(self, key, value): - print() - print("Replacing...") - start = time.time() - self._state = self._replace_circuit.encrypt_run_decrypt( - self._state, encode_key(key), encode_value(value) - ) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - def query(self, key): - print() - print("Querying...") - start = time.time() - result = self._query_circuit.encrypt_run_decrypt(self._state, encode_key(key)) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - if result[0] == 0: - return None - - return decode(result[1:]) - - -db = KeyValueDatabase() - -# Test: Insert/Query -db.insert(3, 4) -assert db.query(3) == 4 - -db.replace(3, 1) -assert db.query(3) == 1 - -# Test: Insert/Query -db.insert(25, 40) -assert db.query(25) == 40 - -# Test: Query Not Found -assert db.query(4) is None - -# Test: Replace/Query -db.replace(3, 5) -assert db.query(3) == 5 - -# Define lower/upper bounds for the key -minimum_key = 1 -maximum_key = 2**KEY_SIZE - 1 -# Define lower/upper bounds for the value -minimum_value = 1 -maximum_value = 2**VALUE_SIZE - 1 - -# Test: Insert/Replace/Query Bounds -# Insert (key: minimum_key , value: minimum_value) into the database -db.insert(minimum_key, minimum_value) - -# Query the database for the key=minimum_key -# The value minimum_value should be returned -assert db.query(minimum_key) == minimum_value - -# Insert (key: maximum_key , value: maximum_value) into the database -db.insert(maximum_key, maximum_value) - -# Query the database for the key=maximum_key -# The value maximum_value should be returned -assert db.query(maximum_key) == maximum_value - -# Replace the value of key=minimum_key with maximum_value -db.replace(minimum_key, maximum_value) - -# Query the database for the key=minimum_key -# The value maximum_value should be returned -assert db.query(minimum_key) == maximum_value - -# Replace the value of key=maximum_key with minimum_value -db.replace(maximum_key, minimum_value) - -# Query the database for the key=maximum_key -# The value minimum_value should be returned -assert db.query(maximum_key) == minimum_value diff --git a/frontends/concrete-python/examples/key_value_database/static_size.py b/frontends/concrete-python/examples/key_value_database/static_size.py new file mode 100644 index 0000000000..c82bdb2d92 --- /dev/null +++ b/frontends/concrete-python/examples/key_value_database/static_size.py @@ -0,0 +1,426 @@ +import time +from typing import List, Optional, Tuple, Union + +import numpy as np + +from concrete import fhe + + +class StaticKeyValueDatabase: + number_of_entries: int + key_size: int + value_size: int + chunk_size: int + + _number_of_key_chunks: int + _number_of_value_chunks: int + _state_shape: Tuple[int, ...] + + module: fhe.Module + state: Optional[fhe.Value] + + def __init__( + self, + number_of_entries: int, + key_size: int = 32, + value_size: int = 32, + chunk_size: int = 4, + compiled: bool = True, + configuration: Optional[fhe.Configuration] = None, + ): + self.number_of_entries = number_of_entries + self.key_size = key_size + self.value_size = value_size + self.chunk_size = chunk_size + + self._number_of_key_chunks = key_size // chunk_size + self._number_of_value_chunks = value_size // chunk_size + self._state_shape = ( + number_of_entries, + 1 + self._number_of_key_chunks + self._number_of_value_chunks, + ) + + if compiled: + if configuration is None: + configuration = fhe.Configuration() + + self.module = self._module( + configuration.fork( + multivariate_strategy_preference=fhe.MultivariateStrategy.PROMOTED, + ) + ) + + self.state = None + + def _encode(self, number: int, width: int) -> np.ndarray: + binary_repr = np.binary_repr(number, width=width) + blocks = [ + binary_repr[i : (i + self.chunk_size)] + for i in range(0, len(binary_repr), self.chunk_size) + ] + return np.array([int(block, 2) for block in blocks]) + + def _decode(self, encoded_number: np.ndarray) -> int: + result = 0 + for i in range(len(encoded_number)): + result += 2 ** (self.chunk_size * i) * encoded_number[(len(encoded_number) - i) - 1] + return result + + def encode_key(self, key: int) -> np.ndarray: + return self._encode(key, width=self.key_size) + + def decode_key(self, encoded_key: np.ndarray) -> int: + return self._decode(encoded_key) + + def encode_value(self, value: int) -> np.ndarray: + return self._encode(value, width=self.value_size) + + def decode_value(self, encoded_value: np.ndarray) -> int: + return self._decode(encoded_value) + + def _module(self, configuration: fhe.Configuration) -> fhe.Module: + flag_slice = 0 + key_slice = slice(1, 1 + self._number_of_key_chunks) + value_slice = slice(1 + self._number_of_key_chunks, None) + + chunk_size = self.chunk_size + number_of_entries = self.number_of_entries + number_of_key_chunks = self._number_of_key_chunks + state_shape = self._state_shape + + @fhe.module() + class StaticKeyValueDatabaseModule: + @fhe.function({"state": "clear"}) + def reset(state): + return state + fhe.zero() + + @fhe.function({"state": "encrypted", "key": "encrypted", "value": "encrypted"}) + def insert(state, key, value): + flags = state[:, flag_slice] + + selection = fhe.zeros(number_of_entries) + + found = fhe.zero() + for i in range(number_of_entries): + is_selected = fhe.multivariate( + lambda found, flag: int(found == 0 and flag == 0) + )(found, flags[i]) + + selection[i] = is_selected + found += is_selected + + state_update = fhe.zeros(state_shape) + state_update[:, flag_slice] = selection + + selection = selection.reshape((-1, 1)) + + key_update = fhe.multivariate(lambda selection, key: selection * key)( + selection, key + ) + value_update = fhe.multivariate(lambda selection, value: selection * value)( + selection, value + ) + + state_update[:, key_slice] = key_update + state_update[:, value_slice] = value_update + + new_state = state + state_update + return fhe.refresh(new_state) + + @fhe.function({"state": "encrypted", "key": "encrypted", "value": "encrypted"}) + def replace(state, key, value): + flags = state[:, flag_slice] + keys = state[:, key_slice] + values = state[:, value_slice] + + number_of_matching_chunks = np.sum((keys - key) == 0, axis=1) + fhe.hint(number_of_matching_chunks, can_store=number_of_key_chunks) + + equal_rows = number_of_matching_chunks == number_of_key_chunks + selection = (flags * 2 + equal_rows == 3).reshape((-1, 1)) + + set_value = fhe.multivariate(lambda selection, value: selection * value)( + selection, value + ) + + kept_values = fhe.multivariate( + lambda inverse_selection, values: inverse_selection * values + )(1 - selection, values) + + new_values = kept_values + set_value + state[:, value_slice] = new_values + + return fhe.refresh(state) + + @fhe.function({"state": "encrypted", "key": "encrypted"}) + def query(state, key): + keys = state[:, key_slice] + values = state[:, value_slice] + + number_of_matching_chunks = np.sum((keys - key) == 0, axis=1) + fhe.hint(number_of_matching_chunks, can_store=number_of_key_chunks) + + selection = (number_of_matching_chunks == number_of_key_chunks).reshape((-1, 1)) + found = np.sum(selection) + fhe.hint(found, can_store=number_of_entries) + + value_selection = fhe.multivariate(lambda selection, values: selection * values)( + selection, values + ) + + value = np.sum(value_selection, axis=0) + fhe.hint(value, bit_width=chunk_size) + + return found, value + + @fhe.function({"state": "encrypted"}) + def inspect(state): + return state + + composition = fhe.Wired( + { + # from reset + fhe.Wire(fhe.Output(reset, 0), fhe.Input(insert, 0)), + fhe.Wire(fhe.Output(reset, 0), fhe.Input(replace, 0)), + fhe.Wire(fhe.Output(reset, 0), fhe.Input(query, 0)), + fhe.Wire(fhe.Output(reset, 0), fhe.Input(inspect, 0)), + # from insert + fhe.Wire(fhe.Output(insert, 0), fhe.Input(insert, 0)), + fhe.Wire(fhe.Output(insert, 0), fhe.Input(replace, 0)), + fhe.Wire(fhe.Output(insert, 0), fhe.Input(query, 0)), + fhe.Wire(fhe.Output(insert, 0), fhe.Input(inspect, 0)), + # from replace + fhe.Wire(fhe.Output(replace, 0), fhe.Input(insert, 0)), + fhe.Wire(fhe.Output(replace, 0), fhe.Input(replace, 0)), + fhe.Wire(fhe.Output(replace, 0), fhe.Input(query, 0)), + fhe.Wire(fhe.Output(replace, 0), fhe.Input(inspect, 0)), + # from inspect + fhe.Wire(fhe.Output(inspect, 0), fhe.Input(insert, 0)), + fhe.Wire(fhe.Output(inspect, 0), fhe.Input(replace, 0)), + fhe.Wire(fhe.Output(inspect, 0), fhe.Input(query, 0)), + fhe.Wire(fhe.Output(inspect, 0), fhe.Input(inspect, 0)), + } + ) + + sample_state = np.array( + [ + [i % 2] + self.encode_key(i).tolist() + self.encode_value(i).tolist() + for i in range(self.number_of_entries) + ] + ) + + insert_replace_inputset = [ + ( + # state + sample_state, + # key + self.encode_key(i), + # value + self.encode_value(i), + ) + for i in range(20) + ] + query_inputset = [ + ( + # state + sample_state, + # key + self.encode_key(i), + ) + for i in range(20) + ] + + return StaticKeyValueDatabaseModule.compile( # type: ignore + { + "reset": [sample_state], + "insert": insert_replace_inputset, + "replace": insert_replace_inputset, + "query": query_inputset, + "inspect": [sample_state], + }, + configuration, + ) + + def keygen(self, force: bool = False): + self.module.keygen(force=force) + + def initialize(self, initial_state: Optional[Union[List, np.ndarray]] = None): + if initial_state is None: + initial_state = np.zeros(self._state_shape, dtype=np.int64) + + if isinstance(initial_state, list): + initial_state = np.array(initial_state) + + if initial_state.shape != self._state_shape: + message = ( + f"Expected initial state to be of shape {self._state_shape} " + f"but it's {initial_state.shape}" + ) + raise ValueError(message) + + initial_state_clear = self.module.reset.encrypt(initial_state) + initial_state_encrypted = self.module.reset.run(initial_state_clear) + + self.state = initial_state_encrypted + + def decode_entry(self, entry: np.ndarray) -> Optional[Tuple[int, int]]: + if entry[0] == 0: + return None + + encoded_key = entry[1 : (self._number_of_key_chunks + 1)] + encoded_value = entry[(self._number_of_key_chunks + 1) :] + + return self.decode_key(encoded_key), self.decode_value(encoded_value) + + @property + def insert(self) -> fhe.Function: + return self.module.insert + + @property + def replace(self) -> fhe.Function: + return self.module.replace + + @property + def query(self) -> fhe.Function: + return self.module.query + + @property + def inspect(self) -> fhe.Function: + return self.module.inspect + + +def inspect(db: StaticKeyValueDatabase): + encrypted_state = db.inspect.run(db.state) + clear_state = db.inspect.decrypt(encrypted_state) + print(clear_state) + + +def insert(db: StaticKeyValueDatabase, key: int, value: int): + encoded_key, encoded_value = db.encode_key(key), db.encode_value(value) + _, encrypted_key, encoded_value = db.insert.encrypt( # type: ignore + None, + encoded_key, + encoded_value, + ) + + print() + + print("Inserting...") + start = time.time() + db.state = db.insert.run(db.state, encrypted_key, encoded_value) # type: ignore + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + +def replace(db: StaticKeyValueDatabase, key: int, value: int): + encoded_key, encoded_value = db.encode_key(key), db.encode_value(value) + _, encrypted_key, encoded_value = db.replace.encrypt( # type: ignore + None, + encoded_key, + encoded_value, + ) + + print() + + print("Replacing...") + start = time.time() + db.state = db.replace.run(db.state, encrypted_key, encoded_value) # type: ignore + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + +def query(db: StaticKeyValueDatabase, key: int) -> Optional[int]: + encoded_key = db.encode_key(key) + _, encrypted_key = db.query.encrypt(None, encoded_key) # type: ignore + + print() + + print("Querying...") + start = time.time() + encrypted_found, encrypted_value = db.query.run(db.state, encrypted_key) # type: ignore + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + found, value = db.query.decrypt(encrypted_found, encrypted_value) # type: ignore + if not found: + return None + + return db.decode_value(value) # type: ignore + + +if __name__ == "__main__": + print("Compiling...") + start = time.time() + db = StaticKeyValueDatabase(number_of_entries=10) + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + print() + + print("Generating keys...") + start = time.time() + db.keygen() + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + print() + + print("Initializing...") + start = time.time() + db.initialize() + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + # Test: Insert/Query + insert(db, 3, 4) + assert query(db, 3) == 4 + + replace(db, 3, 1) + assert query(db, 3) == 1 + + # Test: Insert/Query + insert(db, 25, 40) + assert query(db, 25) == 40 + + # Test: Query Not Found + assert query(db, 4) is None + + # Test: Replace/Query + replace(db, 3, 5) + assert query(db, 3) == 5 + + # Define lower/upper bounds for the key + minimum_key = 1 + maximum_key = 2**db.key_size - 1 + # Define lower/upper bounds for the value + minimum_value = 1 + maximum_value = 2**db.value_size - 1 + + # Test: Insert/Replace/Query Bounds + # Insert (key: minimum_key , value: minimum_value) into the database + insert(db, minimum_key, minimum_value) + + # Query the database for the key=minimum_key + # The value minimum_value should be returned + assert query(db, minimum_key) == minimum_value + + # Insert (key: maximum_key , value: maximum_value) into the database + insert(db, maximum_key, maximum_value) + + # Query the database for the key=maximum_key + # The value maximum_value should be returned + assert query(db, maximum_key) == maximum_value + + # Replace the value of key=minimum_key with maximum_value + replace(db, minimum_key, maximum_value) + + # Query the database for the key=minimum_key + # The value maximum_value should be returned + assert query(db, minimum_key) == maximum_value + + # Replace the value of key=maximum_key with minimum_value + replace(db, maximum_key, minimum_value) + + # Query the database for the key=maximum_key + # The value minimum_value should be returned + assert query(db, maximum_key) == minimum_value diff --git a/frontends/concrete-python/tests/execution/test_examples.py b/frontends/concrete-python/tests/execution/test_examples.py new file mode 100644 index 0000000000..e26eb0d46a --- /dev/null +++ b/frontends/concrete-python/tests/execution/test_examples.py @@ -0,0 +1,183 @@ +""" +Tests of the examples. +""" + +from typing import Optional + +import numpy as np +import pytest + +from examples.key_value_database.static_size import StaticKeyValueDatabase +from examples.levenshtein_distance.levenshtein_distance import Alphabet, LevenshteinDistance + + +def test_static_kvdb(helpers): + """ + Test static key-value database example. + """ + + configuration = helpers.configuration() + + def inspect(db: StaticKeyValueDatabase) -> np.ndarray: + encrypted_state = db.inspect.run(db.state) + clear_state = db.inspect.decrypt(encrypted_state) + return clear_state # type: ignore + + def insert(db: StaticKeyValueDatabase, key: int, value: int): + encoded_key, encoded_value = db.encode_key(key), db.encode_value(value) + _, encrypted_key, encoded_value = db.insert.encrypt( # type: ignore + None, + encoded_key, + encoded_value, + ) + db.state = db.insert.run(db.state, encrypted_key, encoded_value) # type: ignore + + def replace(db: StaticKeyValueDatabase, key: int, value: int): + encoded_key, encoded_value = db.encode_key(key), db.encode_value(value) + _, encrypted_key, encoded_value = db.replace.encrypt( # type: ignore + None, + encoded_key, + encoded_value, + ) + db.state = db.replace.run(db.state, encrypted_key, encoded_value) # type: ignore + + def query(db: StaticKeyValueDatabase, key: int) -> Optional[int]: + encoded_key = db.encode_key(key) + _, encrypted_key = db.query.encrypt(None, encoded_key) # type: ignore + encrypted_found, encrypted_value = db.query.run(db.state, encrypted_key) # type: ignore + + found, value = db.query.decrypt(encrypted_found, encrypted_value) # type: ignore + if not found: + return None + + return db.decode_value(value) # type: ignore + + db = StaticKeyValueDatabase( + number_of_entries=4, + key_size=8, + value_size=8, + chunk_size=2, + configuration=configuration, + ) + db.keygen() + + db.initialize() + assert np.array_equal( + inspect(db), + [ + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + ], + ) + + assert query(db, 3) is None + + insert(db, 3, 4) + assert np.array_equal( + inspect(db), + [ + [1] + [0, 0, 0, 3] + [0, 0, 1, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + ], + ) + + assert query(db, 3) == 4 + + replace(db, 3, 1) + assert np.array_equal( + inspect(db), + [ + [1] + [0, 0, 0, 3] + [0, 0, 0, 1], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + ], + ) + + assert query(db, 3) == 1 + + insert(db, 25, 40) + assert np.array_equal( + inspect(db), + [ + [1] + [0, 0, 0, 3] + [0, 0, 0, 1], + [1] + [0, 1, 2, 1] + [0, 2, 2, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + ], + ) + + assert query(db, 25) == 40 + + minimum_key = 0 + maximum_key = 2**db.key_size - 1 + + minimum_value = 0 + maximum_value = 2**db.value_size - 1 + + db.initialize() + assert np.array_equal( + inspect(db), + [ + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + ], + ) + + insert(db, minimum_key, minimum_value) + assert np.array_equal( + inspect(db), + [ + [1] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + ], + ) + + assert query(db, minimum_key) == minimum_value + + replace(db, minimum_key, maximum_value) + assert np.array_equal( + inspect(db), + [ + [1] + [0, 0, 0, 0] + [3, 3, 3, 3], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + ], + ) + + assert query(db, minimum_key) == maximum_value + + insert(db, maximum_key, maximum_value) + assert np.array_equal( + inspect(db), + [ + [1] + [0, 0, 0, 0] + [3, 3, 3, 3], + [1] + [3, 3, 3, 3] + [3, 3, 3, 3], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + ], + ) + + assert query(db, maximum_key) == maximum_value + + replace(db, maximum_key, minimum_value) + assert np.array_equal( + inspect(db), + [ + [1] + [0, 0, 0, 0] + [3, 3, 3, 3], + [1] + [3, 3, 3, 3] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + [0] + [0, 0, 0, 0] + [0, 0, 0, 0], + ], + ) + + assert query(db, maximum_key) == minimum_value