-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(frontend): refactor levenshtein distance example, test and bench…
…mark it
- Loading branch information
1 parent
d3c5d64
commit 532000f
Showing
3 changed files
with
360 additions
and
79 deletions.
There are no files selected for viewing
167 changes: 167 additions & 0 deletions
167
frontends/concrete-python/benchmarks/levenshtein_distance.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
""" | ||
Benchmarks of the levenshtein distance example. | ||
""" | ||
|
||
from functools import lru_cache | ||
from pathlib import Path | ||
from typing import Tuple | ||
|
||
import py_progress_tracker as progress | ||
|
||
from concrete import fhe | ||
from examples.levenshtein_distance.levenshtein_distance import Alphabet, LevenshteinDistance | ||
|
||
|
||
@lru_cache | ||
def levenshtein_on_server( | ||
server: fhe.Server, | ||
x: Tuple[fhe.Value], | ||
y: Tuple[fhe.Value], | ||
evaluation_keys: fhe.EvaluationKeys, | ||
): | ||
""" | ||
Compute levenshtein distance on a server. | ||
""" | ||
|
||
if len(x) == 0: | ||
return server.run( | ||
len(y), | ||
function_name="constant", | ||
evaluation_keys=evaluation_keys, | ||
) | ||
|
||
if len(y) == 0: | ||
return server.run( | ||
len(x), | ||
function_name="constant", | ||
evaluation_keys=evaluation_keys, | ||
) | ||
|
||
if_equal = levenshtein_on_server(server, x[1:], y[1:], evaluation_keys) | ||
case_1 = levenshtein_on_server(server, x[1:], y, evaluation_keys) | ||
case_2 = levenshtein_on_server(server, x, y[1:], evaluation_keys) | ||
case_3 = if_equal | ||
|
||
is_equal = server.run( | ||
x[0], | ||
y[0], | ||
function_name="equal", | ||
evaluation_keys=evaluation_keys, | ||
) | ||
result = server.run( | ||
is_equal, | ||
if_equal, | ||
case_1, | ||
case_2, | ||
case_3, | ||
function_name="mix", | ||
evaluation_keys=evaluation_keys, | ||
) | ||
|
||
return result | ||
|
||
|
||
def targets(): | ||
""" | ||
Generates targets to benchmark. | ||
""" | ||
|
||
result = [] | ||
for alphabet in ["ACTG", "string"]: | ||
for max_string_length in [2, 4, 8]: | ||
result.append( | ||
{ | ||
"id": ( | ||
f"levenshtein-distance :: " | ||
f"alphabet = {alphabet} | max_string_size = {max_string_length}" | ||
), | ||
"name": ( | ||
f"Levenshtein distance between two strings " | ||
f"of length {max_string_length} " | ||
f"from {alphabet} alphabet" | ||
), | ||
"parameters": { | ||
"alphabet": alphabet, | ||
"max_string_length": max_string_length, | ||
}, | ||
} | ||
) | ||
return result | ||
|
||
|
||
@progress.track(targets()) | ||
def main(alphabet, max_string_length): | ||
""" | ||
Benchmark a target. | ||
Args: | ||
alphabet: | ||
alphabet of the inputs | ||
max_string_length: | ||
maximum size of the inputs | ||
""" | ||
|
||
cached_server = Path(f"levenshtein.{alphabet}.{max_string_length}.server.zip") | ||
alphabet = Alphabet.init_by_name(alphabet) | ||
|
||
print("Compiling...") | ||
if cached_server.exists(): | ||
server = fhe.Server.load(cached_server) | ||
client = fhe.Client(server.client_specs, keyset_cache_directory=".keys") | ||
else: | ||
levenshtein_distance = LevenshteinDistance( | ||
alphabet, | ||
max_string_length, | ||
configuration=fhe.Configuration( | ||
enable_unsafe_features=True, | ||
use_insecure_key_cache=True, | ||
insecure_key_cache_location=".keys", | ||
), | ||
) | ||
levenshtein_distance.module.server.save(cached_server) | ||
|
||
server = levenshtein_distance.module.server | ||
client = levenshtein_distance.module.client | ||
|
||
print("Generating keys...") | ||
client.keygen() | ||
|
||
print("Warming up...") | ||
|
||
sample_a = alphabet.random_string(max_string_length) | ||
sample_b = alphabet.random_string(max_string_length) | ||
|
||
encrypted_sample_a = tuple( | ||
client.encrypt(ai, None, function_name="equal")[0] for ai in alphabet.encode(sample_a) | ||
) | ||
encrypted_sample_b = tuple( | ||
client.encrypt(None, bi, function_name="equal")[1] for bi in alphabet.encode(sample_b) | ||
) | ||
|
||
levenshtein_on_server( | ||
server, | ||
encrypted_sample_a, | ||
encrypted_sample_b, | ||
client.evaluation_keys, | ||
) | ||
|
||
for i in range(5): | ||
print(f"Running subsample {i + 1} out of 5...") | ||
|
||
sample_a = alphabet.random_string(max_string_length) | ||
sample_b = alphabet.random_string(max_string_length) | ||
|
||
encrypted_sample_a = tuple( | ||
client.encrypt(ai, None, function_name="equal")[0] for ai in alphabet.encode(sample_a) | ||
) | ||
encrypted_sample_b = tuple( | ||
client.encrypt(None, bi, function_name="equal")[1] for bi in alphabet.encode(sample_b) | ||
) | ||
|
||
with progress.measure(id="evaluation-time-ms", label="Evaluation Time (ms)"): | ||
levenshtein_on_server( | ||
server, | ||
encrypted_sample_a, | ||
encrypted_sample_b, | ||
client.evaluation_keys, | ||
) |
Oops, something went wrong.