Skip to content

Commit

Permalink
chore(frontend-python): simplify levenshtein example
Browse files Browse the repository at this point in the history
  • Loading branch information
rudy-6-4 committed Jul 30, 2024
1 parent eda323a commit f01c7a8
Showing 1 changed file with 45 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,44 +140,42 @@ def _encode_and_encrypt_strings(self, a: str, b: str) -> tuple:
a_as_int = self.alphabet.encode(a)
b_as_int = self.alphabet.encode(b)

a_enc = tuple(self.module.equal.encrypt(ai, None)[0] for ai in a_as_int) # type: ignore
b_enc = tuple(self.module.equal.encrypt(None, bi)[1] for bi in b_as_int) # type: ignore
a_enc = tuple(self.module.char_cost_align.encrypt(ai, None)[0] for ai in a_as_int) # type: ignore
b_enc = tuple(self.module.char_cost_align.encrypt(None, bi)[1] for bi in b_as_int) # type: ignore

return a_enc, b_enc

def _compile_module(self, args):
"""Compile the FHE module."""
assert len(self.alphabet.mapping_to_int) > 0, "Mapping not defined"

inputset_equal = [
inputset_char_cost_align = [
(
self.alphabet.random_pick_in_values(),
self.alphabet.random_pick_in_values(),
)
for _ in range(1000)
]
inputset_mix = [
inputset_min_cost = [
(
np.random.randint(2),
np.random.randint(args.max_string_length),
np.random.randint(args.max_string_length),
np.random.randint(args.max_string_length),
np.random.randint(args.max_string_length),
)
for _ in range(1000)
]

# pylint: disable-next=no-member
self.module = LevenshsteinModule.compile(
{
"equal": inputset_equal,
"mix": inputset_mix,
"char_cost_align": inputset_char_cost_align,
"min_cost": inputset_min_cost,
"constant": [i for i in range(len(self.alphabet.mapping_to_int))],
},
show_mlir=args.show_mlir,
p_error=10**-20,
show_optimizer=args.show_optimizer,
comparison_strategy_preference=fhe.ComparisonStrategy.ONE_TLU_PROMOTED,
min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED,
)

Expand All @@ -189,7 +187,7 @@ def _compute_in_simulation(self, list_patterns: list):
a_as_int = self.alphabet.encode(a)
b_as_int = self.alphabet.encode(b)

l1_simulate = levenshtein_simulate(self.module, a_as_int, b_as_int)
l1_simulate = levenshtein_fhe(self.module, a_as_int, b_as_int)
l1_clear = levenshtein_clear(a_as_int, b_as_int)

assert l1_simulate == l1_clear, f" {l1_simulate=} and {l1_clear=} are different"
Expand All @@ -209,7 +207,7 @@ def _compute_in_fhe(self, list_patterns: list, show_distance: bool = False):
l1_fhe_enc = levenshtein_fhe(self.module, a_enc, b_enc)
time_end = time.time()

l1_fhe = self.module.mix.decrypt(l1_fhe_enc) # type: ignore
l1_fhe = self.module.min_cost.decrypt(l1_fhe_enc) # type: ignore

l1_clear = levenshtein_clear(a, b)

Expand All @@ -225,50 +223,45 @@ def _compute_in_fhe(self, list_patterns: list, show_distance: bool = False):
@fhe.module()
class LevenshsteinModule:
@fhe.function({"x": "encrypted", "y": "encrypted"})
def equal(x, y):
"""Assert equality between two chars of the alphabet."""
return x == y
def char_cost_align(x, y):
"""Cost of having two chars of the alphabet aligned."""
return x != y

@fhe.function({"x": "clear"})
def constant(x):
return fhe.zero() + x

@fhe.function(
{
"is_equal": "encrypted",
"if_equal": "encrypted",
"case_1": "encrypted",
"case_2": "encrypted",
"case_3": "encrypted",
"char_cost_align": "encrypted",
"align": "encrypted",
"insertion_1": "encrypted",
"insertion_2": "encrypted",
}
)
def mix(is_equal, if_equal, case_1, case_2, case_3):
"""Compute the min of (case_1, case_2, case_3), and then return `if_equal` if `is_equal` is
True, or the min in the other case."""
min_12 = np.minimum(case_1, case_2)
min_123 = np.minimum(min_12, case_3)

return fhe.if_then_else(is_equal, if_equal, 1 + min_123)

# There is a single output in mix: it can go to
# - input 1 of mix
# - input 2 of mix
# - input 3 of mix
# - input 4 of mix
def min_cost(char_cost_align, align, insertion_1, insertion_2):
"""Compute `min(align + char_cost_align, insertion_1 + 1, insertion_2 + 1)`."""
align = align + char_cost_align
insertion_min = 1 + np.minimum(insertion_1, insertion_2)
result = np.minimum(align, insertion_min)
return fhe.refresh(result) # to have at least one operation reducing the noise

# There is a single output in min_cost: it can go to
# - input 1 of min_cost
# - input 2 of min_cost
# - input 3 of min_cost
# or just be the final output
#
# There is a single output of equal, it goes to input 0 of mix
# There is a single output of equal, it goes to input 0 of min_cost
composition = fhe.Wired(
[
fhe.Wire(fhe.AllOutputs(equal), fhe.Input(mix, 0)),
fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 1)),
fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 2)),
fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 3)),
fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 4)),
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 1)),
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 2)),
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 3)),
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 4)),
fhe.Wire(fhe.AllOutputs(char_cost_align), fhe.Input(min_cost, 0)),
fhe.Wire(fhe.AllOutputs(min_cost), fhe.Input(min_cost, 1)),
fhe.Wire(fhe.AllOutputs(min_cost), fhe.Input(min_cost, 2)),
fhe.Wire(fhe.AllOutputs(min_cost), fhe.Input(min_cost, 3)),
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(min_cost, 1)),
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(min_cost, 2)),
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(min_cost, 3)),
]
)

Expand All @@ -281,34 +274,13 @@ def levenshtein_clear(x: str, y: str):
if len(y) == 0:
return len(x)

if x[0] == y[0]:
return levenshtein_clear(x[1:], y[1:])

case_1 = levenshtein_clear(x[1:], y)
case_2 = levenshtein_clear(x, y[1:])
case_3 = levenshtein_clear(x[1:], y[1:])

return 1 + min(case_1, case_2, case_3)


@lru_cache
def levenshtein_simulate(module: fhe.module, x: str, y: str):
"""Compute the distance in simulation."""
if len(x) == 0:
return len(y)
if len(y) == 0:
return len(x)

if_equal = levenshtein_simulate(module, x[1:], y[1:])
case_1 = levenshtein_simulate(module, x[1:], y)
case_2 = levenshtein_simulate(module, x, y[1:])
case_3 = if_equal

is_equal = module.equal(x[0], y[0]) # type: ignore
returned_value = module.mix(is_equal, if_equal, case_1, case_2, case_3) # type: ignore

return returned_value
align = levenshtein_clear(x[1:], y[1:])
insertion_1 = levenshtein_clear(x[1:], y)
insertion_2 = levenshtein_clear(x, y[1:])

char_cost_align = int(x[0] != y[0])
cost_insert = 1
return min(align + char_cost_align, insertion_1 + cost_insert, insertion_2 + cost_insert)

@lru_cache
def levenshtein_fhe(module: fhe.module, x: str, y: str):
Expand All @@ -318,13 +290,12 @@ def levenshtein_fhe(module: fhe.module, x: str, y: str):
if len(y) == 0:
return module.constant.run(module.constant.encrypt(len(x))) # type: ignore

if_equal = levenshtein_fhe(module, x[1:], y[1:])
case_1 = levenshtein_fhe(module, x[1:], y)
case_2 = levenshtein_fhe(module, x, y[1:])
case_3 = if_equal
align = levenshtein_fhe(module, x[1:], y[1:])
insertion_1 = levenshtein_fhe(module, x[1:], y)
insertion_2 = levenshtein_fhe(module, x, y[1:])

is_equal = module.equal.run(x[0], y[0]) # type: ignore
returned_value = module.mix.run(is_equal, if_equal, case_1, case_2, case_3) # type: ignore
char_cost_align = module.char_cost_align.run(x[0], y[0]) # type: ignore
returned_value = module.min_cost.run(char_cost_align, align, insertion_1, insertion_2) # type: ignore

return returned_value

Expand Down

0 comments on commit f01c7a8

Please sign in to comment.