-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
176 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 47 additions & 0 deletions
47
frontends/concrete-python/examples/tfhers-ml/compute_error.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import json | ||
import click | ||
|
||
@click.command() | ||
@click.option('--plaintext-file', '-p', required=True, help='Path to the rescaled plaintext values file.') | ||
@click.option('--quantized-predictions-file', '-q', required=True, help='Path to the test_values.json file containing quantized predictions.') | ||
def compute_error(plaintext_file, quantized_predictions_file): | ||
"""Compute the error between decrypted rescaled values and quantized predictions.""" | ||
# Read rescaled plaintext values from plaintext_file | ||
with open(plaintext_file) as f: | ||
rescaled_plaintext_values = [int(x) for x in f.read().strip().split(',')] | ||
|
||
# Read quantized_predictions from quantized_predictions_file | ||
with open(quantized_predictions_file) as f: | ||
data = json.load(f) | ||
quantized_predictions = data['quantized_predictions'] | ||
|
||
# Flatten quantized_predictions | ||
quantized_predictions_flat = [int(x) for sublist in quantized_predictions for x in sublist] | ||
|
||
# Round down 10 bits using (x // (1 << 10)) * (1 << 10) | ||
rounded_quantized_predictions = [ | ||
(x // (1 << 10)) * (1 << 10) for x in quantized_predictions_flat | ||
] | ||
|
||
# Compare rescaled_plaintext_values with rounded_quantized_predictions | ||
num_differences = 0 | ||
total_values = len(rescaled_plaintext_values) | ||
errors = [] | ||
for i in range(total_values): | ||
a = rescaled_plaintext_values[i] | ||
b = rounded_quantized_predictions[i] | ||
print(f"output: {a}, expected: {b}") | ||
if a != b: | ||
num_differences += 1 | ||
error_in_units = round((a - b) / (1 << 10)) | ||
errors.append((i, error_in_units)) | ||
|
||
print("Number of differing values: {}".format(num_differences)) | ||
print("Total values compared: {}".format(total_values)) | ||
if num_differences > 0: | ||
print("Differences (index, error in units of 2^10):") | ||
for idx, error_in_units in errors: | ||
print("Index {}: error = {}".format(idx, error_in_units)) | ||
|
||
if __name__ == '__main__': | ||
compute_error() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
120 changes: 120 additions & 0 deletions
120
frontends/concrete-python/examples/tfhers-ml/test_values.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
{ | ||
"input_fp32_values": [ | ||
[ | ||
5.1, | ||
3.5, | ||
1.4, | ||
0.2 | ||
], | ||
[ | ||
4.9, | ||
3.0, | ||
1.4, | ||
0.2 | ||
], | ||
[ | ||
4.7, | ||
3.2, | ||
1.3, | ||
0.2 | ||
], | ||
[ | ||
4.6, | ||
3.1, | ||
1.5, | ||
0.2 | ||
], | ||
[ | ||
5.0, | ||
3.6, | ||
1.4, | ||
0.2 | ||
] | ||
], | ||
"input_quantized_values": [ | ||
[ | ||
36, | ||
-17, | ||
-85, | ||
-124 | ||
], | ||
[ | ||
29, | ||
-33, | ||
-85, | ||
-124 | ||
], | ||
[ | ||
23, | ||
-26, | ||
-88, | ||
-124 | ||
], | ||
[ | ||
19, | ||
-30, | ||
-82, | ||
-124 | ||
], | ||
[ | ||
32, | ||
-13, | ||
-85, | ||
-124 | ||
] | ||
], | ||
"quantized_predictions": [ | ||
[ | ||
50675, | ||
17162, | ||
-67716 | ||
], | ||
[ | ||
50063, | ||
17220, | ||
-67169 | ||
], | ||
[ | ||
50881, | ||
16989, | ||
-67759 | ||
], | ||
[ | ||
50035, | ||
16885, | ||
-66819 | ||
], | ||
[ | ||
50943, | ||
16998, | ||
-67824 | ||
] | ||
], | ||
"dequantized_predictions": [ | ||
[ | ||
7.317482754272291, | ||
3.3779769146644965, | ||
-10.61937344282465 | ||
], | ||
[ | ||
6.932649941209247, | ||
3.4144479982554388, | ||
-10.275413395854903 | ||
], | ||
[ | ||
7.44701798219874, | ||
3.269192475677721, | ||
-10.64641234962483 | ||
], | ||
[ | ||
6.915043211199826, | ||
3.2037960499284455, | ||
-10.055329270737149 | ||
], | ||
[ | ||
7.486004312933885, | ||
3.2748517817521776, | ||
-10.687285115718128 | ||
] | ||
] | ||
} |