diff --git a/backends/concrete-cpu/implementation/include/concrete-cpu.h b/backends/concrete-cpu/implementation/include/concrete-cpu.h index 3c6ced4e8f..4d9cc35d6a 100644 --- a/backends/concrete-cpu/implementation/include/concrete-cpu.h +++ b/backends/concrete-cpu/implementation/include/concrete-cpu.h @@ -369,6 +369,11 @@ void concrete_cpu_keyswitch_lwe_ciphertext_u64(uint64_t *ct_out, size_t input_dimension, size_t output_dimension); +size_t concrete_cpu_lwe_array_to_tfhers_int8(const uint64_t *lwe_vec_buffer, + uint8_t *buffer, + size_t buffer_len, + struct TfhersFheIntDescription fheint_desc); + size_t concrete_cpu_lwe_array_to_tfhers_uint8(const uint64_t *lwe_vec_buffer, uint8_t *buffer, size_t buffer_len, @@ -415,6 +420,11 @@ size_t concrete_cpu_serialize_lwe_secret_key_u64(const uint64_t *lwe_sk, size_t concrete_cpu_tfhers_fheint_buffer_size_u64(size_t lwe_size, size_t n_cts); +int64_t concrete_cpu_tfhers_int8_to_lwe_array(const uint8_t *serialized_data_ptr, + size_t serialized_data_len, + uint64_t *lwe_vec_buffer, + struct TfhersFheIntDescription desc); + int64_t concrete_cpu_tfhers_uint8_to_lwe_array(const uint8_t *buffer, size_t buffer_len, uint64_t *lwe_vec_buffer, diff --git a/backends/concrete-cpu/implementation/src/c_api/fheint.rs b/backends/concrete-cpu/implementation/src/c_api/fheint.rs index 1364b62646..3eecab1e98 100644 --- a/backends/concrete-cpu/implementation/src/c_api/fheint.rs +++ b/backends/concrete-cpu/implementation/src/c_api/fheint.rs @@ -5,7 +5,7 @@ use tfhe::integer::ciphertext::Expandable; use tfhe::integer::IntegerCiphertext; use tfhe::shortint::parameters::{Degree, NoiseLevel}; use tfhe::shortint::{CarryModulus, Ciphertext, MessageModulus}; -use tfhe::{FheUint128, FheUint8}; +use tfhe::{FheInt8, FheUint128, FheUint8}; #[repr(C)] pub struct TfhersFheIntDescription { @@ -127,6 +127,29 @@ pub fn tfhers_uint8_description(fheuint: FheUint8) -> TfhersFheIntDescription { } } +pub fn tfhers_int8_description(fheuint: FheInt8) -> TfhersFheIntDescription { + // get metadata from fheuint's ciphertext + let (radix, _, _) = fheuint.into_raw_parts(); + let blocks = radix.blocks(); + let ct = match blocks.first() { + Some(value) => &value.ct, + None => { + return TfhersFheIntDescription::zero(); + } + }; + TfhersFheIntDescription { + width: 8, + is_signed: true, + lwe_size: ct.lwe_size().0, + n_cts: blocks.len(), + degree: blocks[0].degree.get(), + noise_level: blocks[0].noise_level().get(), + message_modulus: blocks[0].message_modulus.0, + carry_modulus: blocks[0].carry_modulus.0, + ks_first: blocks[0].pbs_order == PBSOrder::KeyswitchBootstrap, + } +} + #[no_mangle] pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_to_lwe_array( buffer: *const u8, @@ -161,6 +184,41 @@ pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_to_lwe_array( }) } +#[no_mangle] +pub unsafe extern "C" fn concrete_cpu_tfhers_int8_to_lwe_array( + serialized_data_ptr: *const u8, + serialized_data_len: usize, + lwe_vec_buffer: *mut u64, + desc: TfhersFheIntDescription, +) -> i64 { + nounwind(|| { + let fheint: FheInt8 = + super::utils::safe_deserialize(serialized_data_ptr, serialized_data_len); + // TODO - Use conformance check + let fheint_desc = tfhers_int8_description(fheint.clone()); + if !fheint_desc.is_similar(&desc) { + return 1; + } + + // collect LWEs from fheuint + let (radix, _, _) = fheint.into_raw_parts(); + let blocks = radix.blocks(); + let first_ct = match blocks.first() { + Some(value) => &value.ct, + None => return 1, + }; + let lwe_size = first_ct.lwe_size().0; + let n_cts = blocks.len(); + // copy LWEs to C buffer. Note that lsb is cts[0] + let lwe_vector: &mut [u64] = slice::from_raw_parts_mut(lwe_vec_buffer, n_cts * lwe_size); + for (i, block) in blocks.iter().enumerate() { + lwe_vector[i * lwe_size..(i + 1) * lwe_size] + .copy_from_slice(block.ct.clone().into_container().as_slice()); + } + 0 + }) +} + #[no_mangle] pub extern "C" fn concrete_cpu_tfhers_fheint_buffer_size_u64( lwe_size: usize, @@ -198,7 +256,7 @@ pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8( let n_cts = fheuint_desc.n_cts; // construct fheuint from LWEs let lwe_vector: &[u64] = slice::from_raw_parts(lwe_vec_buffer, n_cts * lwe_size); - let mut blocks: Vec = Vec::new(); + let mut blocks: Vec = Vec::with_capacity(n_cts); for i in 0..n_cts { let lwe_ct = LweCiphertext::>::from_container( lwe_vector[i * lwe_size..(i + 1) * lwe_size].to_vec(), @@ -215,3 +273,45 @@ pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8( super::utils::safe_serialize(&fheuint, buffer, buffer_len) }) } + +#[no_mangle] +pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_int8( + lwe_vec_buffer: *const u64, + buffer: *mut u8, + buffer_len: usize, + fheint_desc: TfhersFheIntDescription, +) -> usize { + nounwind(|| { + // we want to trigger a PBS on TFHErs side + assert!( + fheint_desc.noise_level == NoiseLevel::UNKNOWN.get(), + "noise_level must be unknown" + ); + // we want to use the max degree as we don't track it on Concrete side + assert!( + fheint_desc.degree == fheint_desc.message_modulus - 1, + "degree must be the max value (msg_modulus - 1)" + ); + + let lwe_size = fheint_desc.lwe_size; + let n_cts = fheint_desc.n_cts; + // construct fheuint from LWEs + let lwe_vector: &[u64] = slice::from_raw_parts(lwe_vec_buffer, n_cts * lwe_size); + let mut blocks: Vec = Vec::with_capacity(n_cts); + for i in 0..n_cts { + let lwe_ct = LweCiphertext::>::from_container( + lwe_vector[i * lwe_size..(i + 1) * lwe_size].to_vec(), + CiphertextModulus::new_native(), + ); + blocks.push(fheint_desc.ct_from_lwe(lwe_ct)); + } + let fheuint = match FheInt8::from_expanded_blocks(blocks, fheint_desc.data_kind()) { + Ok(value) => value, + Err(_) => { + return 0; + } + }; + + super::utils::safe_serialize(&fheuint, buffer, buffer_len) + }) +} diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h index bf0d65c430..1aacf41efa 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h @@ -34,12 +34,13 @@ using concretelang::values::Value; namespace concretelang { namespace clientlib { -Result -importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, - TfhersFheIntDescription desc, uint32_t encryptionKeyId, - double encryptionVariance); -Result> exportTfhersFheUint8(TransportValue value, - TfhersFheIntDescription info); +Result importTfhersInteger(llvm::ArrayRef buffer, + TfhersFheIntDescription integerDesc, + uint32_t encryptionKeyId, + double encryptionVariance); + +Result> +exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc); class ClientCircuit { diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 54f1af4844..bbd12d4067 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -1881,14 +1881,14 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( "Return the `circuit` ClientCircuit.", arg("circuit")) .doc() = "Client-side / Encryption program"; - m.def("import_tfhers_fheuint8", + m.def("import_tfhers_int", [](const pybind11::bytes &serialized_fheuint, TfhersFheIntDescription info, uint32_t encryptionKeyId, double encryptionVariance) { const std::string &buffer_str = serialized_fheuint; std::vector buffer(buffer_str.begin(), buffer_str.end()); auto arrayRef = llvm::ArrayRef(buffer); - auto valueOrError = ::concretelang::clientlib::importTfhersFheUint8( + auto valueOrError = ::concretelang::clientlib::importTfhersInteger( arrayRef, info, encryptionKeyId, encryptionVariance); if (valueOrError.has_error()) { throw std::runtime_error(valueOrError.error().mesg); @@ -1896,13 +1896,12 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return TransportValue{valueOrError.value()}; }); - m.def("export_tfhers_fheuint8", - [](TransportValue fheuint, TfhersFheIntDescription info) { - auto result = - ::concretelang::clientlib::exportTfhersFheUint8(fheuint, info); - if (result.has_error()) { - throw std::runtime_error(result.error().mesg); - } - return result.value(); - }); + m.def("export_tfhers_int", [](TransportValue fheuint, + TfhersFheIntDescription info) { + auto result = ::concretelang::clientlib::exportTfhersInteger(fheuint, info); + if (result.has_error()) { + throw std::runtime_error(result.error().mesg); + } + return result.value(); + }); } diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py index d4006b058d..9491743c2e 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py @@ -3,8 +3,8 @@ # pylint: disable=no-name-in-module,import-error, from mlir._mlir_libs._concretelang._compiler import ( - import_tfhers_fheuint8 as _import_tfhers_fheuint8, - export_tfhers_fheuint8 as _export_tfhers_fheuint8, + import_tfhers_int as _import_tfhers_int, + export_tfhers_int as _export_tfhers_int, TfhersFheIntDescription as _TfhersFheIntDescription, TransportValue, ) @@ -182,7 +182,7 @@ class TfhersExporter: """A helper class to import and export TFHErs big integers.""" @staticmethod - def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> bytes: + def export_int(value: TransportValue, info: TfhersFheIntDescription) -> bytes: """Convert Concrete value to TFHErs and serialize it. Args: @@ -193,7 +193,7 @@ def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> byt TypeError: if wrong input types Returns: - bytes: converted and serialized fheuint8 + bytes: converted and serialized TFHErs integer """ if not isinstance(value, TransportValue): raise TypeError(f"value must be of type TransportValue, not {type(value)}") @@ -201,16 +201,16 @@ def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> byt raise TypeError( f"info must be of type TfhersFheIntDescription, not {type(info)}" ) - return bytes(_export_tfhers_fheuint8(value, info.cpp())) + return bytes(_export_tfhers_int(value, info.cpp())) @staticmethod - def import_fheuint8( + def import_int( buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float ) -> TransportValue: """Unserialize and convert from TFHErs to Concrete value. Args: - buffer (bytes): serialized fheuint8 + buffer (bytes): serialized TFHErs integer info (TfhersFheIntDescription): description of the TFHErs integer to import keyid (int): id of the key used for encryption variance (float): variance used for encryption @@ -231,4 +231,4 @@ def import_fheuint8( raise TypeError(f"keyid must be of type int, not {type(keyid)}") if not isinstance(variance, float): raise TypeError(f"variance must be of type float, not {type(variance)}") - return _import_tfhers_fheuint8(buffer, info.cpp(), keyid, variance) + return _import_tfhers_int(buffer, info.cpp(), keyid, variance) diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp index 000ec4dd1f..3fee8e9586 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp @@ -187,53 +187,83 @@ Result ClientProgram::getClientCircuit(std::string circuitName) { "`"); } -Result -importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, - TfhersFheIntDescription desc, uint32_t encryptionKeyId, - double encryptionVariance) { - if (desc.width != 8 || desc.is_signed == true) { - return StringError( - "trying to import FheUint8 but description doesn't match this type"); +Result importTfhersInteger(llvm::ArrayRef buffer, + TfhersFheIntDescription integerDesc, + uint32_t encryptionKeyId, + double encryptionVariance) { + + // Select conversion function based on integer description + std::function + conversion_func; + if (integerDesc.width == 8) { + if (integerDesc.is_signed) { // fheint8 + conversion_func = concrete_cpu_tfhers_int8_to_lwe_array; + } else { // fheuint8 + conversion_func = concrete_cpu_tfhers_uint8_to_lwe_array; + } + } else { + std::ostringstream stringStream; + stringStream << "importTfhersInteger: no support for " << integerDesc.width + << "bits " << (integerDesc.is_signed ? "signed" : "unsigned") + << " integer"; + std::string errorMsg = stringStream.str(); + return StringError(errorMsg); } - auto dims = std::vector({desc.n_cts, desc.lwe_size}); + auto dims = std::vector({integerDesc.n_cts, integerDesc.lwe_size}); auto outputTensor = Tensor::fromDimensions(dims); - auto err = concrete_cpu_tfhers_uint8_to_lwe_array( - serializedFheUint8.data(), serializedFheUint8.size(), - outputTensor.values.data(), desc); + auto err = conversion_func(buffer.data(), buffer.size(), + outputTensor.values.data(), integerDesc); if (err) { - return StringError("couldn't convert fheuint to lwe array: err()") - << err << ")"; + return StringError("couldn't convert fheint to lwe array"); } auto value = Value{outputTensor}.intoRawTransportValue(); auto lwe = value.asBuilder().initTypeInfo().initLweCiphertext(); lwe.setIntegerPrecision(64); // dimensions - lwe.initAbstractShape().setDimensions({(uint32_t)desc.n_cts}); + lwe.initAbstractShape().setDimensions({(uint32_t)integerDesc.n_cts}); lwe.initConcreteShape().setDimensions( - {(uint32_t)desc.n_cts, (uint32_t)desc.lwe_size}); + {(uint32_t)integerDesc.n_cts, (uint32_t)integerDesc.lwe_size}); // encryption auto encryption = lwe.initEncryption(); - encryption.setLweDimension((uint32_t)desc.lwe_size - 1); + encryption.setLweDimension((uint32_t)integerDesc.lwe_size - 1); encryption.initModulus().initMod().initNative(); encryption.setKeyId(encryptionKeyId); encryption.setVariance(encryptionVariance); // Encoding auto encoding = lwe.initEncoding(); auto integer = encoding.initInteger(); - integer.setIsSigned(false); - integer.setWidth(std::log2(desc.message_modulus * desc.carry_modulus)); + integer.setIsSigned( + false); // should always be unsigned as its for the radix encoded cts + integer.setWidth( + std::log2(integerDesc.message_modulus * integerDesc.carry_modulus)); integer.initMode().initNative(); return value; } Result> -exportTfhersFheUint8(TransportValue value, TfhersFheIntDescription desc) { - if (desc.width != 8 || desc.is_signed == true) { - return StringError( - "trying to export FheUint8 but description doesn't match this type"); +exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc) { + // Select conversion function based on integer description + std::function + conversion_func; + std::function buffer_size_func; + if (integerDesc.width == 8) { + if (integerDesc.is_signed) { // fheint8 + conversion_func = concrete_cpu_lwe_array_to_tfhers_int8; + } else { // fheuint8 + conversion_func = concrete_cpu_lwe_array_to_tfhers_uint8; + } + } else { + std::ostringstream stringStream; + stringStream << "exportTfhersInteger: no support for " << integerDesc.width + << "bits " << (integerDesc.is_signed ? "signed" : "unsigned") + << " integer"; + std::string errorMsg = stringStream.str(); + return StringError(errorMsg); } auto fheuint = Value::fromRawTransportValue(value); @@ -244,14 +274,14 @@ exportTfhersFheUint8(TransportValue value, TfhersFheIntDescription desc) { if (!tensorOrError.has_value()) { return StringError("couldn't get tensor from value"); } - const size_t bufferSize = - concrete_cpu_tfhers_fheint_buffer_size_u64(desc.lwe_size, desc.n_cts); - std::vector buffer(bufferSize, 0); - auto flatData = tensorOrError.value().values; - auto size = concrete_cpu_lwe_array_to_tfhers_uint8( - flatData.data(), buffer.data(), buffer.size(), desc); + size_t buffer_size = concrete_cpu_tfhers_fheint_buffer_size_u64( + integerDesc.lwe_size, integerDesc.n_cts); + std::vector buffer(buffer_size, 0); + auto flat_data = tensorOrError.value().values; + auto size = conversion_func(flat_data.data(), buffer.data(), buffer.size(), + integerDesc); if (size == 0) { - return StringError("couldn't convert lwe array to fheuint8"); + return StringError("couldn't convert lwe array to fheint8"); } // we truncate to the serialized data assert(size <= buffer.size()); diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index ff4c97e223..f5cf6211bb 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -959,8 +959,35 @@ def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> # sum will remove the last dim which is the dim of ciphertexts result_shape = tfhers_int.shape[:-1] # if result_shape is () then ctx.tensor would return a scalar type - result_type = ctx.tensor(ctx.eint(result_bit_width), result_shape) - return ctx.sum(result_type, mapped, axes=-1) + result_type = ctx.tensor( + ctx.esint(result_bit_width) if dtype.is_signed else ctx.eint(result_bit_width), + result_shape, + ) + sum_result = ctx.sum(result_type, mapped, axes=-1) + + # we want to set the padding bit if the native type is signed + # and the ciphertext is negative (sign bit set to 1) + if dtype.is_signed: + # select MSBs of all tfhers ciphetexts + index = [slice(0, dim_size) for dim_size in tfhers_int.shape[:-1]] + [ + -1, + ] + msbs = ctx.index( + ctx.tensor(ctx.eint(msg_width + carry_width), tfhers_int.shape[:-1]), + tfhers_int, + index=index, + ) + # construct padding bits based on sign bits (carry would be considered negative) + padding_bit_table = [ + 0, + ] * 2 ** (msg_width - 1) + [ + 2**result_bit_width, + ] * (2 ** (carry_width + msg_width) - 2 ** (msg_width - 1)) + padding_bits_inc = ctx.tlu(result_type, msbs, padding_bit_table) + # set padding bits (where necessary) in the final result + return ctx.add(result_type, sum_result, padding_bits_inc) + + return sum_result def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 1 diff --git a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py index 52ffed0af6..de08b36842 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py @@ -118,20 +118,9 @@ def import_value(self, buffer: bytes, input_idx: int) -> Value: raise ValueError(msg) fheint_desc = self._description_from_type(input_type) - - bit_width = input_type.bit_width - signed = input_type.is_signed - if bit_width == 8: - if not signed: - keyid = self._input_keyid(input_idx) - variance = self._input_variance(input_idx) - return Value(TfhersExporter.import_fheuint8(buffer, fheint_desc, keyid, variance)) - - msg = ( # pragma: no cover - f"importing {'signed' if signed else 'unsigned'} integers of {bit_width}bits is not" - " yet supported" - ) - raise NotImplementedError(msg) # pragma: no cover + keyid = self._input_keyid(input_idx) + variance = self._input_variance(input_idx) + return Value(TfhersExporter.import_int(buffer, fheint_desc, keyid, variance)) def export_value(self, value: Value, output_idx: int) -> bytes: """Export a value as a serialized TFHErs integer. @@ -149,20 +138,9 @@ def export_value(self, value: Value, output_idx: int) -> bytes: raise ValueError(msg) fheint_desc = self._description_from_type(output_type) - - bit_width = output_type.bit_width - signed = output_type.is_signed - if bit_width == 8: - if not signed: - return TfhersExporter.export_fheuint8( - value._inner, fheint_desc # pylint: disable=protected-access - ) - - msg = ( # pragma: no cover - f"exporting value to {'signed' if signed else 'unsigned'} integers of {bit_width}bits" - " is not yet supported" + return TfhersExporter.export_int( + value._inner, fheint_desc # pylint: disable=protected-access ) - raise NotImplementedError(msg) # pragma: no cover def serialize_input_secret_key(self, input_idx: int) -> bytes: """Serialize secret key used for a specific input. diff --git a/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py b/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py index a4fd342e52..36134526a9 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/dtypes.py @@ -146,7 +146,10 @@ def encode(self, value: Union[int, np.integer, list, np.ndarray]) -> np.ndarray: bit_width = self.bit_width msg_width = self.msg_width if isinstance(value, (int, np.integer)): - value_bin = bin(value)[2:].zfill(bit_width) + if self.is_signed and value < 0: + value_bin = bin(2**bit_width + value)[2:].zfill(bit_width) + else: + value_bin = bin(value)[2:].zfill(bit_width) # lsb first return np.array( [int(value_bin[i : i + msg_width], 2) for i in range(0, bit_width, msg_width)][::-1] @@ -201,7 +204,10 @@ def decode(self, value: Union[list, np.ndarray]) -> Union[int, np.ndarray]: if len(value.shape) == 1: # lsb first - return sum(int(v) << (i * msg_width) for i, v in enumerate(value)) + decoded = sum(int(v) << (i * msg_width) for i, v in enumerate(value)) + if self.is_signed and decoded >= 2 ** (bit_width - 1): + decoded = decoded - 2**bit_width + return decoded cts = value.reshape((-1, expected_ct_shape)) return np.array([self.decode(ct) for ct in cts]).reshape(value.shape[:-1]) diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index c56f26ee84..c747916084 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -280,7 +280,7 @@ def lut_add_lut(x, y): return lut[x + y] -TFHERS_INT_8_3_2_4096 = tfhers.TFHERSIntegerType( +TFHERS_UINT_8_3_2_4096 = tfhers.TFHERSIntegerType( False, bit_width=8, carry_width=3, @@ -297,6 +297,23 @@ def lut_add_lut(x, y): ), ) +TFHERS_INT_8_3_2_4096 = tfhers.TFHERSIntegerType( + True, + bit_width=8, + carry_width=3, + msg_width=2, + params=tfhers.CryptoParams( + lwe_dimension=909, + glwe_dimension=1, + polynomial_size=4096, + pbs_base_log=15, + pbs_level=2, + lwe_noise_distribution=0, + glwe_noise_distribution=2.168404344971009e-19, + encryption_key_choice=tfhers.EncryptionKeyChoice.BIG, + ), +) + @pytest.mark.parametrize( "function, parameters, dtype", @@ -307,34 +324,70 @@ def lut_add_lut(x, y): "x": {"range": [0, 2**7 - 1], "status": "encrypted"}, "y": {"range": [0, 2**7 - 1], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y", ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [-(2**6), -2], "status": "encrypted"}, + "y": {"range": [0, 2**6 - 1], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) + signed(y)", + ), pytest.param( lambda x, y: x - y, { "x": {"range": [2**4, 2**7 - 1], "status": "encrypted"}, "y": {"range": [0, 2**4 - 1], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - y", ), + pytest.param( + lambda x, y: x - y, + { + "x": {"range": [-(2**3), -2], "status": "encrypted"}, + "y": {"range": [-(2**3), -2], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) - signed(y)", + ), pytest.param( lambda x, y: x * y, { "x": {"range": [0, 2**3 - 1], "status": "encrypted"}, "y": {"range": [0, 2**3 - 1], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * y", ), + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [-(2**3), 2**2], "status": "encrypted"}, + "y": {"range": [-(2**2), 2**3], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) * signed(y)", + ), + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [-(2**3), 2**2], "status": "encrypted"}, + "y": {"range": [-(2**2), 2**3], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) * signed(y)", + ), pytest.param( lut_add_lut, { "x": {"range": [0, 2**7 - 1], "status": "encrypted"}, "y": {"range": [0, 2**7 - 1], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="lut_add_lut", ), ], @@ -412,32 +465,32 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( # encrypt inputs and incremnt them by one in TFHErs assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v 1 -c {ct_one_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key {'--signed' if dtype.is_signed else ''} --value=1 -c {ct_one_path} --client-key {client_key_path}" ) == 0 ) sample = [s + 1 for s in sample] assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v {ct1} -c {ct1_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key {'--signed' if dtype.is_signed else ''} --value={ct1} -c {ct1_path} --client-key {client_key_path}" ) == 0 ) assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v {ct2} -c {ct2_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key {'--signed' if dtype.is_signed else ''} --value={ct2} -c {ct2_path} --client-key {client_key_path}" ) == 0 ) assert ( os.system( - f"{tfhers_utils} add -c {ct1_path} {ct_one_path} -s {server_key_path} -o {ct1_path}" + f"{tfhers_utils} add {'--signed' if dtype.is_signed else ''} -c {ct1_path} {ct_one_path} -s {server_key_path} -o {ct1_path}" ) == 0 ) assert ( os.system( - f"{tfhers_utils} add -c {ct2_path} {ct_one_path} -s {server_key_path} -o {ct2_path}" + f"{tfhers_utils} add {'--signed' if dtype.is_signed else ''} -c {ct2_path} {ct_one_path} -s {server_key_path} -o {ct2_path}" ) == 0 ) @@ -468,7 +521,9 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( assert ( os.system( - f"{tfhers_utils} decrypt-with-key" f" -c {ct_out_path} --lwe-sk {key_path} -p {pt_path}" + f"{tfhers_utils} decrypt-with-key" + f"{' --signed ' if dtype.is_signed else ''}" + f" -c {ct_out_path} --lwe-sk {key_path} -p {pt_path}" ) == 0 ) @@ -496,25 +551,61 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( "x": {"range": [0, 2**7], "status": "encrypted"}, "y": {"range": [0, 2**7], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y", ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [-(2**6), -2], "status": "encrypted"}, + "y": {"range": [0, 2**6 - 1], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(-x) + signed(y)", + ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [0, 2**6 - 1], "status": "encrypted"}, + "y": {"range": [-(2**6), -2], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) + signed(-y)", + ), pytest.param( lambda x, y: x + y, { "x": {"range": [0, 2**7], "status": "encrypted"}, "y": {"range": [0, 2**7], "status": "clear"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + clear(y)", ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [-(2**6), -2], "status": "encrypted"}, + "y": {"range": [0, 2**6 - 1], "status": "clear"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(-x) + clear(y)", + ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [0, 2**6 - 1], "status": "encrypted"}, + "y": {"range": [-(2**6), -2], "status": "clear"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) + clear(-y)", + ), pytest.param( lambda x, y: x + y, { "x": {"range": [2**6, 2**7 - 1], "status": "encrypted"}, "y": {"range": [2**6, 2**7 - 1], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y big values", ), pytest.param( @@ -523,7 +614,7 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( "x": {"range": [2**6, 2**7 - 1], "status": "encrypted"}, "y": {"range": [2**6, 2**7 - 1], "status": "clear"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + clear(y) big values", ), pytest.param( @@ -532,16 +623,25 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( "x": {"range": [2**4, 2**8 - 1], "status": "encrypted"}, "y": {"range": [0, 2**4], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - y", ), + pytest.param( + lambda x, y: x - y, + { + "x": {"range": [-(2**3), -2], "status": "encrypted"}, + "y": {"range": [-(2**3), -2], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) - signed(y)", + ), pytest.param( lambda x, y: x - y, { "x": {"range": [2**4, 2**8 - 1], "status": "encrypted"}, "y": {"range": [0, 2**4], "status": "clear"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - clear(y)", ), pytest.param( @@ -550,7 +650,7 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( "x": {"range": [0, 2**3], "status": "encrypted"}, "y": {"range": [0, 2**3], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * y", ), pytest.param( @@ -559,16 +659,34 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( "x": {"range": [0, 2**3], "status": "encrypted"}, "y": {"range": [0, 2**3], "status": "clear"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * clear(y)", ), + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [-(2**3), 2], "status": "encrypted"}, + "y": {"range": [-2, 2**4], "status": "encrypted"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) * signed(y)", + ), + pytest.param( + lambda x, y: x * y, + { + "x": {"range": [0, 2**3], "status": "encrypted"}, + "y": {"range": [-(2**3), 0], "status": "clear"}, + }, + TFHERS_INT_8_3_2_4096, + id="signed(x) * clear(-y)", + ), pytest.param( lut_add_lut, { "x": {"range": [0, 2**7], "status": "encrypted"}, "y": {"range": [0, 2**7], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="lut_add_lut(x , y)", ), ], @@ -633,7 +751,9 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( f"{os.path.dirname(os.path.abspath(__file__))}/../tfhers-utils/target/release/tfhers_utils" ) assert ( - os.system(f"{tfhers_utils} encrypt-with-key -v {ct1} -c {ct1_path} --lwe-sk {key_path}") + os.system( + f"{tfhers_utils} encrypt-with-key {'--signed' if dtype.is_signed else ''} --value={ct1} -c {ct1_path} --lwe-sk {key_path}" + ) == 0 ) @@ -660,7 +780,9 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( assert ( os.system( - f"{tfhers_utils} decrypt-with-key" f" -c {ct_out_path} --lwe-sk {key_path} -p {pt_path}" + f"{tfhers_utils} decrypt-with-key" + f"{' --signed ' if dtype.is_signed else ''}" + f" -c {ct_out_path} --lwe-sk {key_path} -p {pt_path}" ) == 0 ) @@ -686,7 +808,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( "y": {"range": [0, 2**6], "status": "encrypted"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y", ), pytest.param( @@ -696,7 +818,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( "y": {"range": [0, 2**4], "status": "encrypted"}, }, [0, 2**3], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - y", ), pytest.param( @@ -706,7 +828,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( "y": {"range": [0, 2**3], "status": "encrypted"}, }, [0, 2**2], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * y", ), pytest.param( @@ -716,7 +838,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( "y": {"range": [0, 2**6], "status": "encrypted"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="lut_add_lut", ), ], @@ -804,10 +926,12 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( f"{os.path.dirname(os.path.abspath(__file__))}/../tfhers-utils/target/release/tfhers_utils" ) assert ( - os.system(f"{tfhers_utils} encrypt-with-key -v {ct1} -c {ct1_path} --lwe-sk {sk_path}") == 0 + os.system(f"{tfhers_utils} encrypt-with-key --value={ct1} -c {ct1_path} --lwe-sk {sk_path}") + == 0 ) assert ( - os.system(f"{tfhers_utils} encrypt-with-key -v {ct2} -c {ct2_path} --lwe-sk {sk_path}") == 0 + os.system(f"{tfhers_utils} encrypt-with-key --value={ct2} -c {ct2_path} --lwe-sk {sk_path}") + == 0 ) # import ciphertexts and run @@ -853,7 +977,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( random_value = np.random.randint(*tfhers_value_range) assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v {random_value} -c {random_ct_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key --value={random_value} -c {random_ct_path} --client-key {client_key_path}" ) == 0 ) @@ -898,7 +1022,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**6], "status": "encrypted"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y", ), pytest.param( @@ -908,7 +1032,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**6], "status": "clear"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + clear(y)", ), pytest.param( @@ -918,7 +1042,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [2**5, 2**6], "status": "encrypted"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + y big values", ), pytest.param( @@ -928,7 +1052,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [2**5, 2**6], "status": "clear"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x + clear(y) big values", ), pytest.param( @@ -938,7 +1062,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**4], "status": "encrypted"}, }, [0, 2**3], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - y", ), pytest.param( @@ -948,7 +1072,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**4], "status": "clear"}, }, [0, 2**3], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x - clear(y)", ), pytest.param( @@ -958,7 +1082,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**3], "status": "encrypted"}, }, [0, 2**2], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * y", ), pytest.param( @@ -968,7 +1092,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**3], "status": "clear"}, }, [0, 2**2], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="x * clear(y)", ), pytest.param( @@ -978,7 +1102,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( "y": {"range": [0, 2**6], "status": "encrypted"}, }, [0, 2**6], - TFHERS_INT_8_3_2_4096, + TFHERS_UINT_8_3_2_4096, id="lut_add_lut(x , y)", ), ], @@ -1064,7 +1188,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_tfhers_keygen( assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v {pt1} -c {ct1_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key --value={pt1} -c {ct1_path} --client-key {client_key_path}" ) == 0 ) @@ -1109,7 +1233,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_tfhers_keygen( random_value = np.random.randint(*tfhers_value_range) assert ( os.system( - f"{tfhers_utils} encrypt-with-key -v {random_value} -c {random_ct_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key --value={random_value} -c {random_ct_path} --client-key {client_key_path}" ) == 0 ) @@ -1178,6 +1302,15 @@ def test_tfhers_integer_eq(lhs, rhs, is_equal): [2, 3, 1, 0, 0, 0, 0, 0], ], ), + pytest.param( + tfhers.int8_2_2, + [-128, 0, 127], + [ + [0, 0, 0, 2], + [0, 0, 0, 0], + [3, 3, 3, 1], + ], + ), ], ) def test_tfhers_integer_encode(dtype, value, encoded): @@ -1223,6 +1356,15 @@ def test_tfhers_integer_bad_encode(dtype, value, expected_error, expected_messag ], [10, 20, 30], ), + pytest.param( + tfhers.int8_2_2, + [ + [2, 1, 0, 2], + [0, 3, 1, 0], + [2, 1, 0, 1], + ], + [-122, 28, 70], + ), ], ) def test_tfhers_integer_decode(dtype, encoded, decoded): diff --git a/frontends/concrete-python/tests/tfhers-utils/src/main.rs b/frontends/concrete-python/tests/tfhers-utils/src/main.rs index 5f363f1bd3..07095d9e1c 100644 --- a/frontends/concrete-python/tests/tfhers-utils/src/main.rs +++ b/frontends/concrete-python/tests/tfhers-utils/src/main.rs @@ -4,9 +4,9 @@ use serde::de::DeserializeOwned; use std::fs; use std::path::Path; use tfhe::core_crypto::prelude::LweSecretKey; -use tfhe::{prelude::*}; +use tfhe::prelude::*; use tfhe::shortint::{ClassicPBSParameters, EncryptionKeyChoice}; -use tfhe::{generate_keys, set_server_key, ClientKey, ConfigBuilder, FheUint8, ServerKey}; +use tfhe::{generate_keys, set_server_key, ClientKey, ConfigBuilder, FheInt8, FheUint8, ServerKey}; use serde::Serialize; use tfhe::named::Named; @@ -48,37 +48,61 @@ fn set_server_key_from_file(path: &String) { set_server_key(sk); } -fn encrypt_with_key(value: u8, client_key: ClientKey, path: &String) { +fn encrypt_with_key_u8(value: u8, client_key: ClientKey, ciphertext_path: &String) { let ct = FheUint8::encrypt(value, &client_key); - safe_save(path, &ct) + safe_save(ciphertext_path, &ct) +} + +fn encrypt_with_key_i8(value: i8, client_key: ClientKey, ciphertext_path: &String) { + let ct = FheInt8::encrypt(value, &client_key); + safe_save(ciphertext_path, &ct) } fn decrypt_with_key( client_key: ClientKey, ciphertext_path: &String, plaintext_path: Option<&String>, + signed: bool, ) { - let fheuint: FheUint8 = safe_load(ciphertext_path); - let result: u8 = fheuint.decrypt(&client_key); + let string_result: String; + + if signed { + let fheint: FheInt8 = safe_load(ciphertext_path); + let result: i8 = fheint.decrypt(&client_key); + string_result = result.to_string(); + } else { + let fheuint: FheUint8 = safe_load(ciphertext_path); + let result: u8 = fheuint.decrypt(&client_key); + string_result = result.to_string(); + } if let Some(path) = plaintext_path { let pt_path: &Path = Path::new(path); - fs::write(pt_path, result.to_string()).unwrap(); + fs::write(pt_path, string_result).unwrap(); } else { - println!("result: {}", result); + println!("result: {}", string_result); } } -fn sum(cts_paths: Vec<&String>, out_ct_path: &String) { +fn sum(cts_paths: Vec<&String>, out_ct_path: &String, signed: bool) { if cts_paths.is_empty() { panic!("can't call sum with 0 ciphertexts"); } - let mut acc: FheUint8 = safe_load(cts_paths[0]); - for ct_path in cts_paths[1..].iter() { - let fheuint: FheUint8 = safe_load(ct_path); - acc += fheuint; + if signed { + let mut acc: FheInt8 = safe_load(cts_paths[0]); + for ct_path in cts_paths[1..].iter() { + let fheuint: FheInt8 = safe_load(ct_path); + acc += fheuint; + } + safe_save(out_ct_path, &acc) + } else { + let mut acc: FheUint8 = safe_load(cts_paths[0]); + for ct_path in cts_paths[1..].iter() { + let fheuint: FheUint8 = safe_load(ct_path); + acc += fheuint; + } + safe_save(out_ct_path, &acc) } - safe_save(out_ct_path, &acc) } fn write_keys( @@ -102,11 +126,7 @@ fn write_keys( } } -fn keygen( - client_key_path: &String, - server_key_path: &String, - output_lwe_path: &String, -) { +fn keygen(client_key_path: &String, server_key_path: &String, output_lwe_path: &String) { let config = ConfigBuilder::with_custom_parameters(BLOCK_PARAMS).build(); let (client_key, server_key) = generate_keys(config); @@ -153,6 +173,12 @@ fn main() { .required(true) .num_args(1), ) + .arg( + Arg::new("signed") + .long("signed") + .help("encrypt as a signed integer") + .action(ArgAction::SetTrue), + ) .arg( Arg::new("ciphertext") .short('c') @@ -186,6 +212,12 @@ fn main() { .short_flag('d') .long_flag("decrypt") .about("Decrypt a ciphertext with a given key.") + .arg( + Arg::new("signed") + .long("signed") + .help("decrypt as a signed integer") + .action(ArgAction::SetTrue), + ) .arg( Arg::new("ciphertext") .short('c') @@ -236,6 +268,12 @@ fn main() { .action(ArgAction::Set) .num_args(1), ) + .arg( + Arg::new("signed") + .long("signed") + .help("consider ciphertexts as signed integers") + .action(ArgAction::SetTrue), + ) .arg( Arg::new("ciphertexts") .short('c') @@ -300,8 +338,8 @@ fn main() { match matches.subcommand() { Some(("encrypt-with-key", encrypt_matches)) => { let value_str = encrypt_matches.get_one::("value").unwrap(); - let value: u8 = value_str.parse().unwrap(); let ciphertext_path = encrypt_matches.get_one::("ciphertext").unwrap(); + let signed = encrypt_matches.get_flag("signed"); let client_key: ClientKey; if let Some(lwe_sk_path) = encrypt_matches.get_one::("lwe-sk") { @@ -312,11 +350,18 @@ fn main() { panic!("no key specified"); } - encrypt_with_key(value, client_key, ciphertext_path) + if signed { + let value: i8 = value_str.parse().unwrap(); + encrypt_with_key_i8(value, client_key, ciphertext_path) + } else { + let value: u8 = value_str.parse().unwrap(); + encrypt_with_key_u8(value, client_key, ciphertext_path) + } } Some(("decrypt-with-key", decrypt_mtches)) => { let ciphertext_path = decrypt_mtches.get_one::("ciphertext").unwrap(); let plaintext_path = decrypt_mtches.get_one::("plaintext"); + let signed = decrypt_mtches.get_flag("signed"); let client_key: ClientKey; if let Some(lwe_sk_path) = decrypt_mtches.get_one::("lwe-sk") { @@ -326,16 +371,17 @@ fn main() { } else { panic!("no key specified"); } - decrypt_with_key(client_key, ciphertext_path, plaintext_path) + decrypt_with_key(client_key, ciphertext_path, plaintext_path, signed) } Some(("add", add_mtches)) => { let server_key_path = add_mtches.get_one::("server-key").unwrap(); let cts_path = add_mtches.get_many::("ciphertexts").unwrap(); let output_ct_path = add_mtches.get_one::("output-ciphertext").unwrap(); + let signed = add_mtches.get_flag("signed"); set_server_key_from_file(server_key_path); - sum(cts_path.collect(), output_ct_path) + sum(cts_path.collect(), output_ct_path, signed) } Some(("keygen", keygen_mtches)) => { let client_key_path = keygen_mtches.get_one::("client-key").unwrap();