From 3f52db55ebdb267631ee363c14a12e3ec525d99a Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 7 Nov 2024 16:49:13 +0100 Subject: [PATCH] feat(frontend): load TFHE-rs integer type from saved params (JSON) --- .../concrete/fhe/tfhers/__init__.py | 58 +++++++++++++++++++ .../concrete-python/examples/tfhers/README.md | 6 ++ .../examples/tfhers/example.py | 38 ++---------- .../examples/tfhers/tfhers_params.json | 1 + .../tests/tfhers-utils/Cargo.lock | 31 ++++++++++ .../tests/tfhers-utils/Cargo.toml | 1 + .../tests/tfhers-utils/src/main.rs | 37 +++++++++--- 7 files changed, 133 insertions(+), 39 deletions(-) create mode 100644 frontends/concrete-python/examples/tfhers/tfhers_params.json diff --git a/frontends/concrete-python/concrete/fhe/tfhers/__init__.py b/frontends/concrete-python/concrete/fhe/tfhers/__init__.py index 05c2f28701..a604ee00a0 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/__init__.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/__init__.py @@ -2,6 +2,9 @@ tfhers module to represent, and compute on tfhers integer values. """ +import json +from math import log2 + from .bridge import new_bridge from .dtypes import ( CryptoParams, @@ -18,3 +21,58 @@ ) from .tracing import from_native, to_native from .values import TFHERSInteger + + +def get_type_from_params( + path_to_params_json: str, is_signed: bool, precision: int +) -> TFHERSIntegerType: + """Get a TFHE-rs integer type from TFHE-rs parameters in JSON format. + + Args: + path_to_params_json (str): path to the TFHE-rs parameters (JSON format) + is_signed (bool): sign of the result type + precision (int): precision of the result type + + Returns: + TFHERSIntegerType: constructed type from the loaded parameters + """ + + # Read crypto parameters from TFHE-rs in the json file + with open(path_to_params_json) as f: + crypto_param_dict = json.load(f) + + lwe_dim = crypto_param_dict["lwe_dimension"] + glwe_dim = crypto_param_dict["glwe_dimension"] + poly_size = crypto_param_dict["polynomial_size"] + pbs_base_log = crypto_param_dict["pbs_base_log"] + pbs_level = crypto_param_dict["pbs_level"] + msg_width = int(log2(crypto_param_dict["message_modulus"])) + carry_width = int(log2(crypto_param_dict["carry_modulus"])) + lwe_noise_distr = crypto_param_dict["lwe_noise_distribution"]["Gaussian"]["std"] + glwe_noise_distr = crypto_param_dict["glwe_noise_distribution"]["Gaussian"]["std"] + encryption_key_choice = ( + EncryptionKeyChoice.BIG + if crypto_param_dict["encryption_key_choice"] == "Big" + else EncryptionKeyChoice.SMALL + ) + + assert glwe_dim == 1, "glwe dim must be 1" + assert encryption_key_choice == EncryptionKeyChoice.BIG, "encryption_key_choice must be BIG" + + tfhers_params = CryptoParams( + lwe_dimension=lwe_dim, + glwe_dimension=glwe_dim, + polynomial_size=poly_size, + pbs_base_log=pbs_base_log, + pbs_level=pbs_level, + lwe_noise_distribution=lwe_noise_distr, + glwe_noise_distribution=glwe_noise_distr, + encryption_key_choice=encryption_key_choice, + ) + return TFHERSIntegerType( + is_signed=is_signed, + bit_width=precision, + carry_width=carry_width, + msg_width=msg_width, + params=tfhers_params, + ) diff --git a/frontends/concrete-python/examples/tfhers/README.md b/frontends/concrete-python/examples/tfhers/README.md index 78c5f9fc0b..9d4971dce9 100644 --- a/frontends/concrete-python/examples/tfhers/README.md +++ b/frontends/concrete-python/examples/tfhers/README.md @@ -57,6 +57,12 @@ python example.py keygen -s $TDIR/tfhers_sk -o $TDIR/concrete_sk -k $TDIR/concre ../../tests/tfhers-utils/target/release/tfhers_utils encrypt-with-key --value 73 --ciphertext $TDIR/tfhers_ct_2 --client-key $TDIR/tfhers_client_key ``` +{% hint style="info" %} + +If you have tensor inputs, then you can encrypt by passing your flat tensor in `--value`. Concrete will take care of reshaping the values to the corresponding shape. For example `--value=1,2,3,4` can represent a 2 by 2 tensor, or a flat vector of 4 values. + +{% endhint %} + ## Compute in TFHE-rs ```sh diff --git a/frontends/concrete-python/examples/tfhers/example.py b/frontends/concrete-python/examples/tfhers/example.py index a808d4d350..942b7340cf 100644 --- a/frontends/concrete-python/examples/tfhers/example.py +++ b/frontends/concrete-python/examples/tfhers/example.py @@ -7,42 +7,16 @@ from concrete import fhe from concrete.fhe import tfhers -########## Params ##################### -LWE_DIM = 909 -GLWE_DIM = 1 -POLY_SIZE = 4096 -PBS_BASE_LOG = 15 -PBS_LEVEL = 2 -MSG_WIDTH = 2 -CARRY_WIDTH = 3 -ENCRYPTION_KEY_CHOICE = tfhers.EncryptionKeyChoice.BIG -LWE_NOISE_DISTR = 0 -GLWE_NOISE_DISTR = 2.168404344971009e-19 -####################################### - -assert GLWE_DIM == 1, "glwe dim must be 1" - ### Options ########################### +# These parameters were saved by running the tfhers_utils utility: +# tfhers_utils save-params tfhers_params.json +TFHERS_PARAMS_FILE = "tfhers_params.json" FHEUINT_PRECISION = 8 +IS_SIGNED = False ####################################### - -tfhers_params = tfhers.CryptoParams( - lwe_dimension=LWE_DIM, - glwe_dimension=GLWE_DIM, - polynomial_size=POLY_SIZE, - pbs_base_log=PBS_BASE_LOG, - pbs_level=PBS_LEVEL, - lwe_noise_distribution=LWE_NOISE_DISTR, - glwe_noise_distribution=GLWE_NOISE_DISTR, - encryption_key_choice=ENCRYPTION_KEY_CHOICE, -) -tfhers_type = tfhers.TFHERSIntegerType( - is_signed=False, - bit_width=FHEUINT_PRECISION, - carry_width=CARRY_WIDTH, - msg_width=MSG_WIDTH, - params=tfhers_params, +tfhers_type = tfhers.get_type_from_params( + TFHERS_PARAMS_FILE, IS_SIGNED, FHEUINT_PRECISION ) tfhers_int = partial(tfhers.TFHERSInteger, tfhers_type) diff --git a/frontends/concrete-python/examples/tfhers/tfhers_params.json b/frontends/concrete-python/examples/tfhers/tfhers_params.json new file mode 100644 index 0000000000..bf62c959f2 --- /dev/null +++ b/frontends/concrete-python/examples/tfhers/tfhers_params.json @@ -0,0 +1 @@ +{"lwe_dimension":902,"glwe_dimension":1,"polynomial_size":4096,"lwe_noise_distribution":{"Gaussian":{"std":1.0994794733558207e-6,"mean":0.0}},"glwe_noise_distribution":{"Gaussian":{"std":2.168404344971009e-19,"mean":0.0}},"pbs_base_log":15,"pbs_level":2,"ks_base_log":3,"ks_level":6,"message_modulus":4,"carry_modulus":8,"max_noise_level":10,"log2_p_fail":-64.084,"ciphertext_modulus":{"modulus":0,"scalar_bits":64},"encryption_key_choice":"Big"} diff --git a/frontends/concrete-python/tests/tfhers-utils/Cargo.lock b/frontends/concrete-python/tests/tfhers-utils/Cargo.lock index 2c54fd07cc..e07be46402 100644 --- a/frontends/concrete-python/tests/tfhers-utils/Cargo.lock +++ b/frontends/concrete-python/tests/tfhers-utils/Cargo.lock @@ -326,6 +326,12 @@ dependencies = [ "either", ] +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + [[package]] name = "js-sys" version = "0.3.70" @@ -362,6 +368,12 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + [[package]] name = "num-complex" version = "0.4.6" @@ -459,6 +471,12 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + [[package]] name = "serde" version = "1.0.209" @@ -479,6 +497,18 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.132" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + [[package]] name = "sha3" version = "0.10.8" @@ -559,6 +589,7 @@ dependencies = [ "bincode", "clap", "serde", + "serde_json", "tfhe", ] diff --git a/frontends/concrete-python/tests/tfhers-utils/Cargo.toml b/frontends/concrete-python/tests/tfhers-utils/Cargo.toml index 71e3cf64c8..af0a7f67a8 100644 --- a/frontends/concrete-python/tests/tfhers-utils/Cargo.toml +++ b/frontends/concrete-python/tests/tfhers-utils/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" bincode = "1.3.3" serde = "1" +serde_json = "1.0.128" clap = { version = "4.5.16", features = ["derive"] } diff --git a/frontends/concrete-python/tests/tfhers-utils/src/main.rs b/frontends/concrete-python/tests/tfhers-utils/src/main.rs index 5543c90291..546a1ffae0 100644 --- a/frontends/concrete-python/tests/tfhers-utils/src/main.rs +++ b/frontends/concrete-python/tests/tfhers-utils/src/main.rs @@ -1,18 +1,23 @@ -use clap::{Arg, ArgAction, Command}; use core::panic; -use serde::de::DeserializeOwned; use std::fs; +use std::io::Write; use std::path::Path; + +use clap::{Arg, ArgAction, Command}; + use tfhe::core_crypto::prelude::LweSecretKey; +use tfhe::named::Named; use tfhe::prelude::*; +use tfhe::safe_serialization::{safe_deserialize, safe_serialize}; use tfhe::shortint::{ClassicPBSParameters, EncryptionKeyChoice}; -use tfhe::{generate_keys, set_server_key, ClientKey, ConfigBuilder, FheInt8, FheUint8, ServerKey}; +use tfhe::{ + generate_keys, set_server_key, ClientKey, ConfigBuilder, FheInt8, FheUint8, ServerKey, + Unversionize, Versionize, +}; +use serde::de::DeserializeOwned; use serde::Serialize; -use tfhe::named::Named; -use tfhe::{Unversionize, Versionize}; - -use tfhe::safe_serialization::{safe_deserialize, safe_serialize}; +use serde_json; const BLOCK_PARAMS: ClassicPBSParameters = tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_3_KS_PBS; const SERIALIZE_SIZE_LIMIT: u64 = 1_000_000_000; @@ -405,6 +410,17 @@ fn main() { .num_args(1), ), ) + .subcommand( + Command::new("save-params") + .short_flag('p') + .long_flag("save-params") + .about("save TFHE-rs parameters used into a file (JSON)") + .arg( + Arg::new("filename") + .help("filename to save TFHE-rs parameters to") + .required(true), + ), + ) .get_matches(); match matches.subcommand() { @@ -482,6 +498,13 @@ fn main() { keygen(client_key_path, server_key_path, output_lwe_path.unwrap()) } } + Some(("save-params", save_params_mtches)) => { + let filename = save_params_mtches.get_one::("filename").unwrap(); + let json_string = serde_json::to_string(&BLOCK_PARAMS).unwrap(); + let path = Path::new(filename); + let mut file = fs::File::create(path).unwrap(); + file.write_all(json_string.as_bytes()).unwrap(); + } _ => unreachable!(), // If all subcommands are defined above, anything else is unreachable } }