Skip to content

Commit

Permalink
Merge pull request #1151 from zama-ai/tfhers-ml-example
Browse files Browse the repository at this point in the history
Tfhers ml example
  • Loading branch information
youben11 authored Dec 18, 2024
2 parents 04d7fb2 + 0b93351 commit 062f5cb
Show file tree
Hide file tree
Showing 13 changed files with 709 additions and 12 deletions.
108 changes: 108 additions & 0 deletions frontends/concrete-python/examples/tfhers-ml/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# TFHE-rs interoperability example

This is a similar example to the [first TFHE-rs example](../tfhers/), except that it uses tensors and run a linear ML model. It also uses quantization.

## Make tmpdir

We want to setup a temporary working directory first:

```sh
export TDIR=`mktemp -d`
```

## KeyGen

First we need to build the TFHE-rs utility in [this directory](../../tests/tfhers-utils/) by running the following:

```sh
cd ../../tests/tfhers-utils/
make build
cd -
```

Then we can generate keys in two different ways. You only need to run one of the following methods.

#### Generate the Secret Key in Concrete

We start by doing keygen in Concrete:

```sh
python example.py keygen -o $TDIR/concrete_sk -k $TDIR/concrete_keyset
```

Then we do a partial keygen in TFHE-rs:

```sh
../../tests/tfhers-utils/target/release/tfhers_utils keygen --lwe-sk $TDIR/concrete_sk --output-lwe-sk $TDIR/tfhers_sk -c $TDIR/tfhers_client_key -s $TDIR/tfhers_server_key
```

#### Generate the Secret Key in TFHE-rs

We start by doing keygen in TFHE-rs:

```sh
../../tests/tfhers-utils/target/release/tfhers_utils keygen --output-lwe-sk $TDIR/tfhers_sk -c $TDIR/tfhers_client_key -s $TDIR/tfhers_server_key
```

Then we do a partial keygen in Concrete:

```sh
python example.py keygen -s $TDIR/tfhers_sk -o $TDIR/concrete_sk -k $TDIR/concrete_keyset
```

## Quantize values

We need to quantize floating point inputs using a pre-built quantizer for our ML model:

```sh
../../tests/tfhers-utils/target/release/tfhers_utils quantize --value=5.1,3.5,1.4,0.2,4.9,3,1.4,0.2,4.7,3.2,1.3,0.2,4.6,3.1,1.5,0.2,5,3.6,1.4,0.2 --config ./input_quantizer.json -o $TDIR/quantized_values
```

## Encrypt in TFHE-rs

```sh
../../tests/tfhers-utils/target/release/tfhers_utils encrypt-with-key --signed --value=$(cat $TDIR/quantized_values) --ciphertext $TDIR/tfhers_ct --client-key $TDIR/tfhers_client_key
```

## Run in Concrete

```sh
python example.py run -k $TDIR/concrete_keyset -c $TDIR/tfhers_ct -o $TDIR/tfhers_ct_out
```

## Decrypt in TFHE-rs

```sh
../../tests/tfhers-utils/target/release/tfhers_utils decrypt-with-key --tensor --signed --ciphertext $TDIR/tfhers_ct_out --client-key $TDIR/tfhers_client_key --plaintext $TDIR/result_plaintext
```

## Rescale Output

At the end of the circuit, we are rounding the result to 8 bits, discarding the remaining LSB bits. As we have `lsbs_to_remove=10` we are re-introducing the 10 bits of LSB.

```sh
python -c "print(','.join(map(lambda x: str(x << 10), [$(cat $TDIR/result_plaintext)])))" > $TDIR/rescaled_plaintext
```


## Dequantize values

We need to dequantize integer outputs using a pre-built quantizer for our ML model:

```sh
../../tests/tfhers-utils/target/release/tfhers_utils dequantize --value=$(cat $TDIR/rescaled_plaintext) --shape=5,3 --config ./output_quantizer.json
```

## Compute error

We compare the output to the expected result:

```sh
python compute_error.py --plaintext-file "$TDIR/rescaled_plaintext" --quantized-predictions-file "test_values.json"
```

## Clean tmpdir

```sh
rm -rf $TDIR
```
63 changes: 63 additions & 0 deletions frontends/concrete-python/examples/tfhers-ml/compute_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import json

import click


@click.command()
@click.option(
"--plaintext-file", "-p", required=True, help="Path to the rescaled plaintext values file."
)
@click.option(
"--quantized-predictions-file",
"-q",
required=True,
help="Path to the test_values.json file containing quantized predictions.",
)
def compute_error(plaintext_file, quantized_predictions_file):
"""Compute the error between decrypted rescaled values and quantized predictions."""
# Read rescaled plaintext values from plaintext_file
with open(plaintext_file) as f:
rescaled_plaintext_values = [int(x) for x in f.read().strip().split(",")]

# Read quantized_predictions from quantized_predictions_file
with open(quantized_predictions_file) as f:
data = json.load(f)
quantized_predictions = data["quantized_predictions"]

# Flatten quantized_predictions
quantized_predictions_flat = [int(x) for sublist in quantized_predictions for x in sublist]

# Round down 10 bits using (x // (1 << 10)) * (1 << 10)
rounded_quantized_predictions = [
(x // (1 << 10)) * (1 << 10) for x in quantized_predictions_flat
]

# Compare rescaled_plaintext_values with rounded_quantized_predictions
num_differences = 0
total_values = len(rescaled_plaintext_values)
errors = []
for i in range(total_values):
a = rescaled_plaintext_values[i]
b = rounded_quantized_predictions[i]
print(f"output: {a}, expected: {b}")
if a != b:
num_differences += 1
error_in_units = round((a - b) / (1 << 10))
errors.append((i, error_in_units))

print("Number of differing values: {}".format(num_differences))
print("Total values compared: {}".format(total_values))
if num_differences > 0:
print("Differences (index, error in units of 2^10):")
for idx, error_in_units in errors:
print("Index {}: error = {}".format(idx, error_in_units))

# success is when we don't offset by more than 1
for error in errors:
if error[1] > 1:
return 1
return 0


if __name__ == "__main__":
compute_error()
179 changes: 179 additions & 0 deletions frontends/concrete-python/examples/tfhers-ml/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import typing
from functools import partial

import click
import numpy as np

import concrete.fhe as fhe
from concrete.fhe import tfhers

### Options ###########################
# These parameters were saved by running the tfhers_utils utility:
# tfhers_utils save-params tfhers_params.json
TFHERS_PARAMS_FILE = "tfhers_params.json"
FHEUINT_PRECISION = 8
IS_SIGNED = True
#######################################

tfhers_type = tfhers.get_type_from_params(
TFHERS_PARAMS_FILE,
is_signed=IS_SIGNED,
precision=FHEUINT_PRECISION,
)
tfhers_int = partial(tfhers.TFHERSInteger, tfhers_type)

#### Model Parameters ##################
q_weights = np.array([[-25, 21, -10], [42, -20, -37], [-128, -15, 127], [-58, -51, 94]])
q_bias = np.array([[35167, 9417, -44584]])
weight_quantizer_zero_point = -5
########################################

rounder = fhe.AutoRounder(target_msbs=8) # We want to keep 8 MSBs


@typing.no_type_check
def ml_inference(q_X: np.ndarray) -> np.ndarray:
y_pred = q_X @ q_weights - weight_quantizer_zero_point * np.sum(q_X, axis=1, keepdims=True)
y_pred += q_bias
y_pred = fhe.round_bit_pattern(y_pred, rounder)
y_pred = y_pred >> rounder.lsbs_to_remove
return y_pred


def compute(tfhers_x):
####### TFHE-rs to Concrete #########

# x and y are supposed to be TFHE-rs values.
# to_native will use type information from x and y to do
# a correct conversion from TFHE-rs to Concrete
concrete_x = tfhers.to_native(tfhers_x)
####### TFHE-rs to Concrete #########

####### Concrete Computation ########
concrete_res = ml_inference(concrete_x)
####### Concrete Computation ########

####### Concrete to TFHE-rs #########
tfhers_res = tfhers.from_native(
concrete_res, tfhers_type
) # we have to specify the type we want to convert to
####### Concrete to TFHE-rs #########
return tfhers_res


def ccompilee():
compiler = fhe.Compiler(
compute,
{
"tfhers_x": "encrypted",
},
)

inputset = [
(
tfhers_int(
np.array(
[
[36, -17, -85, -124],
[29, -33, -85, -124],
[23, -26, -88, -124],
[19, -30, -82, -124],
[32, -13, -85, -124],
]
)
),
)
]

# Add the auto-adjustment before compilation
fhe.AutoRounder.adjust(compute, inputset)

# Print the number of bits rounded
print(f"lsbs_to_remove: {rounder.lsbs_to_remove}")

circuit = compiler.compile(inputset)

tfhers_bridge = tfhers.new_bridge(circuit=circuit)
return circuit, tfhers_bridge


@click.group()
def cli():
pass


@cli.command()
@click.option("-s", "--secret-key", type=str, required=False)
@click.option("-o", "--output-secret-key", type=str, required=True)
@click.option("-k", "--concrete-keyset-path", type=str, required=True)
def keygen(output_secret_key: str, secret_key: str, concrete_keyset_path: str):
"""Concrete Key Generation"""

circuit, tfhers_bridge = ccompilee()

if os.path.exists(concrete_keyset_path):
print(f"removing old keyset at '{concrete_keyset_path}'")
os.remove(concrete_keyset_path)

if secret_key:
print(f"partial keygen from sk at '{secret_key}'")
# load the initial secret key to use for keygen
with open(
secret_key,
"rb",
) as f:
buff = f.read()
input_idx_to_key = {0: buff}
tfhers_bridge.keygen_with_initial_keys(input_idx_to_key_buffer=input_idx_to_key)
else:
print("full keygen")
circuit.keygen()

print("saving Concrete keyset")
circuit.client.keys.save(concrete_keyset_path)
print(f"saved Concrete keyset to '{concrete_keyset_path}'")

sk: bytes = tfhers_bridge.serialize_input_secret_key(input_idx=0)
print(f"writing secret key of size {len(sk)} to '{output_secret_key}'")
with open(output_secret_key, "wb") as f:
f.write(sk)


@cli.command()
@click.option("-c", "--rust-ct", type=str, required=True)
@click.option("-o", "--output-rust-ct", type=str, required=False)
@click.option("-k", "--concrete-keyset-path", type=str, required=True)
def run(rust_ct: str, output_rust_ct: str, concrete_keyset_path: str):
"""Run circuit"""
circuit, tfhers_bridge = ccompilee()

if not os.path.exists(concrete_keyset_path):
raise RuntimeError("cannot find keys, you should run keygen before")
print(f"loading keys from '{concrete_keyset_path}'")
circuit.client.keys.load(concrete_keyset_path)

# read tfhers int from file
with open(rust_ct, "rb") as f:
buff = f.read()
# import fheuint8 and get its description
tfhers_uint8_x = tfhers_bridge.import_value(buff, input_idx=0)

print("Homomorphic evaluation...")
encrypted_result = circuit.run(tfhers_uint8_x)

if output_rust_ct:
print("exporting Rust ciphertexts")
# export fheuint8
buff = tfhers_bridge.export_value(encrypted_result, output_idx=0)
# write it to file
with open(output_rust_ct, "wb") as f:
f.write(buff)
else:
result = circuit.decrypt(encrypted_result)
decoded = tfhers_type.decode(result)
print(f"Concrete decryption result: raw({result}), decoded({decoded})")


if __name__ == "__main__":
cli()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type_name": "UniformQuantizer", "serialized_value": {"n_bits": 8, "is_signed": true, "is_symmetric": false, "is_qat": false, "is_narrow": false, "is_precomputed_qat": false, "rmax": {"type_name": "numpy_float", "serialized_value": 7.9, "dtype": "float64"}, "rmin": {"type_name": "numpy_float", "serialized_value": 0.1, "dtype": "float64"}, "scale": {"type_name": "numpy_float", "serialized_value": 0.03058823529411765, "dtype": "float64"}, "zero_point": {"type_name": "numpy_integer", "serialized_value": -131, "dtype": "int64"}, "offset": 128, "no_clipping": false}}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type_name": "UniformQuantizer", "serialized_value": {"is_signed": false, "is_symmetric": false, "is_qat": false, "is_narrow": false, "is_precomputed_qat": false, "rmax": null, "rmin": null, "scale": {"type_name": "numpy_float", "serialized_value": 0.0006288117860507253, "dtype": "float64"}, "zero_point": {"type_name": "numpy_array", "serialized_value": [[39038, 11790, -50828]], "dtype": "int64"}, "offset": 0, "no_clipping": true}}
8 changes: 8 additions & 0 deletions frontends/concrete-python/examples/tfhers-ml/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/sh

# This file tests that the example is working

shell_blocks=$(sed -n '/^```sh/,/^```/ p' < README.md | sed '/^```sh/d' | sed '/^```/d')

set -e
output=$(eval "$shell_blocks" 2>&1) || echo "$output"
Loading

0 comments on commit 062f5cb

Please sign in to comment.