From 6c7291cd57d497c4112f05cd9f2a4446524635aa Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 10 Oct 2024 09:30:39 +0100 Subject: [PATCH 1/4] feat(frontend/compiler): support TFHErs fheint8 --- .../implementation/include/concrete-cpu.h | 10 + .../implementation/src/c_api/fheint.rs | 104 +++++++++- .../concretelang/ClientLib/ClientLib.h | 6 + .../lib/Bindings/Python/CompilerAPIModule.cpp | 24 +++ .../Python/concrete/compiler/tfhers_int.py | 54 +++++ .../compiler/lib/ClientLib/ClientLib.cpp | 71 +++++++ .../concrete/fhe/tfhers/bridge.py | 10 +- .../concrete/fhe/tfhers/dtypes.py | 10 +- .../tests/execution/test_tfhers.py | 195 ++++++++++++++---- .../tests/tfhers-utils/src/main.rs | 92 ++++++--- 10 files changed, 507 insertions(+), 69 deletions(-) diff --git a/backends/concrete-cpu/implementation/include/concrete-cpu.h b/backends/concrete-cpu/implementation/include/concrete-cpu.h index 3c6ced4e8f..4d9cc35d6a 100644 --- a/backends/concrete-cpu/implementation/include/concrete-cpu.h +++ b/backends/concrete-cpu/implementation/include/concrete-cpu.h @@ -369,6 +369,11 @@ void concrete_cpu_keyswitch_lwe_ciphertext_u64(uint64_t *ct_out, size_t input_dimension, size_t output_dimension); +size_t concrete_cpu_lwe_array_to_tfhers_int8(const uint64_t *lwe_vec_buffer, + uint8_t *buffer, + size_t buffer_len, + struct TfhersFheIntDescription fheint_desc); + size_t concrete_cpu_lwe_array_to_tfhers_uint8(const uint64_t *lwe_vec_buffer, uint8_t *buffer, size_t buffer_len, @@ -415,6 +420,11 @@ size_t concrete_cpu_serialize_lwe_secret_key_u64(const uint64_t *lwe_sk, size_t concrete_cpu_tfhers_fheint_buffer_size_u64(size_t lwe_size, size_t n_cts); +int64_t concrete_cpu_tfhers_int8_to_lwe_array(const uint8_t *serialized_data_ptr, + size_t serialized_data_len, + uint64_t *lwe_vec_buffer, + struct TfhersFheIntDescription desc); + int64_t concrete_cpu_tfhers_uint8_to_lwe_array(const uint8_t *buffer, size_t buffer_len, uint64_t *lwe_vec_buffer, diff --git a/backends/concrete-cpu/implementation/src/c_api/fheint.rs b/backends/concrete-cpu/implementation/src/c_api/fheint.rs index 1364b62646..3eecab1e98 100644 --- a/backends/concrete-cpu/implementation/src/c_api/fheint.rs +++ b/backends/concrete-cpu/implementation/src/c_api/fheint.rs @@ -5,7 +5,7 @@ use tfhe::integer::ciphertext::Expandable; use tfhe::integer::IntegerCiphertext; use tfhe::shortint::parameters::{Degree, NoiseLevel}; use tfhe::shortint::{CarryModulus, Ciphertext, MessageModulus}; -use tfhe::{FheUint128, FheUint8}; +use tfhe::{FheInt8, FheUint128, FheUint8}; #[repr(C)] pub struct TfhersFheIntDescription { @@ -127,6 +127,29 @@ pub fn tfhers_uint8_description(fheuint: FheUint8) -> TfhersFheIntDescription { } } +pub fn tfhers_int8_description(fheuint: FheInt8) -> TfhersFheIntDescription { + // get metadata from fheuint's ciphertext + let (radix, _, _) = fheuint.into_raw_parts(); + let blocks = radix.blocks(); + let ct = match blocks.first() { + Some(value) => &value.ct, + None => { + return TfhersFheIntDescription::zero(); + } + }; + TfhersFheIntDescription { + width: 8, + is_signed: true, + lwe_size: ct.lwe_size().0, + n_cts: blocks.len(), + degree: blocks[0].degree.get(), + noise_level: blocks[0].noise_level().get(), + message_modulus: blocks[0].message_modulus.0, + carry_modulus: blocks[0].carry_modulus.0, + ks_first: blocks[0].pbs_order == PBSOrder::KeyswitchBootstrap, + } +} + #[no_mangle] pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_to_lwe_array( buffer: *const u8, @@ -161,6 +184,41 @@ pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_to_lwe_array( }) } +#[no_mangle] +pub unsafe extern "C" fn concrete_cpu_tfhers_int8_to_lwe_array( + serialized_data_ptr: *const u8, + serialized_data_len: usize, + lwe_vec_buffer: *mut u64, + desc: TfhersFheIntDescription, +) -> i64 { + nounwind(|| { + let fheint: FheInt8 = + super::utils::safe_deserialize(serialized_data_ptr, serialized_data_len); + // TODO - Use conformance check + let fheint_desc = tfhers_int8_description(fheint.clone()); + if !fheint_desc.is_similar(&desc) { + return 1; + } + + // collect LWEs from fheuint + let (radix, _, _) = fheint.into_raw_parts(); + let blocks = radix.blocks(); + let first_ct = match blocks.first() { + Some(value) => &value.ct, + None => return 1, + }; + let lwe_size = first_ct.lwe_size().0; + let n_cts = blocks.len(); + // copy LWEs to C buffer. Note that lsb is cts[0] + let lwe_vector: &mut [u64] = slice::from_raw_parts_mut(lwe_vec_buffer, n_cts * lwe_size); + for (i, block) in blocks.iter().enumerate() { + lwe_vector[i * lwe_size..(i + 1) * lwe_size] + .copy_from_slice(block.ct.clone().into_container().as_slice()); + } + 0 + }) +} + #[no_mangle] pub extern "C" fn concrete_cpu_tfhers_fheint_buffer_size_u64( lwe_size: usize, @@ -198,7 +256,7 @@ pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8( let n_cts = fheuint_desc.n_cts; // construct fheuint from LWEs let lwe_vector: &[u64] = slice::from_raw_parts(lwe_vec_buffer, n_cts * lwe_size); - let mut blocks: Vec = Vec::new(); + let mut blocks: Vec = Vec::with_capacity(n_cts); for i in 0..n_cts { let lwe_ct = LweCiphertext::>::from_container( lwe_vector[i * lwe_size..(i + 1) * lwe_size].to_vec(), @@ -215,3 +273,45 @@ pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8( super::utils::safe_serialize(&fheuint, buffer, buffer_len) }) } + +#[no_mangle] +pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_int8( + lwe_vec_buffer: *const u64, + buffer: *mut u8, + buffer_len: usize, + fheint_desc: TfhersFheIntDescription, +) -> usize { + nounwind(|| { + // we want to trigger a PBS on TFHErs side + assert!( + fheint_desc.noise_level == NoiseLevel::UNKNOWN.get(), + "noise_level must be unknown" + ); + // we want to use the max degree as we don't track it on Concrete side + assert!( + fheint_desc.degree == fheint_desc.message_modulus - 1, + "degree must be the max value (msg_modulus - 1)" + ); + + let lwe_size = fheint_desc.lwe_size; + let n_cts = fheint_desc.n_cts; + // construct fheuint from LWEs + let lwe_vector: &[u64] = slice::from_raw_parts(lwe_vec_buffer, n_cts * lwe_size); + let mut blocks: Vec = Vec::with_capacity(n_cts); + for i in 0..n_cts { + let lwe_ct = LweCiphertext::>::from_container( + lwe_vector[i * lwe_size..(i + 1) * lwe_size].to_vec(), + CiphertextModulus::new_native(), + ); + blocks.push(fheint_desc.ct_from_lwe(lwe_ct)); + } + let fheuint = match FheInt8::from_expanded_blocks(blocks, fheint_desc.data_kind()) { + Ok(value) => value, + Err(_) => { + return 0; + } + }; + + super::utils::safe_serialize(&fheuint, buffer, buffer_len) + }) +} diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h index bf0d65c430..99f39a248a 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h @@ -40,6 +40,12 @@ importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, double encryptionVariance); Result> exportTfhersFheUint8(TransportValue value, TfhersFheIntDescription info); +Result +importTfhersFheInt8(llvm::ArrayRef serializedFheUint8, + TfhersFheIntDescription desc, uint32_t encryptionKeyId, + double encryptionVariance); +Result> exportTfhersFheInt8(TransportValue value, + TfhersFheIntDescription info); class ClientCircuit { diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 54f1af4844..9e90e1f2d6 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -1905,4 +1905,28 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( } return result.value(); }); + + m.def("import_tfhers_fheint8", + [](const pybind11::bytes &serialized_fheuint, + TfhersFheIntDescription info, uint32_t encryptionKeyId, + double encryptionVariance) { + const std::string &buffer_str = serialized_fheuint; + std::vector buffer(buffer_str.begin(), buffer_str.end()); + auto arrayRef = llvm::ArrayRef(buffer); + auto valueOrError = ::concretelang::clientlib::importTfhersFheInt8( + arrayRef, info, encryptionKeyId, encryptionVariance); + if (valueOrError.has_error()) { + throw std::runtime_error(valueOrError.error().mesg); + } + return TransportValue{valueOrError.value()}; + }); + + m.def("export_tfhers_fheint8", [](TransportValue fheuint, + TfhersFheIntDescription info) { + auto result = ::concretelang::clientlib::exportTfhersFheInt8(fheuint, info); + if (result.has_error()) { + throw std::runtime_error(result.error().mesg); + } + return result.value(); + }); } diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py index d4006b058d..89cfd38c9b 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py @@ -5,6 +5,8 @@ from mlir._mlir_libs._concretelang._compiler import ( import_tfhers_fheuint8 as _import_tfhers_fheuint8, export_tfhers_fheuint8 as _export_tfhers_fheuint8, + import_tfhers_fheint8 as _import_tfhers_fheint8, + export_tfhers_fheint8 as _export_tfhers_fheint8, TfhersFheIntDescription as _TfhersFheIntDescription, TransportValue, ) @@ -232,3 +234,55 @@ def import_fheuint8( if not isinstance(variance, float): raise TypeError(f"variance must be of type float, not {type(variance)}") return _import_tfhers_fheuint8(buffer, info.cpp(), keyid, variance) + + @staticmethod + def export_fheint8(value: TransportValue, info: TfhersFheIntDescription) -> bytes: + """Convert Concrete value to TFHErs and serialize it. + + Args: + value (Value): value to export + info (TfhersFheIntDescription): description of the TFHErs integer to export to + + Raises: + TypeError: if wrong input types + + Returns: + bytes: converted and serialized fheuint8 + """ + if not isinstance(value, TransportValue): + raise TypeError(f"value must be of type Value, not {type(value)}") + if not isinstance(info, TfhersFheIntDescription): + raise TypeError( + f"info must be of type TfhersFheIntDescription, not {type(info)}" + ) + return bytes(_export_tfhers_fheint8(value, info.cpp())) + + @staticmethod + def import_fheint8( + buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float + ) -> TransportValue: + """Unserialize and convert from TFHErs to Concrete value. + + Args: + buffer (bytes): serialized fheuint8 + info (TfhersFheIntDescription): description of the TFHErs integer to import + keyid (int): id of the key used for encryption + variance (float): variance used for encryption + + Raises: + TypeError: if wrong input types + + Returns: + Value: unserialized and converted value + """ + if not isinstance(buffer, bytes): + raise TypeError(f"buffer must be of type bytes, not {type(buffer)}") + if not isinstance(info, TfhersFheIntDescription): + raise TypeError( + f"info must be of type TfhersFheIntDescription, not {type(info)}" + ) + if not isinstance(keyid, int): + raise TypeError(f"keyid must be of type int, not {type(keyid)}") + if not isinstance(variance, float): + raise TypeError(f"variance must be of type float, not {type(variance)}") + return _import_tfhers_fheint8(buffer, info.cpp(), keyid, variance) diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp index 000ec4dd1f..1ed0493a78 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp @@ -259,5 +259,76 @@ exportTfhersFheUint8(TransportValue value, TfhersFheIntDescription desc) { return buffer; } +Result +importTfhersFheInt8(llvm::ArrayRef serializedFheUint8, + TfhersFheIntDescription desc, uint32_t encryptionKeyId, + double encryptionVariance) { + if (desc.width != 8 || desc.is_signed == false) { + return StringError( + "trying to import FheInt8 but description doesn't match this type"); + } + + auto dims = std::vector({desc.n_cts, desc.lwe_size}); + auto outputTensor = Tensor::fromDimensions(dims); + auto err = concrete_cpu_tfhers_int8_to_lwe_array( + serializedFheUint8.data(), serializedFheUint8.size(), + outputTensor.values.data(), desc); + if (err) { + return StringError("couldn't convert fheint to lwe array"); + } + + auto value = Value{outputTensor}.intoRawTransportValue(); + auto lwe = value.asBuilder().initTypeInfo().initLweCiphertext(); + lwe.setIntegerPrecision(64); + // dimensions + lwe.initAbstractShape().setDimensions({(uint32_t)desc.n_cts}); + lwe.initConcreteShape().setDimensions( + {(uint32_t)desc.n_cts, (uint32_t)desc.lwe_size}); + // encryption + auto encryption = lwe.initEncryption(); + encryption.setLweDimension((uint32_t)desc.lwe_size - 1); + encryption.initModulus().initMod().initNative(); + encryption.setKeyId(encryptionKeyId); + encryption.setVariance(encryptionVariance); + // Encoding + auto encoding = lwe.initEncoding(); + auto integer = encoding.initInteger(); + integer.setIsSigned(false); + integer.setWidth(std::log2(desc.message_modulus * desc.carry_modulus)); + integer.initMode().initNative(); + + return value; +} + +Result> exportTfhersFheInt8(TransportValue value, + TfhersFheIntDescription desc) { + if (desc.width != 8 || desc.is_signed == false) { + return StringError( + "trying to export FheInt8 but description doesn't match this type"); + } + + auto fheuint = Value::fromRawTransportValue(value); + if (fheuint.isScalar()) { + return StringError("expected a tensor, but value is a scalar"); + } + auto tensorOrError = fheuint.getTensor(); + if (!tensorOrError.has_value()) { + return StringError("couldn't get tensor from value"); + } + size_t buffer_size = + concrete_cpu_tfhers_fheint_buffer_size_u64(desc.lwe_size, desc.n_cts); + std::vector buffer(buffer_size, 0); + auto flat_data = tensorOrError.value().values; + auto size = concrete_cpu_lwe_array_to_tfhers_int8( + flat_data.data(), buffer.data(), buffer.size(), desc); + if (size == 0) { + return StringError("couldn't convert lwe array to fheint8"); + } + // we truncate to the serialized data + assert(size <= buffer.size()); + buffer.resize(size, 0); + return buffer; +} + } // namespace clientlib } // namespace concretelang diff --git a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py index 52ffed0af6..27a5811b05 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py @@ -121,11 +121,13 @@ def import_value(self, buffer: bytes, input_idx: int) -> Value: bit_width = input_type.bit_width signed = input_type.is_signed + keyid = self._input_keyid(input_idx) + variance = self._input_variance(input_idx) if bit_width == 8: if not signed: - keyid = self._input_keyid(input_idx) - variance = self._input_variance(input_idx) return Value(TfhersExporter.import_fheuint8(buffer, fheint_desc, keyid, variance)) + else: + return Value(TfhersExporter.import_fheint8(buffer, fheint_desc, keyid, variance)) msg = ( # pragma: no cover f"importing {'signed' if signed else 'unsigned'} integers of {bit_width}bits is not" @@ -157,6 +159,10 @@ def export_value(self, value: Value, output_idx: int) -> bytes: return TfhersExporter.export_fheuint8( value._inner, fheint_desc # pylint: disable=protected-access ) + else: + return TfhersExporter.export_fheint8( + value._inner, fheint_desc # pylint: disable=protected-access + ) msg = ( # pragma: no cover f"exporting value to {'signed' if signed else 'unsigned'} integers of {bit_width}bits" diff --git a/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py b/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py index a4fd342e52..36134526a9 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py @@ -146,7 +146,10 @@ def encode(self, value: Union[int, np.integer, list, np.ndarray]) -> np.ndarray: bit_width = self.bit_width msg_width = self.msg_width if isinstance(value, (int, np.integer)): - value_bin = bin(value)[2:].zfill(bit_width) + if self.is_signed and value < 0: + value_bin = bin(2**bit_width + value)[2:].zfill(bit_width) + else: + value_bin = bin(value)[2:].zfill(bit_width) # lsb first return np.array( [int(value_bin[i : i + msg_width], 2) for i in range(0, bit_width, msg_width)][::-1] @@ -201,7 +204,10 @@ def decode(self, value: Union[list, np.ndarray]) -> Union[int, np.ndarray]: if len(value.shape) == 1: # lsb first - return sum(int(v) << (i * msg_width) for i, v in enumerate(value)) + decoded = sum(int(v) << (i * msg_width) for i, v in enumerate(value)) + if self.is_signed and decoded >= 2 ** (bit_width - 1): + decoded = decoded - 2**bit_width + return decoded cts = value.reshape((-1, expected_ct_shape)) return np.array([self.decode(ct) for ct in cts]).reshape(value.shape[:-1]) diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index c56f26ee84..01301fc748 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -280,7 +280,7 @@ def lut_add_lut(x, y): return lut[x + y] -TFHERS_INT_8_3_2_4096 = tfhers.TFHERSIntegerType( +TFHERS_UINT_8_3_2_4096 = tfhers.TFHERSIntegerType( False, bit_width=8, carry_width=3, @@ -297,6 +297,23 @@ def lut_add_lut(x, y): ), ) +TFHERS_INT_8_3_2_4096 = tfhers.TFHERSIntegerType( + True, + bit_width=8, + carry_width=3, + msg_width=2, + params=tfhers.CryptoParams( + lwe_dimension=909, + glwe_dimension=1, + polynomial_size=4096, + pbs_base_log=15, + pbs_level=2, + lwe_noise_distribution=0, + glwe_noise_distribution=2.168404344971009e-19, + encryption_key_choice=tfhers.EncryptionKeyChoice.BIG, + ), +) + @pytest.mark.parametrize( "function, parameters, dtype", @@ -307,34 +324,61 @@ def lut_add_lut(x, y): "x": {"range": [0, 2**7 - 1], "status": "encrypted"}, "y": {"range": [0, 2**7 - 1], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y", ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [-(2**6), -2], "status": "encrypted"}, + "y": {"range": [0, 2**6 - 1], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) + signed(y)", + ), pytest.param( lambda x, y: x - y, { "x": {"range": [2**4, 2**7 - 1], "status": "encrypted"}, "y": {"range": [0, 2**4 - 1], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - y", ), + pytest.param( + lambda x, y: x - y, + { + "x": {"range": [-(2**3), -2], "status": "encrypted"}, + "y": {"range": [-(2**3), -2], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) - signed(y)", + ), pytest.param( lambda x, y: x * y, { "x": {"range": [0, 2**3 - 1], "status": "encrypted"}, "y": {"range": [0, 2**3 - 1], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * y", ), + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [-(2**3), 2**2], "status": "encrypted"}, + "y": {"range": [-(2**2), 2**3], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) * signed(y)", + ), pytest.param( lut_add_lut, { "x": {"range": [0, 2**7 - 1], "status": "encrypted"}, "y": {"range": [0, 2**7 - 1], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="lut_add_lut", ), ], @@ -412,32 +456,32 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( # encrypt inputs and incremnt them by one in TFHErs assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v 1 -c {ct_one_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key {'--signed' if dtype.is_signed else ''} --value=1 -c {ct_one_path} --client-key {client_key_path}" ) == 0 ) sample = [s + 1 for s in sample] assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v {ct1} -c {ct1_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key {'--signed' if dtype.is_signed else ''} --value={ct1} -c {ct1_path} --client-key {client_key_path}" ) == 0 ) assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v {ct2} -c {ct2_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key {'--signed' if dtype.is_signed else ''} --value={ct2} -c {ct2_path} --client-key {client_key_path}" ) == 0 ) assert ( os.system( - f"{tfhers_utils} add -c {ct1_path} {ct_one_path} -s {server_key_path} -o {ct1_path}" + f"{tfhers_utils} add {'--signed' if dtype.is_signed else ''} -c {ct1_path} {ct_one_path} -s {server_key_path} -o {ct1_path}" ) == 0 ) assert ( os.system( - f"{tfhers_utils} add -c {ct2_path} {ct_one_path} -s {server_key_path} -o {ct2_path}" + f"{tfhers_utils} add {'--signed' if dtype.is_signed else ''} -c {ct2_path} {ct_one_path} -s {server_key_path} -o {ct2_path}" ) == 0 ) @@ -468,7 +512,9 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( assert ( os.system( - f"{tfhers_utils} decrypt-with-key" f" -c {ct_out_path} --lwe-sk {key_path} -p {pt_path}" + f"{tfhers_utils} decrypt-with-key" + f"{' --signed ' if dtype.is_signed else ''}" + f" -c {ct_out_path} --lwe-sk {key_path} -p {pt_path}" ) == 0 ) @@ -496,25 +542,61 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( "x": {"range": [0, 2**7], "status": "encrypted"}, "y": {"range": [0, 2**7], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y", ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [-(2**6), -2], "status": "encrypted"}, + "y": {"range": [0, 2**6 - 1], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(-x) + signed(y)", + ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [0, 2**6 - 1], "status": "encrypted"}, + "y": {"range": [-(2**6), -2], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) + signed(-y)", + ), pytest.param( lambda x, y: x + y, { "x": {"range": [0, 2**7], "status": "encrypted"}, "y": {"range": [0, 2**7], "status": "clear"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + clear(y)", ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [-(2**6), -2], "status": "encrypted"}, + "y": {"range": [0, 2**6 - 1], "status": "clear"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(-x) + clear(y)", + ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [0, 2**6 - 1], "status": "encrypted"}, + "y": {"range": [-(2**6), -2], "status": "clear"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) + clear(-y)", + ), pytest.param( lambda x, y: x + y, { "x": {"range": [2**6, 2**7 - 1], "status": "encrypted"}, "y": {"range": [2**6, 2**7 - 1], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y big values", ), pytest.param( @@ -523,7 +605,7 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( "x": {"range": [2**6, 2**7 - 1], "status": "encrypted"}, "y": {"range": [2**6, 2**7 - 1], "status": "clear"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + clear(y) big values", ), pytest.param( @@ -532,16 +614,25 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( "x": {"range": [2**4, 2**8 - 1], "status": "encrypted"}, "y": {"range": [0, 2**4], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - y", ), + pytest.param( + lambda x, y: x - y, + { + "x": {"range": [-(2**3), -2], "status": "encrypted"}, + "y": {"range": [-(2**3), -2], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) - signed(y)", + ), pytest.param( lambda x, y: x - y, { "x": {"range": [2**4, 2**8 - 1], "status": "encrypted"}, "y": {"range": [0, 2**4], "status": "clear"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - clear(y)", ), pytest.param( @@ -550,7 +641,7 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( "x": {"range": [0, 2**3], "status": "encrypted"}, "y": {"range": [0, 2**3], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * y", ), pytest.param( @@ -559,7 +650,7 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( "x": {"range": [0, 2**3], "status": "encrypted"}, "y": {"range": [0, 2**3], "status": "clear"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * clear(y)", ), pytest.param( @@ -568,7 +659,7 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( "x": {"range": [0, 2**7], "status": "encrypted"}, "y": {"range": [0, 2**7], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="lut_add_lut(x , y)", ), ], @@ -633,7 +724,9 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( f"{os.path.dirname(os.path.abspath(__file__))}/../tfhers-utils/target/release/tfhers_utils" ) assert ( - os.system(f"{tfhers_utils} encrypt-with-key -v {ct1} -c {ct1_path} --lwe-sk {key_path}") + os.system( + f"{tfhers_utils} encrypt-with-key {'--signed' if dtype.is_signed else ''} --value={ct1} -c {ct1_path} --lwe-sk {key_path}" + ) == 0 ) @@ -660,7 +753,9 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( assert ( os.system( - f"{tfhers_utils} decrypt-with-key" f" -c {ct_out_path} --lwe-sk {key_path} -p {pt_path}" + f"{tfhers_utils} decrypt-with-key" + f"{' --signed ' if dtype.is_signed else ''}" + f" -c {ct_out_path} --lwe-sk {key_path} -p {pt_path}" ) == 0 ) @@ -686,7 +781,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( "y": {"range": [0, 2**6], "status": "encrypted"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y", ), pytest.param( @@ -696,7 +791,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( "y": {"range": [0, 2**4], "status": "encrypted"}, }, [0, 2**3], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - y", ), pytest.param( @@ -706,7 +801,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( "y": {"range": [0, 2**3], "status": "encrypted"}, }, [0, 2**2], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * y", ), pytest.param( @@ -716,7 +811,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( "y": {"range": [0, 2**6], "status": "encrypted"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="lut_add_lut", ), ], @@ -804,10 +899,12 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( f"{os.path.dirname(os.path.abspath(__file__))}/../tfhers-utils/target/release/tfhers_utils" ) assert ( - os.system(f"{tfhers_utils} encrypt-with-key -v {ct1} -c {ct1_path} --lwe-sk {sk_path}") == 0 + os.system(f"{tfhers_utils} encrypt-with-key --value={ct1} -c {ct1_path} --lwe-sk {sk_path}") + == 0 ) assert ( - os.system(f"{tfhers_utils} encrypt-with-key -v {ct2} -c {ct2_path} --lwe-sk {sk_path}") == 0 + os.system(f"{tfhers_utils} encrypt-with-key --value={ct2} -c {ct2_path} --lwe-sk {sk_path}") + == 0 ) # import ciphertexts and run @@ -853,7 +950,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( random_value = np.random.randint(*tfhers_value_range) assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v {random_value} -c {random_ct_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key --value={random_value} -c {random_ct_path} --client-key {client_key_path}" ) == 0 ) @@ -898,7 +995,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**6], "status": "encrypted"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y", ), pytest.param( @@ -908,7 +1005,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**6], "status": "clear"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + clear(y)", ), pytest.param( @@ -918,7 +1015,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [2**5, 2**6], "status": "encrypted"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y big values", ), pytest.param( @@ -928,7 +1025,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [2**5, 2**6], "status": "clear"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + clear(y) big values", ), pytest.param( @@ -938,7 +1035,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**4], "status": "encrypted"}, }, [0, 2**3], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - y", ), pytest.param( @@ -948,7 +1045,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**4], "status": "clear"}, }, [0, 2**3], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - clear(y)", ), pytest.param( @@ -958,7 +1055,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**3], "status": "encrypted"}, }, [0, 2**2], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * y", ), pytest.param( @@ -968,7 +1065,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**3], "status": "clear"}, }, [0, 2**2], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * clear(y)", ), pytest.param( @@ -978,7 +1075,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**6], "status": "encrypted"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="lut_add_lut(x , y)", ), ], @@ -1064,7 +1161,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_tfhers_keygen( assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v {pt1} -c {ct1_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key --value={pt1} -c {ct1_path} --client-key {client_key_path}" ) == 0 ) @@ -1109,7 +1206,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_tfhers_keygen( random_value = np.random.randint(*tfhers_value_range) assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v {random_value} -c {random_ct_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key --value={random_value} -c {random_ct_path} --client-key {client_key_path}" ) == 0 ) @@ -1178,6 +1275,15 @@ def test_tfhers_integer_eq(lhs, rhs, is_equal): [2, 3, 1, 0, 0, 0, 0, 0], ], ), + pytest.param( + tfhers.int8_2_2, + [-128, 0, 127], + [ + [0, 0, 0, 2], + [0, 0, 0, 0], + [3, 3, 3, 1], + ], + ), ], ) def test_tfhers_integer_encode(dtype, value, encoded): @@ -1223,6 +1329,15 @@ def test_tfhers_integer_bad_encode(dtype, value, expected_error, expected_messag ], [10, 20, 30], ), + pytest.param( + tfhers.int8_2_2, + [ + [2, 1, 0, 2], + [0, 3, 1, 0], + [2, 1, 0, 1], + ], + [-122, 28, 70], + ), ], ) def test_tfhers_integer_decode(dtype, encoded, decoded): diff --git a/frontends/concrete-python/tests/tfhers-utils/src/main.rs b/frontends/concrete-python/tests/tfhers-utils/src/main.rs index 5f363f1bd3..07095d9e1c 100644 --- a/frontends/concrete-python/tests/tfhers-utils/src/main.rs +++ b/frontends/concrete-python/tests/tfhers-utils/src/main.rs @@ -4,9 +4,9 @@ use serde::de::DeserializeOwned; use std::fs; use std::path::Path; use tfhe::core_crypto::prelude::LweSecretKey; -use tfhe::{prelude::*}; +use tfhe::prelude::*; use tfhe::shortint::{ClassicPBSParameters, EncryptionKeyChoice}; -use tfhe::{generate_keys, set_server_key, ClientKey, ConfigBuilder, FheUint8, ServerKey}; +use tfhe::{generate_keys, set_server_key, ClientKey, ConfigBuilder, FheInt8, FheUint8, ServerKey}; use serde::Serialize; use tfhe::named::Named; @@ -48,37 +48,61 @@ fn set_server_key_from_file(path: &String) { set_server_key(sk); } -fn encrypt_with_key(value: u8, client_key: ClientKey, path: &String) { +fn encrypt_with_key_u8(value: u8, client_key: ClientKey, ciphertext_path: &String) { let ct = FheUint8::encrypt(value, &client_key); - safe_save(path, &ct) + safe_save(ciphertext_path, &ct) +} + +fn encrypt_with_key_i8(value: i8, client_key: ClientKey, ciphertext_path: &String) { + let ct = FheInt8::encrypt(value, &client_key); + safe_save(ciphertext_path, &ct) } fn decrypt_with_key( client_key: ClientKey, ciphertext_path: &String, plaintext_path: Option<&String>, + signed: bool, ) { - let fheuint: FheUint8 = safe_load(ciphertext_path); - let result: u8 = fheuint.decrypt(&client_key); + let string_result: String; + + if signed { + let fheint: FheInt8 = safe_load(ciphertext_path); + let result: i8 = fheint.decrypt(&client_key); + string_result = result.to_string(); + } else { + let fheuint: FheUint8 = safe_load(ciphertext_path); + let result: u8 = fheuint.decrypt(&client_key); + string_result = result.to_string(); + } if let Some(path) = plaintext_path { let pt_path: &Path = Path::new(path); - fs::write(pt_path, result.to_string()).unwrap(); + fs::write(pt_path, string_result).unwrap(); } else { - println!("result: {}", result); + println!("result: {}", string_result); } } -fn sum(cts_paths: Vec<&String>, out_ct_path: &String) { +fn sum(cts_paths: Vec<&String>, out_ct_path: &String, signed: bool) { if cts_paths.is_empty() { panic!("can't call sum with 0 ciphertexts"); } - let mut acc: FheUint8 = safe_load(cts_paths[0]); - for ct_path in cts_paths[1..].iter() { - let fheuint: FheUint8 = safe_load(ct_path); - acc += fheuint; + if signed { + let mut acc: FheInt8 = safe_load(cts_paths[0]); + for ct_path in cts_paths[1..].iter() { + let fheuint: FheInt8 = safe_load(ct_path); + acc += fheuint; + } + safe_save(out_ct_path, &acc) + } else { + let mut acc: FheUint8 = safe_load(cts_paths[0]); + for ct_path in cts_paths[1..].iter() { + let fheuint: FheUint8 = safe_load(ct_path); + acc += fheuint; + } + safe_save(out_ct_path, &acc) } - safe_save(out_ct_path, &acc) } fn write_keys( @@ -102,11 +126,7 @@ fn write_keys( } } -fn keygen( - client_key_path: &String, - server_key_path: &String, - output_lwe_path: &String, -) { +fn keygen(client_key_path: &String, server_key_path: &String, output_lwe_path: &String) { let config = ConfigBuilder::with_custom_parameters(BLOCK_PARAMS).build(); let (client_key, server_key) = generate_keys(config); @@ -153,6 +173,12 @@ fn main() { .required(true) .num_args(1), ) + .arg( + Arg::new("signed") + .long("signed") + .help("encrypt as a signed integer") + .action(ArgAction::SetTrue), + ) .arg( Arg::new("ciphertext") .short('c') @@ -186,6 +212,12 @@ fn main() { .short_flag('d') .long_flag("decrypt") .about("Decrypt a ciphertext with a given key.") + .arg( + Arg::new("signed") + .long("signed") + .help("decrypt as a signed integer") + .action(ArgAction::SetTrue), + ) .arg( Arg::new("ciphertext") .short('c') @@ -236,6 +268,12 @@ fn main() { .action(ArgAction::Set) .num_args(1), ) + .arg( + Arg::new("signed") + .long("signed") + .help("consider ciphertexts as signed integers") + .action(ArgAction::SetTrue), + ) .arg( Arg::new("ciphertexts") .short('c') @@ -300,8 +338,8 @@ fn main() { match matches.subcommand() { Some(("encrypt-with-key", encrypt_matches)) => { let value_str = encrypt_matches.get_one::("value").unwrap(); - let value: u8 = value_str.parse().unwrap(); let ciphertext_path = encrypt_matches.get_one::("ciphertext").unwrap(); + let signed = encrypt_matches.get_flag("signed"); let client_key: ClientKey; if let Some(lwe_sk_path) = encrypt_matches.get_one::("lwe-sk") { @@ -312,11 +350,18 @@ fn main() { panic!("no key specified"); } - encrypt_with_key(value, client_key, ciphertext_path) + if signed { + let value: i8 = value_str.parse().unwrap(); + encrypt_with_key_i8(value, client_key, ciphertext_path) + } else { + let value: u8 = value_str.parse().unwrap(); + encrypt_with_key_u8(value, client_key, ciphertext_path) + } } Some(("decrypt-with-key", decrypt_mtches)) => { let ciphertext_path = decrypt_mtches.get_one::("ciphertext").unwrap(); let plaintext_path = decrypt_mtches.get_one::("plaintext"); + let signed = decrypt_mtches.get_flag("signed"); let client_key: ClientKey; if let Some(lwe_sk_path) = decrypt_mtches.get_one::("lwe-sk") { @@ -326,16 +371,17 @@ fn main() { } else { panic!("no key specified"); } - decrypt_with_key(client_key, ciphertext_path, plaintext_path) + decrypt_with_key(client_key, ciphertext_path, plaintext_path, signed) } Some(("add", add_mtches)) => { let server_key_path = add_mtches.get_one::("server-key").unwrap(); let cts_path = add_mtches.get_many::("ciphertexts").unwrap(); let output_ct_path = add_mtches.get_one::("output-ciphertext").unwrap(); + let signed = add_mtches.get_flag("signed"); set_server_key_from_file(server_key_path); - sum(cts_path.collect(), output_ct_path) + sum(cts_path.collect(), output_ct_path, signed) } Some(("keygen", keygen_mtches)) => { let client_key_path = keygen_mtches.get_one::("client-key").unwrap(); From e5c643ec14f3642b3d92d0cf63c7e24cb5665eaf Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 14 Oct 2024 11:39:20 +0100 Subject: [PATCH 2/4] fix(frontend): use esint when tfhers type is signed --- .../concrete/fhe/mlir/converter.py | 5 +++- .../tests/execution/test_tfhers.py | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index ff4c97e223..cbd174c2f9 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -959,7 +959,10 @@ def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> # sum will remove the last dim which is the dim of ciphertexts result_shape = tfhers_int.shape[:-1] # if result_shape is () then ctx.tensor would return a scalar type - result_type = ctx.tensor(ctx.eint(result_bit_width), result_shape) + result_type = ctx.tensor( + ctx.esint(result_bit_width) if dtype.is_signed else ctx.eint(result_bit_width), + result_shape, + ) return ctx.sum(result_type, mapped, axes=-1) def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index 01301fc748..ce0eab473c 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -363,6 +363,16 @@ def lut_add_lut(x, y): TFHERS_UINT_8_3_2_4096, id="x * y", ), + # FIXME: doesn't work when mul bitwidth is smaller than tfhers bitwidth, and result is negative + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [-(2**3), 0], "status": "encrypted"}, + "y": {"range": [-(2**3), 0], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) * signed(y)", + ), pytest.param( lambda x, y: x * y, { @@ -653,6 +663,25 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( TFHERS_UINT_8_3_2_4096, id="x * clear(y)", ), + # FIXME: doesn't work when mul bitwidth is smaller than tfhers bitwidth, and result is negative + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [-(2**3), -(2**3)], "status": "encrypted"}, + "y": {"range": [2**4, 2**4], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) * signed(y)", + ), + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [0, 2**3], "status": "encrypted"}, + "y": {"range": [-(2**3), 0], "status": "clear"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) * clear(-y)", + ), pytest.param( lut_add_lut, { From b0e7c085ea109bb7935ae6366b1ef5784a441ed7 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 22 Oct 2024 11:41:39 +0100 Subject: [PATCH 3/4] fix(frontend): consider padding bit during to_native conversion --- .../concrete/fhe/mlir/converter.py | 26 ++++++++++++++++++- .../tests/execution/test_tfhers.py | 10 +++---- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index cbd174c2f9..f5cf6211bb 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -963,7 +963,31 @@ def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> ctx.esint(result_bit_width) if dtype.is_signed else ctx.eint(result_bit_width), result_shape, ) - return ctx.sum(result_type, mapped, axes=-1) + sum_result = ctx.sum(result_type, mapped, axes=-1) + + # we want to set the padding bit if the native type is signed + # and the ciphertext is negative (sign bit set to 1) + if dtype.is_signed: + # select MSBs of all tfhers ciphetexts + index = [slice(0, dim_size) for dim_size in tfhers_int.shape[:-1]] + [ + -1, + ] + msbs = ctx.index( + ctx.tensor(ctx.eint(msg_width + carry_width), tfhers_int.shape[:-1]), + tfhers_int, + index=index, + ) + # construct padding bits based on sign bits (carry would be considered negative) + padding_bit_table = [ + 0, + ] * 2 ** (msg_width - 1) + [ + 2**result_bit_width, + ] * (2 ** (carry_width + msg_width) - 2 ** (msg_width - 1)) + padding_bits_inc = ctx.tlu(result_type, msbs, padding_bit_table) + # set padding bits (where necessary) in the final result + return ctx.add(result_type, sum_result, padding_bits_inc) + + return sum_result def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 1 diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index ce0eab473c..c747916084 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -363,12 +363,11 @@ def lut_add_lut(x, y): TFHERS_UINT_8_3_2_4096, id="x * y", ), - # FIXME: doesn't work when mul bitwidth is smaller than tfhers bitwidth, and result is negative pytest.param( lambda x, y: x * y, { - "x": {"range": [-(2**3), 0], "status": "encrypted"}, - "y": {"range": [-(2**3), 0], "status": "encrypted"}, + "x": {"range": [-(2**3), 2**2], "status": "encrypted"}, + "y": {"range": [-(2**2), 2**3], "status": "encrypted"}, }, TFHERS_INT_8_3_2_4096, id="signed(x) * signed(y)", @@ -663,12 +662,11 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( TFHERS_UINT_8_3_2_4096, id="x * clear(y)", ), - # FIXME: doesn't work when mul bitwidth is smaller than tfhers bitwidth, and result is negative pytest.param( lambda x, y: x * y, { - "x": {"range": [-(2**3), -(2**3)], "status": "encrypted"}, - "y": {"range": [2**4, 2**4], "status": "encrypted"}, + "x": {"range": [-(2**3), 2], "status": "encrypted"}, + "y": {"range": [-2, 2**4], "status": "encrypted"}, }, TFHERS_INT_8_3_2_4096, id="signed(x) * signed(y)", From 514fe62364f77f03398b03daea728cfb9fc96aca Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 24 Oct 2024 11:43:04 +0100 Subject: [PATCH 4/4] refactor(frontend/compiler): single API for import/export of TFHErs int --- .../concretelang/ClientLib/ClientLib.h | 19 +-- .../lib/Bindings/Python/CompilerAPIModule.cpp | 35 +--- .../Python/concrete/compiler/tfhers_int.py | 70 +------- .../compiler/lib/ClientLib/ClientLib.cpp | 153 +++++++----------- .../concrete/fhe/tfhers/bridge.py | 34 +--- 5 files changed, 79 insertions(+), 232 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h index 99f39a248a..1aacf41efa 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h @@ -34,18 +34,13 @@ using concretelang::values::Value; namespace concretelang { namespace clientlib { -Result -importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, - TfhersFheIntDescription desc, uint32_t encryptionKeyId, - double encryptionVariance); -Result> exportTfhersFheUint8(TransportValue value, - TfhersFheIntDescription info); -Result -importTfhersFheInt8(llvm::ArrayRef serializedFheUint8, - TfhersFheIntDescription desc, uint32_t encryptionKeyId, - double encryptionVariance); -Result> exportTfhersFheInt8(TransportValue value, - TfhersFheIntDescription info); +Result importTfhersInteger(llvm::ArrayRef buffer, + TfhersFheIntDescription integerDesc, + uint32_t encryptionKeyId, + double encryptionVariance); + +Result> +exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc); class ClientCircuit { diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 9e90e1f2d6..bbd12d4067 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -1881,14 +1881,14 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( "Return the `circuit` ClientCircuit.", arg("circuit")) .doc() = "Client-side / Encryption program"; - m.def("import_tfhers_fheuint8", + m.def("import_tfhers_int", [](const pybind11::bytes &serialized_fheuint, TfhersFheIntDescription info, uint32_t encryptionKeyId, double encryptionVariance) { const std::string &buffer_str = serialized_fheuint; std::vector buffer(buffer_str.begin(), buffer_str.end()); auto arrayRef = llvm::ArrayRef(buffer); - auto valueOrError = ::concretelang::clientlib::importTfhersFheUint8( + auto valueOrError = ::concretelang::clientlib::importTfhersInteger( arrayRef, info, encryptionKeyId, encryptionVariance); if (valueOrError.has_error()) { throw std::runtime_error(valueOrError.error().mesg); @@ -1896,34 +1896,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return TransportValue{valueOrError.value()}; }); - m.def("export_tfhers_fheuint8", - [](TransportValue fheuint, TfhersFheIntDescription info) { - auto result = - ::concretelang::clientlib::exportTfhersFheUint8(fheuint, info); - if (result.has_error()) { - throw std::runtime_error(result.error().mesg); - } - return result.value(); - }); - - m.def("import_tfhers_fheint8", - [](const pybind11::bytes &serialized_fheuint, - TfhersFheIntDescription info, uint32_t encryptionKeyId, - double encryptionVariance) { - const std::string &buffer_str = serialized_fheuint; - std::vector buffer(buffer_str.begin(), buffer_str.end()); - auto arrayRef = llvm::ArrayRef(buffer); - auto valueOrError = ::concretelang::clientlib::importTfhersFheInt8( - arrayRef, info, encryptionKeyId, encryptionVariance); - if (valueOrError.has_error()) { - throw std::runtime_error(valueOrError.error().mesg); - } - return TransportValue{valueOrError.value()}; - }); - - m.def("export_tfhers_fheint8", [](TransportValue fheuint, - TfhersFheIntDescription info) { - auto result = ::concretelang::clientlib::exportTfhersFheInt8(fheuint, info); + m.def("export_tfhers_int", [](TransportValue fheuint, + TfhersFheIntDescription info) { + auto result = ::concretelang::clientlib::exportTfhersInteger(fheuint, info); if (result.has_error()) { throw std::runtime_error(result.error().mesg); } diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py index 89cfd38c9b..9491743c2e 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py @@ -3,10 +3,8 @@ # pylint: disable=no-name-in-module,import-error, from mlir._mlir_libs._concretelang._compiler import ( - import_tfhers_fheuint8 as _import_tfhers_fheuint8, - export_tfhers_fheuint8 as _export_tfhers_fheuint8, - import_tfhers_fheint8 as _import_tfhers_fheint8, - export_tfhers_fheint8 as _export_tfhers_fheint8, + import_tfhers_int as _import_tfhers_int, + export_tfhers_int as _export_tfhers_int, TfhersFheIntDescription as _TfhersFheIntDescription, TransportValue, ) @@ -184,7 +182,7 @@ class TfhersExporter: """A helper class to import and export TFHErs big integers.""" @staticmethod - def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> bytes: + def export_int(value: TransportValue, info: TfhersFheIntDescription) -> bytes: """Convert Concrete value to TFHErs and serialize it. Args: @@ -195,7 +193,7 @@ def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> byt TypeError: if wrong input types Returns: - bytes: converted and serialized fheuint8 + bytes: converted and serialized TFHErs integer """ if not isinstance(value, TransportValue): raise TypeError(f"value must be of type TransportValue, not {type(value)}") @@ -203,16 +201,16 @@ def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> byt raise TypeError( f"info must be of type TfhersFheIntDescription, not {type(info)}" ) - return bytes(_export_tfhers_fheuint8(value, info.cpp())) + return bytes(_export_tfhers_int(value, info.cpp())) @staticmethod - def import_fheuint8( + def import_int( buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float ) -> TransportValue: """Unserialize and convert from TFHErs to Concrete value. Args: - buffer (bytes): serialized fheuint8 + buffer (bytes): serialized TFHErs integer info (TfhersFheIntDescription): description of the TFHErs integer to import keyid (int): id of the key used for encryption variance (float): variance used for encryption @@ -233,56 +231,4 @@ def import_fheuint8( raise TypeError(f"keyid must be of type int, not {type(keyid)}") if not isinstance(variance, float): raise TypeError(f"variance must be of type float, not {type(variance)}") - return _import_tfhers_fheuint8(buffer, info.cpp(), keyid, variance) - - @staticmethod - def export_fheint8(value: TransportValue, info: TfhersFheIntDescription) -> bytes: - """Convert Concrete value to TFHErs and serialize it. - - Args: - value (Value): value to export - info (TfhersFheIntDescription): description of the TFHErs integer to export to - - Raises: - TypeError: if wrong input types - - Returns: - bytes: converted and serialized fheuint8 - """ - if not isinstance(value, TransportValue): - raise TypeError(f"value must be of type Value, not {type(value)}") - if not isinstance(info, TfhersFheIntDescription): - raise TypeError( - f"info must be of type TfhersFheIntDescription, not {type(info)}" - ) - return bytes(_export_tfhers_fheint8(value, info.cpp())) - - @staticmethod - def import_fheint8( - buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float - ) -> TransportValue: - """Unserialize and convert from TFHErs to Concrete value. - - Args: - buffer (bytes): serialized fheuint8 - info (TfhersFheIntDescription): description of the TFHErs integer to import - keyid (int): id of the key used for encryption - variance (float): variance used for encryption - - Raises: - TypeError: if wrong input types - - Returns: - Value: unserialized and converted value - """ - if not isinstance(buffer, bytes): - raise TypeError(f"buffer must be of type bytes, not {type(buffer)}") - if not isinstance(info, TfhersFheIntDescription): - raise TypeError( - f"info must be of type TfhersFheIntDescription, not {type(info)}" - ) - if not isinstance(keyid, int): - raise TypeError(f"keyid must be of type int, not {type(keyid)}") - if not isinstance(variance, float): - raise TypeError(f"variance must be of type float, not {type(variance)}") - return _import_tfhers_fheint8(buffer, info.cpp(), keyid, variance) + return _import_tfhers_int(buffer, info.cpp(), keyid, variance) diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp index 1ed0493a78..3fee8e9586 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp @@ -187,92 +187,34 @@ Result ClientProgram::getClientCircuit(std::string circuitName) { "`"); } -Result -importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, - TfhersFheIntDescription desc, uint32_t encryptionKeyId, - double encryptionVariance) { - if (desc.width != 8 || desc.is_signed == true) { - return StringError( - "trying to import FheUint8 but description doesn't match this type"); - } - - auto dims = std::vector({desc.n_cts, desc.lwe_size}); - auto outputTensor = Tensor::fromDimensions(dims); - auto err = concrete_cpu_tfhers_uint8_to_lwe_array( - serializedFheUint8.data(), serializedFheUint8.size(), - outputTensor.values.data(), desc); - if (err) { - return StringError("couldn't convert fheuint to lwe array: err()") - << err << ")"; - } - - auto value = Value{outputTensor}.intoRawTransportValue(); - auto lwe = value.asBuilder().initTypeInfo().initLweCiphertext(); - lwe.setIntegerPrecision(64); - // dimensions - lwe.initAbstractShape().setDimensions({(uint32_t)desc.n_cts}); - lwe.initConcreteShape().setDimensions( - {(uint32_t)desc.n_cts, (uint32_t)desc.lwe_size}); - // encryption - auto encryption = lwe.initEncryption(); - encryption.setLweDimension((uint32_t)desc.lwe_size - 1); - encryption.initModulus().initMod().initNative(); - encryption.setKeyId(encryptionKeyId); - encryption.setVariance(encryptionVariance); - // Encoding - auto encoding = lwe.initEncoding(); - auto integer = encoding.initInteger(); - integer.setIsSigned(false); - integer.setWidth(std::log2(desc.message_modulus * desc.carry_modulus)); - integer.initMode().initNative(); - - return value; -} - -Result> -exportTfhersFheUint8(TransportValue value, TfhersFheIntDescription desc) { - if (desc.width != 8 || desc.is_signed == true) { - return StringError( - "trying to export FheUint8 but description doesn't match this type"); - } - - auto fheuint = Value::fromRawTransportValue(value); - if (fheuint.isScalar()) { - return StringError("expected a tensor, but value is a scalar"); - } - auto tensorOrError = fheuint.getTensor(); - if (!tensorOrError.has_value()) { - return StringError("couldn't get tensor from value"); - } - const size_t bufferSize = - concrete_cpu_tfhers_fheint_buffer_size_u64(desc.lwe_size, desc.n_cts); - std::vector buffer(bufferSize, 0); - auto flatData = tensorOrError.value().values; - auto size = concrete_cpu_lwe_array_to_tfhers_uint8( - flatData.data(), buffer.data(), buffer.size(), desc); - if (size == 0) { - return StringError("couldn't convert lwe array to fheuint8"); - } - // we truncate to the serialized data - assert(size <= buffer.size()); - buffer.resize(size, 0); - return buffer; -} - -Result -importTfhersFheInt8(llvm::ArrayRef serializedFheUint8, - TfhersFheIntDescription desc, uint32_t encryptionKeyId, - double encryptionVariance) { - if (desc.width != 8 || desc.is_signed == false) { - return StringError( - "trying to import FheInt8 but description doesn't match this type"); +Result importTfhersInteger(llvm::ArrayRef buffer, + TfhersFheIntDescription integerDesc, + uint32_t encryptionKeyId, + double encryptionVariance) { + + // Select conversion function based on integer description + std::function + conversion_func; + if (integerDesc.width == 8) { + if (integerDesc.is_signed) { // fheint8 + conversion_func = concrete_cpu_tfhers_int8_to_lwe_array; + } else { // fheuint8 + conversion_func = concrete_cpu_tfhers_uint8_to_lwe_array; + } + } else { + std::ostringstream stringStream; + stringStream << "importTfhersInteger: no support for " << integerDesc.width + << "bits " << (integerDesc.is_signed ? "signed" : "unsigned") + << " integer"; + std::string errorMsg = stringStream.str(); + return StringError(errorMsg); } - auto dims = std::vector({desc.n_cts, desc.lwe_size}); + auto dims = std::vector({integerDesc.n_cts, integerDesc.lwe_size}); auto outputTensor = Tensor::fromDimensions(dims); - auto err = concrete_cpu_tfhers_int8_to_lwe_array( - serializedFheUint8.data(), serializedFheUint8.size(), - outputTensor.values.data(), desc); + auto err = conversion_func(buffer.data(), buffer.size(), + outputTensor.values.data(), integerDesc); if (err) { return StringError("couldn't convert fheint to lwe array"); } @@ -281,30 +223,47 @@ importTfhersFheInt8(llvm::ArrayRef serializedFheUint8, auto lwe = value.asBuilder().initTypeInfo().initLweCiphertext(); lwe.setIntegerPrecision(64); // dimensions - lwe.initAbstractShape().setDimensions({(uint32_t)desc.n_cts}); + lwe.initAbstractShape().setDimensions({(uint32_t)integerDesc.n_cts}); lwe.initConcreteShape().setDimensions( - {(uint32_t)desc.n_cts, (uint32_t)desc.lwe_size}); + {(uint32_t)integerDesc.n_cts, (uint32_t)integerDesc.lwe_size}); // encryption auto encryption = lwe.initEncryption(); - encryption.setLweDimension((uint32_t)desc.lwe_size - 1); + encryption.setLweDimension((uint32_t)integerDesc.lwe_size - 1); encryption.initModulus().initMod().initNative(); encryption.setKeyId(encryptionKeyId); encryption.setVariance(encryptionVariance); // Encoding auto encoding = lwe.initEncoding(); auto integer = encoding.initInteger(); - integer.setIsSigned(false); - integer.setWidth(std::log2(desc.message_modulus * desc.carry_modulus)); + integer.setIsSigned( + false); // should always be unsigned as its for the radix encoded cts + integer.setWidth( + std::log2(integerDesc.message_modulus * integerDesc.carry_modulus)); integer.initMode().initNative(); return value; } -Result> exportTfhersFheInt8(TransportValue value, - TfhersFheIntDescription desc) { - if (desc.width != 8 || desc.is_signed == false) { - return StringError( - "trying to export FheInt8 but description doesn't match this type"); +Result> +exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc) { + // Select conversion function based on integer description + std::function + conversion_func; + std::function buffer_size_func; + if (integerDesc.width == 8) { + if (integerDesc.is_signed) { // fheint8 + conversion_func = concrete_cpu_lwe_array_to_tfhers_int8; + } else { // fheuint8 + conversion_func = concrete_cpu_lwe_array_to_tfhers_uint8; + } + } else { + std::ostringstream stringStream; + stringStream << "exportTfhersInteger: no support for " << integerDesc.width + << "bits " << (integerDesc.is_signed ? "signed" : "unsigned") + << " integer"; + std::string errorMsg = stringStream.str(); + return StringError(errorMsg); } auto fheuint = Value::fromRawTransportValue(value); @@ -315,12 +274,12 @@ Result> exportTfhersFheInt8(TransportValue value, if (!tensorOrError.has_value()) { return StringError("couldn't get tensor from value"); } - size_t buffer_size = - concrete_cpu_tfhers_fheint_buffer_size_u64(desc.lwe_size, desc.n_cts); + size_t buffer_size = concrete_cpu_tfhers_fheint_buffer_size_u64( + integerDesc.lwe_size, integerDesc.n_cts); std::vector buffer(buffer_size, 0); auto flat_data = tensorOrError.value().values; - auto size = concrete_cpu_lwe_array_to_tfhers_int8( - flat_data.data(), buffer.data(), buffer.size(), desc); + auto size = conversion_func(flat_data.data(), buffer.data(), buffer.size(), + integerDesc); if (size == 0) { return StringError("couldn't convert lwe array to fheint8"); } diff --git a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py index 27a5811b05..de08b36842 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py @@ -118,22 +118,9 @@ def import_value(self, buffer: bytes, input_idx: int) -> Value: raise ValueError(msg) fheint_desc = self._description_from_type(input_type) - - bit_width = input_type.bit_width - signed = input_type.is_signed keyid = self._input_keyid(input_idx) variance = self._input_variance(input_idx) - if bit_width == 8: - if not signed: - return Value(TfhersExporter.import_fheuint8(buffer, fheint_desc, keyid, variance)) - else: - return Value(TfhersExporter.import_fheint8(buffer, fheint_desc, keyid, variance)) - - msg = ( # pragma: no cover - f"importing {'signed' if signed else 'unsigned'} integers of {bit_width}bits is not" - " yet supported" - ) - raise NotImplementedError(msg) # pragma: no cover + return Value(TfhersExporter.import_int(buffer, fheint_desc, keyid, variance)) def export_value(self, value: Value, output_idx: int) -> bytes: """Export a value as a serialized TFHErs integer. @@ -151,24 +138,9 @@ def export_value(self, value: Value, output_idx: int) -> bytes: raise ValueError(msg) fheint_desc = self._description_from_type(output_type) - - bit_width = output_type.bit_width - signed = output_type.is_signed - if bit_width == 8: - if not signed: - return TfhersExporter.export_fheuint8( - value._inner, fheint_desc # pylint: disable=protected-access - ) - else: - return TfhersExporter.export_fheint8( - value._inner, fheint_desc # pylint: disable=protected-access - ) - - msg = ( # pragma: no cover - f"exporting value to {'signed' if signed else 'unsigned'} integers of {bit_width}bits" - " is not yet supported" + return TfhersExporter.export_int( + value._inner, fheint_desc # pylint: disable=protected-access ) - raise NotImplementedError(msg) # pragma: no cover def serialize_input_secret_key(self, input_idx: int) -> bytes: """Serialize secret key used for a specific input.