From 39bb0320fc65f096f67c09d6ac046e8f58b52d32 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 5 Aug 2024 11:39:27 +0300 Subject: [PATCH] feat(frontend): refactor key value database example to use modules and benchmark it --- .../key_value_database/static_size.py | 287 ++++++++++++ .../concrete-python/concrete/fhe/__init__.py | 4 + .../{dynamic-size.py => dynamic_size.py} | 0 .../key_value_database/static-size.py | 317 ------------- .../key_value_database/static_size.py | 415 ++++++++++++++++++ 5 files changed, 706 insertions(+), 317 deletions(-) create mode 100644 frontends/concrete-python/benchmarks/key_value_database/static_size.py rename frontends/concrete-python/examples/key_value_database/{dynamic-size.py => dynamic_size.py} (100%) delete mode 100644 frontends/concrete-python/examples/key_value_database/static-size.py create mode 100644 frontends/concrete-python/examples/key_value_database/static_size.py diff --git a/frontends/concrete-python/benchmarks/key_value_database/static_size.py b/frontends/concrete-python/benchmarks/key_value_database/static_size.py new file mode 100644 index 0000000000..0f8126abaf --- /dev/null +++ b/frontends/concrete-python/benchmarks/key_value_database/static_size.py @@ -0,0 +1,287 @@ +import random +from pathlib import Path + +import numpy as np +import py_progress_tracker as progress + +from concrete import fhe +from examples.key_value_database.static_size import StaticKeyValueDatabase + +targets = [] + +for number_of_entries in [8, 16]: + for key_size in [8, 16]: + for value_size in [8, 16]: + for chunk_size in [2, 4]: + targets.append( + { + "id": ( + f"static-kvdb-insert :: " + f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" + ), + "name": ( + f"Insertion to " + f"static key-value database " + f"from {key_size}b to {value_size}b " + f"with chunk size of {chunk_size} " + f"on {number_of_entries} entries" + ), + "parameters": { + "operation": "insert", + "number_of_entries": number_of_entries, + "key_size": key_size, + "value_size": value_size, + "chunk_size": chunk_size, + }, + } + ) + targets.append( + { + "id": ( + f"static-kvdb-replace :: " + f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" + ), + "name": ( + f"Replacement in " + f"static key-value database " + f"from {key_size}b to {value_size}b " + f"with chunk size of {chunk_size} " + f"on {number_of_entries} entries" + ), + "parameters": { + "operation": "replace", + "number_of_entries": number_of_entries, + "key_size": key_size, + "value_size": value_size, + "chunk_size": chunk_size, + }, + } + ) + targets.append( + { + "id": ( + f"static-kvdb-query :: " + f"{number_of_entries} * {key_size}->{value_size} ^ {chunk_size}" + ), + "name": ( + f"Query of " + f"static key-value database " + f"from {key_size}b to {value_size}b " + f"with chunk size of {chunk_size} " + f"on {number_of_entries} entries" + ), + "parameters": { + "operation": "query", + "number_of_entries": number_of_entries, + "key_size": key_size, + "value_size": value_size, + "chunk_size": chunk_size, + }, + } + ) + + +def benchmark_insert(db: StaticKeyValueDatabase, client: fhe.Client, server: fhe.Server): + print("Warming up...") + + sample_key = random.randint(0, 2**db.key_size - 1) + sample_value = random.randint(0, 2**db.value_size - 1) + + # Initial state only contains odd keys for benchmarks. + # To avoid collisions, we'll make sure that sample_key is even. + if sample_key % 2 == 1: + sample_key -= 1 + + encoded_sample_key = db.encode_key(sample_key) + encoded_sample_value = db.encode_value(sample_value) + + _, encrypted_sample_key, encrypted_sample_value = client.encrypt( + None, + encoded_sample_key, + encoded_sample_value, + function_name="insert", + ) + ran = server.run( # noqa: F841 + db.state, + encrypted_sample_key, + encrypted_sample_value, + function_name="insert", + evaluation_keys=client.evaluation_keys, + ) + + for i in range(5): + print(f"Running subsample {i + 1} out of 5...") + + sample_key = random.randint(0, 2**db.key_size - 1) + sample_value = random.randint(0, 2**db.value_size - 1) + + if sample_key % 2 == 1: + sample_key -= 1 + + encoded_sample_key = db.encode_key(sample_key) + encoded_sample_value = db.encode_value(sample_value) + + _, encrypted_sample_key, encrypted_sample_value = client.encrypt( + None, + encoded_sample_key, + encoded_sample_value, + function_name="insert", + ) + + with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): + ran = server.run( # noqa: F841 + db.state, + encrypted_sample_key, + encrypted_sample_value, + function_name="insert", + evaluation_keys=client.evaluation_keys, + ) + + +def benchmark_replace(db: StaticKeyValueDatabase, client: fhe.Client, server: fhe.Server): + print("Warming up...") + + sample_key = random.randint(0, db.number_of_entries // 2) * 2 + sample_value = random.randint(0, db.number_of_entries // 2) * 2 + + # Initial state only contains odd keys for benchmarks. + # To actually replace, we'll make sure that sample_key is odd. + if sample_key % 2 == 0: + sample_key += 1 + + encoded_sample_key = db.encode_key(sample_key) + encoded_sample_value = db.encode_value(sample_value) + + _, encrypted_sample_key, encrypted_sample_value = client.encrypt( + None, + encoded_sample_key, + encoded_sample_value, + function_name="replace", + ) + ran = server.run( # noqa: F841 + db.state, + encrypted_sample_key, + encrypted_sample_value, + function_name="replace", + evaluation_keys=client.evaluation_keys, + ) + + for i in range(5): + print(f"Running subsample {i + 1} out of 5...") + + sample_key = random.randint(0, db.number_of_entries - 1) + sample_value = random.randint(0, db.number_of_entries - 1) + + if sample_key % 2 == 0: + sample_key += 1 + + encoded_sample_key = db.encode_key(sample_key) + encoded_sample_value = db.encode_value(sample_value) + + _, encrypted_sample_key, encrypted_sample_value = client.encrypt( + None, + encoded_sample_key, + encoded_sample_value, + function_name="replace", + ) + + with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): + ran = server.run( # noqa: F841 + db.state, + encrypted_sample_key, + encrypted_sample_value, + function_name="replace", + evaluation_keys=client.evaluation_keys, + ) + + +def benchmark_query(db: StaticKeyValueDatabase, client: fhe.Client, server: fhe.Server): + print("Warming up...") + + sample_key = random.randint(0, db.number_of_entries - 1) + encoded_sample_key = db.encode_key(sample_key) + + _, encrypted_sample_key = client.encrypt( + None, + encoded_sample_key, + function_name="query", + ) + ran = server.run( # noqa: F841 + db.state, + encrypted_sample_key, + function_name="query", + evaluation_keys=client.evaluation_keys, + ) + + for i in range(5): + print(f"Running subsample {i + 1} out of 5...") + + sample_key = random.randint(0, db.number_of_entries - 1) + encoded_sample_key = db.encode_key(sample_key) + + _, encrypted_sample_key = client.encrypt(None, encoded_sample_key, function_name="query") + with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): + ran = server.run( # noqa: F841 + db.state, + encrypted_sample_key, + function_name="query", + evaluation_keys=client.evaluation_keys, + ) + + +@progress.track(targets) +def main(operation, number_of_entries, key_size, value_size, chunk_size): + print("Compiling...") + cached_server = Path(f"{number_of_entries}.{key_size}.{value_size}.{chunk_size}.server.zip") + if cached_server.exists(): + db = StaticKeyValueDatabase( + # database parameters + number_of_entries, + key_size, + value_size, + chunk_size, + compiled=False, + ) + server = fhe.Server.load(cached_server) + client = fhe.Client(server.client_specs, keyset_cache_directory=".keys") + else: + configuration = {} + db = StaticKeyValueDatabase( + # database parameters + number_of_entries, + key_size, + value_size, + chunk_size, + # configuration overrides + enable_unsafe_features=True, + use_insecure_key_cache=True, + insecure_key_cache_location=".keys", + ) + db.module.server.save(cached_server) + + server = db.module.server + client = db.module.client + + db.state = server.run( + client.encrypt( + [ + np.array([1] + db.encode_key(i).tolist() + db.encode_value(i).tolist()) * (i % 2) + for i in range(db.number_of_entries) + ], + function_name="reset", + ), + function_name="reset", + evaluation_keys=client.evaluation_keys, + ) + + print("Generating keys...") + client.keygen() + + if operation == "insert": + benchmark_insert(db, client, server) + elif operation == "replace": + benchmark_replace(db, client, server) + elif operation == "query": + benchmark_query(db, client, server) + else: + raise ValueError(f"Invalid operation: {operation}") diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index 723616cb49..f82430727c 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -24,6 +24,10 @@ DebugArtifacts, EncryptionStatus, Exactness, +) +from .compilation import FheFunction as Function +from .compilation import FheModule as Module +from .compilation import ( FunctionDebugArtifacts, Input, Keys, diff --git a/frontends/concrete-python/examples/key_value_database/dynamic-size.py b/frontends/concrete-python/examples/key_value_database/dynamic_size.py similarity index 100% rename from frontends/concrete-python/examples/key_value_database/dynamic-size.py rename to frontends/concrete-python/examples/key_value_database/dynamic_size.py diff --git a/frontends/concrete-python/examples/key_value_database/static-size.py b/frontends/concrete-python/examples/key_value_database/static-size.py deleted file mode 100644 index 61f7c5f920..0000000000 --- a/frontends/concrete-python/examples/key_value_database/static-size.py +++ /dev/null @@ -1,317 +0,0 @@ -import time - -import numpy as np - -from concrete import fhe - -NUMBER_OF_ENTRIES = 5 -CHUNK_SIZE = 4 - -KEY_SIZE = 32 -VALUE_SIZE = 32 - -assert KEY_SIZE % CHUNK_SIZE == 0 -assert VALUE_SIZE % CHUNK_SIZE == 0 - -NUMBER_OF_KEY_CHUNKS = KEY_SIZE // CHUNK_SIZE -NUMBER_OF_VALUE_CHUNKS = VALUE_SIZE // CHUNK_SIZE - -STATE_SHAPE = (NUMBER_OF_ENTRIES, 1 + NUMBER_OF_KEY_CHUNKS + NUMBER_OF_VALUE_CHUNKS) - -FLAG = 0 -KEY = slice(1, 1 + NUMBER_OF_KEY_CHUNKS) -VALUE = slice(1 + NUMBER_OF_KEY_CHUNKS, None) - - -def encode(number: int, width: int) -> np.ndarray: - binary_repr = np.binary_repr(number, width=width) - blocks = [binary_repr[i : (i + CHUNK_SIZE)] for i in range(0, len(binary_repr), CHUNK_SIZE)] - return np.array([int(block, 2) for block in blocks]) - - -def encode_key(number: int) -> np.ndarray: - return encode(number, width=KEY_SIZE) - - -def encode_value(number: int) -> np.ndarray: - return encode(number, width=VALUE_SIZE) - - -def decode(encoded_number: np.ndarray) -> int: - result = 0 - for i in range(len(encoded_number)): - result += 2 ** (CHUNK_SIZE * i) * encoded_number[(len(encoded_number) - i) - 1] - return result - - -keep_selected_lut = fhe.LookupTable([0 for _ in range(16)] + [i for i in range(16)]) - - -def _insert_impl(state, key, value): - flags = state[:, FLAG] - - selection = fhe.zeros(NUMBER_OF_ENTRIES) - - found = fhe.zero() - for i in range(NUMBER_OF_ENTRIES): - packed_flag_and_already_found = (found * 2) + flags[i] - is_selected = packed_flag_and_already_found == 0 - - selection[i] = is_selected - found += is_selected - - state_update = fhe.zeros(STATE_SHAPE) - state_update[:, FLAG] = selection - - selection = selection.reshape((-1, 1)) - - packed_selection_and_key = (selection * (2**CHUNK_SIZE)) + key - key_update = keep_selected_lut[packed_selection_and_key] - - packed_selection_and_value = selection * (2**CHUNK_SIZE) + value - value_update = keep_selected_lut[packed_selection_and_value] - - state_update[:, KEY] = key_update - state_update[:, VALUE] = value_update - - new_state = state + state_update - return new_state - - -def _replace_impl(state, key, value): - flags = state[:, FLAG] - keys = state[:, KEY] - values = state[:, VALUE] - - number_of_matching_chunks = np.sum((keys - key) == 0, axis=1) - fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) - - equal_rows = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS - selection = (flags * 2 + equal_rows == 3).reshape((-1, 1)) - - packed_selection_and_value = selection * (2**CHUNK_SIZE) + value - set_value = keep_selected_lut[packed_selection_and_value] - - inverse_selection = 1 - selection - packed_inverse_selection_and_values = inverse_selection * (2**CHUNK_SIZE) + values - kept_values = keep_selected_lut[packed_inverse_selection_and_values] - - new_values = kept_values + set_value - state[:, VALUE] = new_values - - return state - - -def _query_impl(state, key): - keys = state[:, KEY] - values = state[:, VALUE] - - number_of_matching_chunks = np.sum((keys - key) == 0, axis=1) - fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) - - selection = (number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS).reshape((-1, 1)) - found = np.sum(selection) - fhe.hint(found, can_store=NUMBER_OF_ENTRIES) - - packed_selection_and_values = selection * (2**CHUNK_SIZE) + values - value_selection = keep_selected_lut[packed_selection_and_values] - value = np.sum(value_selection, axis=0) - fhe.hint(value, bit_width=CHUNK_SIZE) - - return fhe.array([found, *value]) - - -class KeyValueDatabase: - _state: np.ndarray - - _insert_circuit: fhe.Circuit - _replace_circuit: fhe.Circuit - _query_circuit: fhe.Circuit - - def __init__(self): - self._state = np.zeros(STATE_SHAPE, dtype=np.int64) - - sample_state = np.array( - [ - [i % 2] + encode_key(i).tolist() + encode_value(i).tolist() - for i in range(STATE_SHAPE[0]) - ] - ) - - insert_replace_inputset = [ - ( - # state - sample_state, - # key - encode_key(i), - # value - encode_key(i), - ) - for i in range(20) - ] - query_inputset = [ - ( - # state - sample_state, - # key - encode_key(i), - ) - for i in range(20) - ] - - configuration = fhe.Configuration( - enable_unsafe_features=True, - use_insecure_key_cache=True, - insecure_key_cache_location=".keys", - ) - - insert_compiler = fhe.Compiler( - _insert_impl, - {"state": "encrypted", "key": "encrypted", "value": "encrypted"}, - ) - replace_compiler = fhe.Compiler( - _replace_impl, - {"state": "encrypted", "key": "encrypted", "value": "encrypted"}, - ) - query_compiler = fhe.Compiler( - _query_impl, - {"state": "encrypted", "key": "encrypted"}, - ) - - print() - - print("Compiling insertion circuit...") - start = time.time() - self._insert_circuit = insert_compiler.compile(insert_replace_inputset, configuration) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Compiling replacement circuit...") - start = time.time() - self._replace_circuit = replace_compiler.compile(insert_replace_inputset, configuration) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Compiling query circuit...") - start = time.time() - self._query_circuit = query_compiler.compile(query_inputset, configuration) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Generating insertion keys...") - start = time.time() - self._insert_circuit.keygen() - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Generating replacement keys...") - start = time.time() - self._replace_circuit.keygen() - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - print() - - print("Generating query keys...") - start = time.time() - self._query_circuit.keygen() - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - def insert(self, key, value): - print() - print("Inserting...") - start = time.time() - self._state = self._insert_circuit.encrypt_run_decrypt( - self._state, encode_key(key), encode_value(value) - ) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - def replace(self, key, value): - print() - print("Replacing...") - start = time.time() - self._state = self._replace_circuit.encrypt_run_decrypt( - self._state, encode_key(key), encode_value(value) - ) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - def query(self, key): - print() - print("Querying...") - start = time.time() - result = self._query_circuit.encrypt_run_decrypt(self._state, encode_key(key)) - end = time.time() - print(f"(took {end - start:.3f} seconds)") - - if result[0] == 0: - return None - - return decode(result[1:]) - - -db = KeyValueDatabase() - -# Test: Insert/Query -db.insert(3, 4) -assert db.query(3) == 4 - -db.replace(3, 1) -assert db.query(3) == 1 - -# Test: Insert/Query -db.insert(25, 40) -assert db.query(25) == 40 - -# Test: Query Not Found -assert db.query(4) is None - -# Test: Replace/Query -db.replace(3, 5) -assert db.query(3) == 5 - -# Define lower/upper bounds for the key -minimum_key = 1 -maximum_key = 2**KEY_SIZE - 1 -# Define lower/upper bounds for the value -minimum_value = 1 -maximum_value = 2**VALUE_SIZE - 1 - -# Test: Insert/Replace/Query Bounds -# Insert (key: minimum_key , value: minimum_value) into the database -db.insert(minimum_key, minimum_value) - -# Query the database for the key=minimum_key -# The value minimum_value should be returned -assert db.query(minimum_key) == minimum_value - -# Insert (key: maximum_key , value: maximum_value) into the database -db.insert(maximum_key, maximum_value) - -# Query the database for the key=maximum_key -# The value maximum_value should be returned -assert db.query(maximum_key) == maximum_value - -# Replace the value of key=minimum_key with maximum_value -db.replace(minimum_key, maximum_value) - -# Query the database for the key=minimum_key -# The value maximum_value should be returned -assert db.query(minimum_key) == maximum_value - -# Replace the value of key=maximum_key with minimum_value -db.replace(maximum_key, minimum_value) - -# Query the database for the key=maximum_key -# The value minimum_value should be returned -assert db.query(maximum_key) == minimum_value diff --git a/frontends/concrete-python/examples/key_value_database/static_size.py b/frontends/concrete-python/examples/key_value_database/static_size.py new file mode 100644 index 0000000000..74dc96abe3 --- /dev/null +++ b/frontends/concrete-python/examples/key_value_database/static_size.py @@ -0,0 +1,415 @@ +import time +from typing import List, Optional, Tuple, Union + +import numpy as np + +from concrete import fhe + + +class StaticKeyValueDatabase: + number_of_entries: int + key_size: int + value_size: int + chunk_size: int + + _number_of_key_chunks: int + _number_of_value_chunks: int + _state_shape: Tuple[int, ...] + + module: fhe.Module + state: Optional[fhe.Value] + + def __init__( + self, + number_of_entries: int, + key_size: int = 32, + value_size: int = 32, + chunk_size: int = 4, + compiled: bool = True, + **kwargs, + ): + self.number_of_entries = number_of_entries + self.key_size = key_size + self.value_size = value_size + self.chunk_size = chunk_size + + self._number_of_key_chunks = key_size // chunk_size + self._number_of_value_chunks = value_size // chunk_size + self._state_shape = ( + number_of_entries, + 1 + self._number_of_key_chunks + self._number_of_value_chunks, + ) + + configuration = fhe.Configuration( + multivariate_strategy_preference=fhe.MultivariateStrategy.PROMOTED, + ).fork(**kwargs) + + if compiled: + self.module = self._module(configuration) + + self.state = None + + def _encode(self, number: int, width: int) -> np.ndarray: + binary_repr = np.binary_repr(number, width=width) + blocks = [ + binary_repr[i : (i + self.chunk_size)] + for i in range(0, len(binary_repr), self.chunk_size) + ] + return np.array([int(block, 2) for block in blocks]) + + def _decode(self, encoded_number: np.ndarray) -> int: + result = 0 + for i in range(len(encoded_number)): + result += 2 ** (self.chunk_size * i) * encoded_number[(len(encoded_number) - i) - 1] + return result + + def encode_key(self, key: int) -> np.ndarray: + return self._encode(key, width=self.key_size) + + def decode_key(self, encoded_key: np.ndarray) -> int: + return self._decode(encoded_key) + + def encode_value(self, value: int) -> np.ndarray: + return self._encode(value, width=self.value_size) + + def decode_value(self, encoded_value: np.ndarray) -> int: + return self._decode(encoded_value) + + def _module(self, configuration: fhe.Configuration) -> fhe.Module: + flag_slice = 0 + key_slice = slice(1, 1 + self._number_of_key_chunks) + value_slice = slice(1 + self._number_of_key_chunks, None) + + chunk_size = self.chunk_size + number_of_entries = self.number_of_entries + number_of_key_chunks = self._number_of_key_chunks + state_shape = self._state_shape + + @fhe.module() + class StaticKeyValueDatabaseModule: + @fhe.function({"state": "clear"}) + def reset(state): + return state + fhe.zero() + + @fhe.function({"state": "encrypted", "key": "encrypted", "value": "encrypted"}) + def insert(state, key, value): + flags = state[:, flag_slice] + + selection = fhe.zeros(number_of_entries) + + found = fhe.zero() + for i in range(number_of_entries): + is_selected = fhe.multivariate( + lambda found, flag: int(found == 0 and flag == 0) + )(found, flags[i]) + + selection[i] = is_selected + found += is_selected + + state_update = fhe.zeros(state_shape) + state_update[:, flag_slice] = selection + + selection = selection.reshape((-1, 1)) + + key_update = fhe.multivariate(lambda selection, key: selection * key)( + selection, key + ) + value_update = fhe.multivariate(lambda selection, value: selection * value)( + selection, value + ) + + state_update[:, key_slice] = key_update + state_update[:, value_slice] = value_update + + new_state = state + state_update + return fhe.refresh(new_state) + + @fhe.function({"state": "encrypted", "key": "encrypted", "value": "encrypted"}) + def replace(state, key, value): + flags = state[:, flag_slice] + keys = state[:, key_slice] + values = state[:, value_slice] + + number_of_matching_chunks = np.sum((keys - key) == 0, axis=1) + fhe.hint(number_of_matching_chunks, can_store=number_of_key_chunks) + + equal_rows = number_of_matching_chunks == number_of_key_chunks + selection = (flags * 2 + equal_rows == 3).reshape((-1, 1)) + + set_value = fhe.multivariate(lambda selection, value: selection * value)( + selection, value + ) + + kept_values = fhe.multivariate( + lambda inverse_selection, values: inverse_selection * values + )(1 - selection, values) + + new_values = kept_values + set_value + state[:, value_slice] = new_values + + return fhe.refresh(state) + + @fhe.function({"state": "encrypted", "key": "encrypted"}) + def query(state, key): + keys = state[:, key_slice] + values = state[:, value_slice] + + number_of_matching_chunks = np.sum((keys - key) == 0, axis=1) + fhe.hint(number_of_matching_chunks, can_store=number_of_key_chunks) + + selection = (number_of_matching_chunks == number_of_key_chunks).reshape((-1, 1)) + found = np.sum(selection) + fhe.hint(found, can_store=number_of_entries) + + value_selection = fhe.multivariate(lambda selection, values: selection * values)( + selection, values + ) + + value = np.sum(value_selection, axis=0) + fhe.hint(value, bit_width=chunk_size) + + return found, value + + @fhe.function({"state": "encrypted"}) + def inspect(state): + return state + + composition = fhe.Wired( + [ + # from reset + fhe.Wire(fhe.Output(reset, 0), fhe.Input(insert, 0)), + fhe.Wire(fhe.Output(reset, 0), fhe.Input(replace, 0)), + fhe.Wire(fhe.Output(reset, 0), fhe.Input(query, 0)), + fhe.Wire(fhe.Output(reset, 0), fhe.Input(inspect, 0)), + # from insert + fhe.Wire(fhe.Output(insert, 0), fhe.Input(insert, 0)), + fhe.Wire(fhe.Output(insert, 0), fhe.Input(replace, 0)), + fhe.Wire(fhe.Output(insert, 0), fhe.Input(query, 0)), + fhe.Wire(fhe.Output(insert, 0), fhe.Input(inspect, 0)), + # from replace + fhe.Wire(fhe.Output(replace, 0), fhe.Input(insert, 0)), + fhe.Wire(fhe.Output(replace, 0), fhe.Input(replace, 0)), + fhe.Wire(fhe.Output(replace, 0), fhe.Input(query, 0)), + fhe.Wire(fhe.Output(replace, 0), fhe.Input(inspect, 0)), + # from inspect + fhe.Wire(fhe.Output(inspect, 0), fhe.Input(insert, 0)), + fhe.Wire(fhe.Output(inspect, 0), fhe.Input(replace, 0)), + fhe.Wire(fhe.Output(inspect, 0), fhe.Input(query, 0)), + fhe.Wire(fhe.Output(inspect, 0), fhe.Input(inspect, 0)), + ] + ) + + sample_state = np.array( + [ + [i % 2] + self.encode_key(i).tolist() + self.encode_value(i).tolist() + for i in range(self.number_of_entries) + ] + ) + + insert_replace_inputset = [ + ( + # state + sample_state, + # key + self.encode_key(i), + # value + self.encode_value(i), + ) + for i in range(20) + ] + query_inputset = [ + ( + # state + sample_state, + # key + self.encode_key(i), + ) + for i in range(20) + ] + + return StaticKeyValueDatabaseModule.compile( + { + "reset": [sample_state], + "insert": insert_replace_inputset, + "replace": insert_replace_inputset, + "query": query_inputset, + "inspect": [sample_state], + }, + configuration, + ) + + def keygen(self, force: bool = False): + self.module.keygen(force=force) + + def initialize(self, initial_state: Optional[Union[List, np.ndarray]] = None): + if isinstance(initial_state, list): + initial_state = np.array(initial_state) + + if initial_state is None: + initial_state = np.zeros(self._state_shape, dtype=np.int64) + + if initial_state.shape != self._state_shape: + message = ( + f"Expected initial state to be of shape {self._state_shape} " + f"but it's {initial_state.shape}" + ) + raise ValueError(message) + + initial_state_clear = self.module.reset.encrypt(initial_state) + initial_state_encrypted = self.module.reset.run(initial_state_clear) + + self.state = initial_state_encrypted + + def decode_entry(self, entry: np.ndarray) -> Optional[Tuple[int, int]]: + if entry[0] == 0: + return None + + encoded_key = entry[1 : (self._number_of_key_chunks + 1)] + encoded_value = entry[(self._number_of_key_chunks + 1) :] + + return self.decode_key(encoded_key), self.decode_value(encoded_value) + + @property + def insert(self) -> fhe.Function: + return self.module.insert + + @property + def replace(self) -> fhe.Function: + return self.module.replace + + @property + def query(self) -> fhe.Function: + return self.module.query + + @property + def inspect(self) -> fhe.Function: + return self.module.inspect + + +def inspect(db: StaticKeyValueDatabase): + encrypted_state = db.inspect.run(db.state) + clear_state = db.inspect.decrypt(encrypted_state) + print(clear_state) + + +def insert(db: StaticKeyValueDatabase, key: int, value: int): + encoded_key, encoded_value = db.encode_key(key), db.encode_value(value) + _, encrypted_key, encoded_value = db.insert.encrypt(None, encoded_key, encoded_value) + + print() + + print("Inserting...") + start = time.time() + db.state = db.insert.run(db.state, encrypted_key, encoded_value) + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + +def replace(db: StaticKeyValueDatabase, key: int, value: int): + encoded_key, encoded_value = db.encode_key(key), db.encode_value(value) + _, encrypted_key, encoded_value = db.replace.encrypt(None, encoded_key, encoded_value) + + print() + + print("Replacing...") + start = time.time() + db.state = db.replace.run(db.state, encrypted_key, encoded_value) + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + +def query(db: StaticKeyValueDatabase, key: int) -> Optional[int]: + encoded_key = db.encode_key(key) + _, encrypted_key = db.query.encrypt(None, encoded_key) + + print() + + print("Querying...") + start = time.time() + encrypted_found, encrypted_value = db.query.run(db.state, encrypted_key) + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + found, value = db.query.decrypt(encrypted_found, encrypted_value) + if not found: + return None + + return db.decode_value(value) + + +if __name__ == "__main__": + print("Compiling...") + start = time.time() + db = StaticKeyValueDatabase(number_of_entries=10) + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + print() + + print("Generating keys...") + start = time.time() + db.keygen() + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + print() + + print("Initializing...") + start = time.time() + db.initialize() + end = time.time() + print(f"(took {end - start:.3f} seconds)") + + # Test: Insert/Query + insert(db, 3, 4) + assert query(db, 3) == 4 + + replace(db, 3, 1) + assert query(db, 3) == 1 + + # Test: Insert/Query + insert(db, 25, 40) + assert query(db, 25) == 40 + + # Test: Query Not Found + assert query(db, 4) is None + + # Test: Replace/Query + replace(db, 3, 5) + assert query(db, 3) == 5 + + # Define lower/upper bounds for the key + minimum_key = 1 + maximum_key = 2**db.key_size - 1 + # Define lower/upper bounds for the value + minimum_value = 1 + maximum_value = 2**db.value_size - 1 + + # Test: Insert/Replace/Query Bounds + # Insert (key: minimum_key , value: minimum_value) into the database + insert(db, minimum_key, minimum_value) + + # Query the database for the key=minimum_key + # The value minimum_value should be returned + assert query(db, minimum_key) == minimum_value + + # Insert (key: maximum_key , value: maximum_value) into the database + insert(db, maximum_key, maximum_value) + + # Query the database for the key=maximum_key + # The value maximum_value should be returned + assert query(db, maximum_key) == maximum_value + + # Replace the value of key=minimum_key with maximum_value + replace(db, minimum_key, maximum_value) + + # Query the database for the key=minimum_key + # The value maximum_value should be returned + assert query(db, minimum_key) == maximum_value + + # Replace the value of key=maximum_key with minimum_value + replace(db, maximum_key, minimum_value) + + # Query the database for the key=maximum_key + # The value minimum_value should be returned + assert query(db, maximum_key) == minimum_value