diff --git a/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py b/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py index b3e3e8787e..dd80004b00 100644 --- a/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py +++ b/frontends/concrete-python/examples/levenshtein_distance/levenshtein_distance.py @@ -140,8 +140,8 @@ def _encode_and_encrypt_strings(self, a: str, b: str) -> tuple: a_as_int = self.alphabet.encode(a) b_as_int = self.alphabet.encode(b) - a_enc = tuple(self.module.equal.encrypt(ai, None)[0] for ai in a_as_int) # type: ignore - b_enc = tuple(self.module.equal.encrypt(None, bi)[1] for bi in b_as_int) # type: ignore + a_enc = tuple(self.module.char_cost_align.encrypt(ai, None)[0] for ai in a_as_int) # type: ignore + b_enc = tuple(self.module.char_cost_align.encrypt(None, bi)[1] for bi in b_as_int) # type: ignore return a_enc, b_enc @@ -149,20 +149,19 @@ def _compile_module(self, args): """Compile the FHE module.""" assert len(self.alphabet.mapping_to_int) > 0, "Mapping not defined" - inputset_equal = [ + inputset_char_cost_align = [ ( self.alphabet.random_pick_in_values(), self.alphabet.random_pick_in_values(), ) for _ in range(1000) ] - inputset_mix = [ + inputset_min_cost = [ ( np.random.randint(2), np.random.randint(args.max_string_length), np.random.randint(args.max_string_length), np.random.randint(args.max_string_length), - np.random.randint(args.max_string_length), ) for _ in range(1000) ] @@ -170,14 +169,13 @@ def _compile_module(self, args): # pylint: disable-next=no-member self.module = LevenshsteinModule.compile( { - "equal": inputset_equal, - "mix": inputset_mix, + "char_cost_align": inputset_char_cost_align, + "min_cost": inputset_min_cost, "constant": [i for i in range(len(self.alphabet.mapping_to_int))], }, show_mlir=args.show_mlir, p_error=10**-20, show_optimizer=args.show_optimizer, - comparison_strategy_preference=fhe.ComparisonStrategy.ONE_TLU_PROMOTED, min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED, ) @@ -189,7 +187,7 @@ def _compute_in_simulation(self, list_patterns: list): a_as_int = self.alphabet.encode(a) b_as_int = self.alphabet.encode(b) - l1_simulate = levenshtein_simulate(self.module, a_as_int, b_as_int) + l1_simulate = levenshtein_fhe(self.module, a_as_int, b_as_int) l1_clear = levenshtein_clear(a_as_int, b_as_int) assert l1_simulate == l1_clear, f" {l1_simulate=} and {l1_clear=} are different" @@ -209,7 +207,7 @@ def _compute_in_fhe(self, list_patterns: list, show_distance: bool = False): l1_fhe_enc = levenshtein_fhe(self.module, a_enc, b_enc) time_end = time.time() - l1_fhe = self.module.mix.decrypt(l1_fhe_enc) # type: ignore + l1_fhe = self.module.min_cost.decrypt(l1_fhe_enc) # type: ignore l1_clear = levenshtein_clear(a, b) @@ -225,9 +223,9 @@ def _compute_in_fhe(self, list_patterns: list, show_distance: bool = False): @fhe.module() class LevenshsteinModule: @fhe.function({"x": "encrypted", "y": "encrypted"}) - def equal(x, y): - """Assert equality between two chars of the alphabet.""" - return x == y + def char_cost_align(x, y): + """Cost of having two chars of the alphabet aligned.""" + return x != y @fhe.function({"x": "clear"}) def constant(x): @@ -235,40 +233,35 @@ def constant(x): @fhe.function( { - "is_equal": "encrypted", - "if_equal": "encrypted", - "case_1": "encrypted", - "case_2": "encrypted", - "case_3": "encrypted", + "char_cost_align": "encrypted", + "align": "encrypted", + "insertion_1": "encrypted", + "insertion_2": "encrypted", } ) - def mix(is_equal, if_equal, case_1, case_2, case_3): - """Compute the min of (case_1, case_2, case_3), and then return `if_equal` if `is_equal` is - True, or the min in the other case.""" - min_12 = np.minimum(case_1, case_2) - min_123 = np.minimum(min_12, case_3) - - return fhe.if_then_else(is_equal, if_equal, 1 + min_123) - - # There is a single output in mix: it can go to - # - input 1 of mix - # - input 2 of mix - # - input 3 of mix - # - input 4 of mix + def min_cost(char_cost_align, align, insertion_1, insertion_2): + """Compute `min(align + char_cost_align, insertion_1 + 1, insertion_2 + 1)`.""" + align = align + char_cost_align + insertion_min = 1 + np.minimum(insertion_1, insertion_2) + result = np.minimum(align, insertion_min) + return fhe.refresh(result) # to have at least one operation reducing the noise + + # There is a single output in min_cost: it can go to + # - input 1 of min_cost + # - input 2 of min_cost + # - input 3 of min_cost # or just be the final output # - # There is a single output of equal, it goes to input 0 of mix + # There is a single output of equal, it goes to input 0 of min_cost composition = fhe.Wired( [ - fhe.Wire(fhe.AllOutputs(equal), fhe.Input(mix, 0)), - fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 1)), - fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 2)), - fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 3)), - fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 4)), - fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 1)), - fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 2)), - fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 3)), - fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 4)), + fhe.Wire(fhe.AllOutputs(char_cost_align), fhe.Input(min_cost, 0)), + fhe.Wire(fhe.AllOutputs(min_cost), fhe.Input(min_cost, 1)), + fhe.Wire(fhe.AllOutputs(min_cost), fhe.Input(min_cost, 2)), + fhe.Wire(fhe.AllOutputs(min_cost), fhe.Input(min_cost, 3)), + fhe.Wire(fhe.AllOutputs(constant), fhe.Input(min_cost, 1)), + fhe.Wire(fhe.AllOutputs(constant), fhe.Input(min_cost, 2)), + fhe.Wire(fhe.AllOutputs(constant), fhe.Input(min_cost, 3)), ] ) @@ -281,34 +274,13 @@ def levenshtein_clear(x: str, y: str): if len(y) == 0: return len(x) - if x[0] == y[0]: - return levenshtein_clear(x[1:], y[1:]) - - case_1 = levenshtein_clear(x[1:], y) - case_2 = levenshtein_clear(x, y[1:]) - case_3 = levenshtein_clear(x[1:], y[1:]) - - return 1 + min(case_1, case_2, case_3) - - -@lru_cache -def levenshtein_simulate(module: fhe.module, x: str, y: str): - """Compute the distance in simulation.""" - if len(x) == 0: - return len(y) - if len(y) == 0: - return len(x) - - if_equal = levenshtein_simulate(module, x[1:], y[1:]) - case_1 = levenshtein_simulate(module, x[1:], y) - case_2 = levenshtein_simulate(module, x, y[1:]) - case_3 = if_equal - - is_equal = module.equal(x[0], y[0]) # type: ignore - returned_value = module.mix(is_equal, if_equal, case_1, case_2, case_3) # type: ignore - - return returned_value + align = levenshtein_clear(x[1:], y[1:]) + insertion_1 = levenshtein_clear(x[1:], y) + insertion_2 = levenshtein_clear(x, y[1:]) + char_cost_align = int(x[0] != y[0]) + cost_insert = 1 + return min(align + char_cost_align, insertion_1 + cost_insert, insertion_2 + cost_insert) @lru_cache def levenshtein_fhe(module: fhe.module, x: str, y: str): @@ -318,13 +290,12 @@ def levenshtein_fhe(module: fhe.module, x: str, y: str): if len(y) == 0: return module.constant.run(module.constant.encrypt(len(x))) # type: ignore - if_equal = levenshtein_fhe(module, x[1:], y[1:]) - case_1 = levenshtein_fhe(module, x[1:], y) - case_2 = levenshtein_fhe(module, x, y[1:]) - case_3 = if_equal + align = levenshtein_fhe(module, x[1:], y[1:]) + insertion_1 = levenshtein_fhe(module, x[1:], y) + insertion_2 = levenshtein_fhe(module, x, y[1:]) - is_equal = module.equal.run(x[0], y[0]) # type: ignore - returned_value = module.mix.run(is_equal, if_equal, case_1, case_2, case_3) # type: ignore + char_cost_align = module.char_cost_align.run(x[0], y[0]) # type: ignore + returned_value = module.min_cost.run(char_cost_align, align, insertion_1, insertion_2) # type: ignore return returned_value