-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1151 from zama-ai/tfhers-ml-example
Tfhers ml example
- Loading branch information
Showing
13 changed files
with
709 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# TFHE-rs interoperability example | ||
|
||
This is a similar example to the [first TFHE-rs example](../tfhers/), except that it uses tensors and run a linear ML model. It also uses quantization. | ||
|
||
## Make tmpdir | ||
|
||
We want to setup a temporary working directory first: | ||
|
||
```sh | ||
export TDIR=`mktemp -d` | ||
``` | ||
|
||
## KeyGen | ||
|
||
First we need to build the TFHE-rs utility in [this directory](../../tests/tfhers-utils/) by running the following: | ||
|
||
```sh | ||
cd ../../tests/tfhers-utils/ | ||
make build | ||
cd - | ||
``` | ||
|
||
Then we can generate keys in two different ways. You only need to run one of the following methods. | ||
|
||
#### Generate the Secret Key in Concrete | ||
|
||
We start by doing keygen in Concrete: | ||
|
||
```sh | ||
python example.py keygen -o $TDIR/concrete_sk -k $TDIR/concrete_keyset | ||
``` | ||
|
||
Then we do a partial keygen in TFHE-rs: | ||
|
||
```sh | ||
../../tests/tfhers-utils/target/release/tfhers_utils keygen --lwe-sk $TDIR/concrete_sk --output-lwe-sk $TDIR/tfhers_sk -c $TDIR/tfhers_client_key -s $TDIR/tfhers_server_key | ||
``` | ||
|
||
#### Generate the Secret Key in TFHE-rs | ||
|
||
We start by doing keygen in TFHE-rs: | ||
|
||
```sh | ||
../../tests/tfhers-utils/target/release/tfhers_utils keygen --output-lwe-sk $TDIR/tfhers_sk -c $TDIR/tfhers_client_key -s $TDIR/tfhers_server_key | ||
``` | ||
|
||
Then we do a partial keygen in Concrete: | ||
|
||
```sh | ||
python example.py keygen -s $TDIR/tfhers_sk -o $TDIR/concrete_sk -k $TDIR/concrete_keyset | ||
``` | ||
|
||
## Quantize values | ||
|
||
We need to quantize floating point inputs using a pre-built quantizer for our ML model: | ||
|
||
```sh | ||
../../tests/tfhers-utils/target/release/tfhers_utils quantize --value=5.1,3.5,1.4,0.2,4.9,3,1.4,0.2,4.7,3.2,1.3,0.2,4.6,3.1,1.5,0.2,5,3.6,1.4,0.2 --config ./input_quantizer.json -o $TDIR/quantized_values | ||
``` | ||
|
||
## Encrypt in TFHE-rs | ||
|
||
```sh | ||
../../tests/tfhers-utils/target/release/tfhers_utils encrypt-with-key --signed --value=$(cat $TDIR/quantized_values) --ciphertext $TDIR/tfhers_ct --client-key $TDIR/tfhers_client_key | ||
``` | ||
|
||
## Run in Concrete | ||
|
||
```sh | ||
python example.py run -k $TDIR/concrete_keyset -c $TDIR/tfhers_ct -o $TDIR/tfhers_ct_out | ||
``` | ||
|
||
## Decrypt in TFHE-rs | ||
|
||
```sh | ||
../../tests/tfhers-utils/target/release/tfhers_utils decrypt-with-key --tensor --signed --ciphertext $TDIR/tfhers_ct_out --client-key $TDIR/tfhers_client_key --plaintext $TDIR/result_plaintext | ||
``` | ||
|
||
## Rescale Output | ||
|
||
At the end of the circuit, we are rounding the result to 8 bits, discarding the remaining LSB bits. As we have `lsbs_to_remove=10` we are re-introducing the 10 bits of LSB. | ||
|
||
```sh | ||
python -c "print(','.join(map(lambda x: str(x << 10), [$(cat $TDIR/result_plaintext)])))" > $TDIR/rescaled_plaintext | ||
``` | ||
|
||
|
||
## Dequantize values | ||
|
||
We need to dequantize integer outputs using a pre-built quantizer for our ML model: | ||
|
||
```sh | ||
../../tests/tfhers-utils/target/release/tfhers_utils dequantize --value=$(cat $TDIR/rescaled_plaintext) --shape=5,3 --config ./output_quantizer.json | ||
``` | ||
|
||
## Compute error | ||
|
||
We compare the output to the expected result: | ||
|
||
```sh | ||
python compute_error.py --plaintext-file "$TDIR/rescaled_plaintext" --quantized-predictions-file "test_values.json" | ||
``` | ||
|
||
## Clean tmpdir | ||
|
||
```sh | ||
rm -rf $TDIR | ||
``` |
63 changes: 63 additions & 0 deletions
63
frontends/concrete-python/examples/tfhers-ml/compute_error.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import json | ||
|
||
import click | ||
|
||
|
||
@click.command() | ||
@click.option( | ||
"--plaintext-file", "-p", required=True, help="Path to the rescaled plaintext values file." | ||
) | ||
@click.option( | ||
"--quantized-predictions-file", | ||
"-q", | ||
required=True, | ||
help="Path to the test_values.json file containing quantized predictions.", | ||
) | ||
def compute_error(plaintext_file, quantized_predictions_file): | ||
"""Compute the error between decrypted rescaled values and quantized predictions.""" | ||
# Read rescaled plaintext values from plaintext_file | ||
with open(plaintext_file) as f: | ||
rescaled_plaintext_values = [int(x) for x in f.read().strip().split(",")] | ||
|
||
# Read quantized_predictions from quantized_predictions_file | ||
with open(quantized_predictions_file) as f: | ||
data = json.load(f) | ||
quantized_predictions = data["quantized_predictions"] | ||
|
||
# Flatten quantized_predictions | ||
quantized_predictions_flat = [int(x) for sublist in quantized_predictions for x in sublist] | ||
|
||
# Round down 10 bits using (x // (1 << 10)) * (1 << 10) | ||
rounded_quantized_predictions = [ | ||
(x // (1 << 10)) * (1 << 10) for x in quantized_predictions_flat | ||
] | ||
|
||
# Compare rescaled_plaintext_values with rounded_quantized_predictions | ||
num_differences = 0 | ||
total_values = len(rescaled_plaintext_values) | ||
errors = [] | ||
for i in range(total_values): | ||
a = rescaled_plaintext_values[i] | ||
b = rounded_quantized_predictions[i] | ||
print(f"output: {a}, expected: {b}") | ||
if a != b: | ||
num_differences += 1 | ||
error_in_units = round((a - b) / (1 << 10)) | ||
errors.append((i, error_in_units)) | ||
|
||
print("Number of differing values: {}".format(num_differences)) | ||
print("Total values compared: {}".format(total_values)) | ||
if num_differences > 0: | ||
print("Differences (index, error in units of 2^10):") | ||
for idx, error_in_units in errors: | ||
print("Index {}: error = {}".format(idx, error_in_units)) | ||
|
||
# success is when we don't offset by more than 1 | ||
for error in errors: | ||
if error[1] > 1: | ||
return 1 | ||
return 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
compute_error() |
179 changes: 179 additions & 0 deletions
179
frontends/concrete-python/examples/tfhers-ml/example.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import os | ||
import typing | ||
from functools import partial | ||
|
||
import click | ||
import numpy as np | ||
|
||
import concrete.fhe as fhe | ||
from concrete.fhe import tfhers | ||
|
||
### Options ########################### | ||
# These parameters were saved by running the tfhers_utils utility: | ||
# tfhers_utils save-params tfhers_params.json | ||
TFHERS_PARAMS_FILE = "tfhers_params.json" | ||
FHEUINT_PRECISION = 8 | ||
IS_SIGNED = True | ||
####################################### | ||
|
||
tfhers_type = tfhers.get_type_from_params( | ||
TFHERS_PARAMS_FILE, | ||
is_signed=IS_SIGNED, | ||
precision=FHEUINT_PRECISION, | ||
) | ||
tfhers_int = partial(tfhers.TFHERSInteger, tfhers_type) | ||
|
||
#### Model Parameters ################## | ||
q_weights = np.array([[-25, 21, -10], [42, -20, -37], [-128, -15, 127], [-58, -51, 94]]) | ||
q_bias = np.array([[35167, 9417, -44584]]) | ||
weight_quantizer_zero_point = -5 | ||
######################################## | ||
|
||
rounder = fhe.AutoRounder(target_msbs=8) # We want to keep 8 MSBs | ||
|
||
|
||
@typing.no_type_check | ||
def ml_inference(q_X: np.ndarray) -> np.ndarray: | ||
y_pred = q_X @ q_weights - weight_quantizer_zero_point * np.sum(q_X, axis=1, keepdims=True) | ||
y_pred += q_bias | ||
y_pred = fhe.round_bit_pattern(y_pred, rounder) | ||
y_pred = y_pred >> rounder.lsbs_to_remove | ||
return y_pred | ||
|
||
|
||
def compute(tfhers_x): | ||
####### TFHE-rs to Concrete ######### | ||
|
||
# x and y are supposed to be TFHE-rs values. | ||
# to_native will use type information from x and y to do | ||
# a correct conversion from TFHE-rs to Concrete | ||
concrete_x = tfhers.to_native(tfhers_x) | ||
####### TFHE-rs to Concrete ######### | ||
|
||
####### Concrete Computation ######## | ||
concrete_res = ml_inference(concrete_x) | ||
####### Concrete Computation ######## | ||
|
||
####### Concrete to TFHE-rs ######### | ||
tfhers_res = tfhers.from_native( | ||
concrete_res, tfhers_type | ||
) # we have to specify the type we want to convert to | ||
####### Concrete to TFHE-rs ######### | ||
return tfhers_res | ||
|
||
|
||
def ccompilee(): | ||
compiler = fhe.Compiler( | ||
compute, | ||
{ | ||
"tfhers_x": "encrypted", | ||
}, | ||
) | ||
|
||
inputset = [ | ||
( | ||
tfhers_int( | ||
np.array( | ||
[ | ||
[36, -17, -85, -124], | ||
[29, -33, -85, -124], | ||
[23, -26, -88, -124], | ||
[19, -30, -82, -124], | ||
[32, -13, -85, -124], | ||
] | ||
) | ||
), | ||
) | ||
] | ||
|
||
# Add the auto-adjustment before compilation | ||
fhe.AutoRounder.adjust(compute, inputset) | ||
|
||
# Print the number of bits rounded | ||
print(f"lsbs_to_remove: {rounder.lsbs_to_remove}") | ||
|
||
circuit = compiler.compile(inputset) | ||
|
||
tfhers_bridge = tfhers.new_bridge(circuit=circuit) | ||
return circuit, tfhers_bridge | ||
|
||
|
||
@click.group() | ||
def cli(): | ||
pass | ||
|
||
|
||
@cli.command() | ||
@click.option("-s", "--secret-key", type=str, required=False) | ||
@click.option("-o", "--output-secret-key", type=str, required=True) | ||
@click.option("-k", "--concrete-keyset-path", type=str, required=True) | ||
def keygen(output_secret_key: str, secret_key: str, concrete_keyset_path: str): | ||
"""Concrete Key Generation""" | ||
|
||
circuit, tfhers_bridge = ccompilee() | ||
|
||
if os.path.exists(concrete_keyset_path): | ||
print(f"removing old keyset at '{concrete_keyset_path}'") | ||
os.remove(concrete_keyset_path) | ||
|
||
if secret_key: | ||
print(f"partial keygen from sk at '{secret_key}'") | ||
# load the initial secret key to use for keygen | ||
with open( | ||
secret_key, | ||
"rb", | ||
) as f: | ||
buff = f.read() | ||
input_idx_to_key = {0: buff} | ||
tfhers_bridge.keygen_with_initial_keys(input_idx_to_key_buffer=input_idx_to_key) | ||
else: | ||
print("full keygen") | ||
circuit.keygen() | ||
|
||
print("saving Concrete keyset") | ||
circuit.client.keys.save(concrete_keyset_path) | ||
print(f"saved Concrete keyset to '{concrete_keyset_path}'") | ||
|
||
sk: bytes = tfhers_bridge.serialize_input_secret_key(input_idx=0) | ||
print(f"writing secret key of size {len(sk)} to '{output_secret_key}'") | ||
with open(output_secret_key, "wb") as f: | ||
f.write(sk) | ||
|
||
|
||
@cli.command() | ||
@click.option("-c", "--rust-ct", type=str, required=True) | ||
@click.option("-o", "--output-rust-ct", type=str, required=False) | ||
@click.option("-k", "--concrete-keyset-path", type=str, required=True) | ||
def run(rust_ct: str, output_rust_ct: str, concrete_keyset_path: str): | ||
"""Run circuit""" | ||
circuit, tfhers_bridge = ccompilee() | ||
|
||
if not os.path.exists(concrete_keyset_path): | ||
raise RuntimeError("cannot find keys, you should run keygen before") | ||
print(f"loading keys from '{concrete_keyset_path}'") | ||
circuit.client.keys.load(concrete_keyset_path) | ||
|
||
# read tfhers int from file | ||
with open(rust_ct, "rb") as f: | ||
buff = f.read() | ||
# import fheuint8 and get its description | ||
tfhers_uint8_x = tfhers_bridge.import_value(buff, input_idx=0) | ||
|
||
print("Homomorphic evaluation...") | ||
encrypted_result = circuit.run(tfhers_uint8_x) | ||
|
||
if output_rust_ct: | ||
print("exporting Rust ciphertexts") | ||
# export fheuint8 | ||
buff = tfhers_bridge.export_value(encrypted_result, output_idx=0) | ||
# write it to file | ||
with open(output_rust_ct, "wb") as f: | ||
f.write(buff) | ||
else: | ||
result = circuit.decrypt(encrypted_result) | ||
decoded = tfhers_type.decode(result) | ||
print(f"Concrete decryption result: raw({result}), decoded({decoded})") | ||
|
||
|
||
if __name__ == "__main__": | ||
cli() |
1 change: 1 addition & 0 deletions
1
frontends/concrete-python/examples/tfhers-ml/input_quantizer.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"type_name": "UniformQuantizer", "serialized_value": {"n_bits": 8, "is_signed": true, "is_symmetric": false, "is_qat": false, "is_narrow": false, "is_precomputed_qat": false, "rmax": {"type_name": "numpy_float", "serialized_value": 7.9, "dtype": "float64"}, "rmin": {"type_name": "numpy_float", "serialized_value": 0.1, "dtype": "float64"}, "scale": {"type_name": "numpy_float", "serialized_value": 0.03058823529411765, "dtype": "float64"}, "zero_point": {"type_name": "numpy_integer", "serialized_value": -131, "dtype": "int64"}, "offset": 128, "no_clipping": false}} |
1 change: 1 addition & 0 deletions
1
frontends/concrete-python/examples/tfhers-ml/output_quantizer.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"type_name": "UniformQuantizer", "serialized_value": {"is_signed": false, "is_symmetric": false, "is_qat": false, "is_narrow": false, "is_precomputed_qat": false, "rmax": null, "rmin": null, "scale": {"type_name": "numpy_float", "serialized_value": 0.0006288117860507253, "dtype": "float64"}, "zero_point": {"type_name": "numpy_array", "serialized_value": [[39038, 11790, -50828]], "dtype": "int64"}, "offset": 0, "no_clipping": true}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#!/bin/sh | ||
|
||
# This file tests that the example is working | ||
|
||
shell_blocks=$(sed -n '/^```sh/,/^```/ p' < README.md | sed '/^```sh/d' | sed '/^```/d') | ||
|
||
set -e | ||
output=$(eval "$shell_blocks" 2>&1) || echo "$output" |
Oops, something went wrong.