Skip to content

Commit

Permalink
docs(frontend): make functions and argparse
Browse files Browse the repository at this point in the history
  • Loading branch information
bcm-at-zama committed Jun 28, 2024
1 parent f4ff88a commit 9de185e
Showing 1 changed file with 123 additions and 81 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Computing Levenstein distance between strings, https://en.wikipedia.org/wiki/Levenshtein_distance

import time
import argparse
from functools import lru_cache

import numpy
Expand All @@ -16,7 +17,7 @@
class MyModule:
@fhe.function({"x": "encrypted", "y": "encrypted"})
def equal(x, y):
return x == y
return fhe.univariate(lambda x: x == 0)(x - y)

@fhe.function(
{
Expand Down Expand Up @@ -69,29 +70,6 @@ def map_string_to_int(s):
return tuple(ord(si) for si in s)


# Compilation
inputset_equal = [(random_letter_as_int(), random_letter_as_int()) for _ in range(1000)]
inputset_mix = [
(
numpy.random.randint(2),
numpy.random.randint(max_string_length),
numpy.random.randint(max_string_length),
numpy.random.randint(max_string_length),
numpy.random.randint(max_string_length),
)
for _ in range(100)
]

my_module = MyModule.compile(
{"equal": inputset_equal, "mix": inputset_mix},
show_mlir=True,
p_error=10**-20,
show_optimizer=True,
comparison_strategy_preference=fhe.ComparisonStrategy.ONE_TLU_PROMOTED,
min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED,
)


# Function in clear, for reference and comparison
@lru_cache
def levenshtein_clear(x, y):
Expand All @@ -112,15 +90,15 @@ def levenshtein_clear(x, y):

# Function in FHE-simulate, to debug
@lru_cache
def levenshtein_simulate(x, y):
def levenshtein_simulate(my_module, x, y):
if len(x) == 0:
return len(y)
if len(y) == 0:
return len(x)

if_equal = levenshtein_simulate(x[1:], y[1:])
case_1 = levenshtein_simulate(x[1:], y)
case_2 = levenshtein_simulate(x, y[1:])
if_equal = levenshtein_simulate(my_module, x[1:], y[1:])
case_1 = levenshtein_simulate(my_module, x[1:], y)
case_2 = levenshtein_simulate(my_module, x, y[1:])
case_3 = if_equal

is_equal = my_module.equal(x[0], y[0])
Expand All @@ -131,17 +109,17 @@ def levenshtein_simulate(x, y):

# Function in FHE
@lru_cache
def levenshtein_fhe(x, y):
def levenshtein_fhe(my_module, x, y):
if len(x) == 0:
# In clear, that's return len(y)
return my_module.mix.encrypt(None, len(y), None, None, None)[1]
if len(y) == 0:
# In clear, that's return len(x)
return my_module.mix.encrypt(None, len(x), None, None, None)[1]

if_equal = levenshtein_fhe(x[1:], y[1:])
case_1 = levenshtein_fhe(x[1:], y)
case_2 = levenshtein_fhe(x, y[1:])
if_equal = levenshtein_fhe(my_module, x[1:], y[1:])
case_1 = levenshtein_fhe(my_module, x[1:], y)
case_2 = levenshtein_fhe(my_module, x, y[1:])
case_3 = if_equal

# In FHE
Expand All @@ -151,70 +129,134 @@ def levenshtein_fhe(x, y):
return returned_value


# Random patterns of different lengths
list_patterns = [
("", ""),
("", "a"),
("b", ""),
("a", "a"),
("a", "b"),
]

for length_1 in range(max_string_length + 1):
for length_2 in range(max_string_length + 1):
list_patterns += [
(
random_string(length_1),
random_string(length_2),
)
for _ in range(1)
]
# Manage user args
def manage_args():
parser = argparse.ArgumentParser(description="Levenshtein distance in Concrete.")
parser.add_argument(
"--show_mlir",
dest="show_mlir",
action="store_true",
help="Show the MLIR",
)
args = parser.parse_args()
return args


def compile_module(args):
# Compilation
inputset_equal = [(random_letter_as_int(), random_letter_as_int()) for _ in range(1000)]
inputset_mix = [
(
numpy.random.randint(2),
numpy.random.randint(max_string_length),
numpy.random.randint(max_string_length),
numpy.random.randint(max_string_length),
numpy.random.randint(max_string_length),
)
for _ in range(100)
]

my_module = MyModule.compile(
{"equal": inputset_equal, "mix": inputset_mix},
show_mlir=args.show_mlir,
p_error=10**-20,
show_optimizer=True,
comparison_strategy_preference=fhe.ComparisonStrategy.ONE_TLU_PROMOTED,
min_max_strategy_preference=fhe.MinMaxStrategy.ONE_TLU_PROMOTED,
)

return my_module


def prepare_random_patterns():
# Random patterns of different lengths
list_patterns = [
("", ""),
("", "a"),
("b", ""),
("a", "a"),
("a", "b"),
]

for length_1 in range(max_string_length + 1):
for length_2 in range(max_string_length + 1):
list_patterns += [
(
random_string(length_1),
random_string(length_2),
)
for _ in range(1)
]

# Checks in simulation
print("Computations in simulation\n")
return list_patterns

for a, b in list_patterns:

print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="")
def compute_in_simulation(my_module, list_patterns):

assert len(a) <= max_string_length
assert len(b) <= max_string_length
# Checks in simulation
print("Computations in simulation\n")

a_as_int = map_string_to_int(a)
b_as_int = map_string_to_int(b)
for a, b in list_patterns:

l1_simulate = levenshtein_simulate(a_as_int, b_as_int)
l1_clear = levenshtein_clear(a_as_int, b_as_int)
print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="")

assert l1_simulate == l1_clear, f" {l1_simulate=} and {l1_clear=} are different"
print(" - OK")
assert len(a) <= max_string_length
assert len(b) <= max_string_length

# Key generation
my_module.keygen()
a_as_int = map_string_to_int(a)
b_as_int = map_string_to_int(b)

# Checks in FHE
print("\nComputations in FHE\n")
l1_simulate = levenshtein_simulate(my_module, a_as_int, b_as_int)
l1_clear = levenshtein_clear(a_as_int, b_as_int)

for a, b in list_patterns:
assert l1_simulate == l1_clear, f" {l1_simulate=} and {l1_clear=} are different"
print(" - OK")

print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="")

assert len(a) <= max_string_length
assert len(b) <= max_string_length
def compute_in_fhe(my_module, list_patterns):
# Key generation
my_module.keygen()

a_as_int = map_string_to_int(a)
b_as_int = map_string_to_int(b)
# Checks in FHE
print("\nComputations in FHE\n")

a_enc = tuple(my_module.equal.encrypt(ai, None)[0] for ai in a_as_int)
b_enc = tuple(my_module.equal.encrypt(None, bi)[1] for bi in b_as_int)
for a, b in list_patterns:

time_begin = time.time()
l1_fhe_enc = levenshtein_fhe(a_enc, b_enc)
time_end = time.time()
print(f" Computing Levenshtein between strings '{a}' and '{b}'", end="")

l1_fhe = my_module.mix.decrypt(l1_fhe_enc)
assert len(a) <= max_string_length
assert len(b) <= max_string_length

a_as_int = map_string_to_int(a)
b_as_int = map_string_to_int(b)

a_enc = tuple(my_module.equal.encrypt(ai, None)[0] for ai in a_as_int)
b_enc = tuple(my_module.equal.encrypt(None, bi)[1] for bi in b_as_int)

time_begin = time.time()
l1_fhe_enc = levenshtein_fhe(my_module, a_enc, b_enc)
time_end = time.time()

l1_fhe = my_module.mix.decrypt(l1_fhe_enc)

l1_clear = levenshtein_clear(a, b)

assert l1_fhe == l1_clear, f" {l1_fhe=} and {l1_clear=} are different"
print(f" - OK in {time_end - time_begin:.2f} seconds")


def main():
print()

# Options by the user
args = manage_args()

#
my_module = compile_module(args)
list_patterns = prepare_random_patterns()
compute_in_simulation(my_module, list_patterns)
compute_in_fhe(my_module, list_patterns)

l1_clear = levenshtein_clear(a, b)

assert l1_fhe == l1_clear, f" {l1_fhe=} and {l1_clear=} are different"
print(f" - OK in {time_end - time_begin:.2f} seconds")
if __name__ == "__main__":
main()

0 comments on commit 9de185e

Please sign in to comment.