diff --git a/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py b/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py index 8975a8cc67..7137aa79a7 100644 --- a/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py +++ b/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py @@ -1,6 +1,7 @@ # Computing Levenstein distance between strings, https://en.wikipedia.org/wiki/Levenshtein_distance import time +import argparse from functools import lru_cache import numpy @@ -16,7 +17,7 @@ class MyModule: @fhe.function({"x": "encrypted", "y": "encrypted"}) def equal(x, y): - return x == y + return fhe.univariate(lambda x: x == 0)(x - y) @fhe.function( { @@ -69,29 +70,6 @@ def map_string_to_int(s): return tuple(ord(si) for si in s) -# Compilation -inputset_equal = [(random_letter_as_int(), random_letter_as_int()) for _ in range(1000)] -inputset_mix = [ - ( - numpy.random.randint(2), - numpy.random.randint(max_string_length), - numpy.random.randint(max_string_length), - numpy.random.randint(max_string_length), - numpy.random.randint(max_string_length), - ) - for _ in range(100) -] - -my_module = MyModule.compile( - {"equal": inputset_equal, "mix": inputset_mix}, - show_mlir=True, - p_error=10**-20, - show_optimizer=True, - comparison_strategy_preference=fhe.ComparisonStrategy.ONE_TLU_PROMOTED, - min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED, -) - - # Function in clear, for reference and comparison @lru_cache def levenshtein_clear(x, y): @@ -112,15 +90,15 @@ def levenshtein_clear(x, y): # Function in FHE-simulate, to debug @lru_cache -def levenshtein_simulate(x, y): +def levenshtein_simulate(my_module, x, y): if len(x) == 0: return len(y) if len(y) == 0: return len(x) - if_equal = levenshtein_simulate(x[1:], y[1:]) - case_1 = levenshtein_simulate(x[1:], y) - case_2 = levenshtein_simulate(x, y[1:]) + if_equal = levenshtein_simulate(my_module, x[1:], y[1:]) + case_1 = levenshtein_simulate(my_module, x[1:], y) + case_2 = levenshtein_simulate(my_module, x, y[1:]) case_3 = if_equal is_equal = my_module.equal(x[0], y[0]) @@ -131,7 +109,7 @@ def levenshtein_simulate(x, y): # Function in FHE @lru_cache -def levenshtein_fhe(x, y): +def levenshtein_fhe(my_module, x, y): if len(x) == 0: # In clear, that's return len(y) return my_module.mix.encrypt(None, len(y), None, None, None)[1] @@ -139,9 +117,9 @@ def levenshtein_fhe(x, y): # In clear, that's return len(x) return my_module.mix.encrypt(None, len(x), None, None, None)[1] - if_equal = levenshtein_fhe(x[1:], y[1:]) - case_1 = levenshtein_fhe(x[1:], y) - case_2 = levenshtein_fhe(x, y[1:]) + if_equal = levenshtein_fhe(my_module, x[1:], y[1:]) + case_1 = levenshtein_fhe(my_module, x[1:], y) + case_2 = levenshtein_fhe(my_module, x, y[1:]) case_3 = if_equal # In FHE @@ -151,70 +129,134 @@ def levenshtein_fhe(x, y): return returned_value -# Random patterns of different lengths -list_patterns = [ - ("", ""), - ("", "a"), - ("b", ""), - ("a", "a"), - ("a", "b"), -] - -for length_1 in range(max_string_length + 1): - for length_2 in range(max_string_length + 1): - list_patterns += [ - ( - random_string(length_1), - random_string(length_2), - ) - for _ in range(1) - ] +# Manage user args +def manage_args(): + parser = argparse.ArgumentParser(description="Levenshtein distance in Concrete.") + parser.add_argument( + "--show_mlir", + dest="show_mlir", + action="store_true", + help="Show the MLIR", + ) + args = parser.parse_args() + return args + + +def compile_module(args): + # Compilation + inputset_equal = [(random_letter_as_int(), random_letter_as_int()) for _ in range(1000)] + inputset_mix = [ + ( + numpy.random.randint(2), + numpy.random.randint(max_string_length), + numpy.random.randint(max_string_length), + numpy.random.randint(max_string_length), + numpy.random.randint(max_string_length), + ) + for _ in range(100) + ] + + my_module = MyModule.compile( + {"equal": inputset_equal, "mix": inputset_mix}, + show_mlir=args.show_mlir, + p_error=10**-20, + show_optimizer=True, + comparison_strategy_preference=fhe.ComparisonStrategy.ONE_TLU_PROMOTED, + min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED, + ) + + return my_module + + +def prepare_random_patterns(): + # Random patterns of different lengths + list_patterns = [ + ("", ""), + ("", "a"), + ("b", ""), + ("a", "a"), + ("a", "b"), + ] + + for length_1 in range(max_string_length + 1): + for length_2 in range(max_string_length + 1): + list_patterns += [ + ( + random_string(length_1), + random_string(length_2), + ) + for _ in range(1) + ] -# Checks in simulation -print("Computations in simulation\n") + return list_patterns -for a, b in list_patterns: - print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="") +def compute_in_simulation(my_module, list_patterns): - assert len(a) <= max_string_length - assert len(b) <= max_string_length + # Checks in simulation + print("Computations in simulation\n") - a_as_int = map_string_to_int(a) - b_as_int = map_string_to_int(b) + for a, b in list_patterns: - l1_simulate = levenshtein_simulate(a_as_int, b_as_int) - l1_clear = levenshtein_clear(a_as_int, b_as_int) + print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="") - assert l1_simulate == l1_clear, f" {l1_simulate=} and {l1_clear=} are different" - print(" - OK") + assert len(a) <= max_string_length + assert len(b) <= max_string_length -# Key generation -my_module.keygen() + a_as_int = map_string_to_int(a) + b_as_int = map_string_to_int(b) -# Checks in FHE -print("\nComputations in FHE\n") + l1_simulate = levenshtein_simulate(my_module, a_as_int, b_as_int) + l1_clear = levenshtein_clear(a_as_int, b_as_int) -for a, b in list_patterns: + assert l1_simulate == l1_clear, f" {l1_simulate=} and {l1_clear=} are different" + print(" - OK") - print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="") - assert len(a) <= max_string_length - assert len(b) <= max_string_length +def compute_in_fhe(my_module, list_patterns): + # Key generation + my_module.keygen() - a_as_int = map_string_to_int(a) - b_as_int = map_string_to_int(b) + # Checks in FHE + print("\nComputations in FHE\n") - a_enc = tuple(my_module.equal.encrypt(ai, None)[0] for ai in a_as_int) - b_enc = tuple(my_module.equal.encrypt(None, bi)[1] for bi in b_as_int) + for a, b in list_patterns: - time_begin = time.time() - l1_fhe_enc = levenshtein_fhe(a_enc, b_enc) - time_end = time.time() + print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="") - l1_fhe = my_module.mix.decrypt(l1_fhe_enc) + assert len(a) <= max_string_length + assert len(b) <= max_string_length + + a_as_int = map_string_to_int(a) + b_as_int = map_string_to_int(b) + + a_enc = tuple(my_module.equal.encrypt(ai, None)[0] for ai in a_as_int) + b_enc = tuple(my_module.equal.encrypt(None, bi)[1] for bi in b_as_int) + + time_begin = time.time() + l1_fhe_enc = levenshtein_fhe(my_module, a_enc, b_enc) + time_end = time.time() + + l1_fhe = my_module.mix.decrypt(l1_fhe_enc) + + l1_clear = levenshtein_clear(a, b) + + assert l1_fhe == l1_clear, f" {l1_fhe=} and {l1_clear=} are different" + print(f" - OK in {time_end - time_begin:.2f} seconds") + + +def main(): + print() + + # Options by the user + args = manage_args() + + # + my_module = compile_module(args) + list_patterns = prepare_random_patterns() + compute_in_simulation(my_module, list_patterns) + compute_in_fhe(my_module, list_patterns) - l1_clear = levenshtein_clear(a, b) - assert l1_fhe == l1_clear, f" {l1_fhe=} and {l1_clear=} are different" - print(f" - OK in {time_end - time_begin:.2f} seconds") +if __name__ == "__main__": + main()