diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h index b59f0d7767..ea622e2ad9 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h @@ -36,7 +36,8 @@ namespace clientlib { Result importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, - uint32_t encryptionKeyId, double encryptionVariance); + TfhersFheIntDescription desc, uint32_t encryptionKeyId, + double encryptionVariance); Result> exportTfhersFheUint8(TransportValue value, TfhersFheIntDescription info); Result diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index e31c466fe7..67f05f749e 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -728,13 +728,14 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( m.def("check_cuda_device_available", &checkCudaDeviceAvailable); m.def("import_tfhers_fheuint8", - [](const pybind11::bytes &serialized_fheuint, uint32_t encryptionKeyId, + [](const pybind11::bytes &serialized_fheuint, + TfhersFheIntDescription info, uint32_t encryptionKeyId, double encryptionVariance) { const std::string &buffer_str = serialized_fheuint; std::vector buffer(buffer_str.begin(), buffer_str.end()); auto arrayRef = llvm::ArrayRef(buffer); auto valueOrError = ::concretelang::clientlib::importTfhersFheUint8( - arrayRef, encryptionKeyId, encryptionVariance); + arrayRef, info, encryptionKeyId, encryptionVariance); if (valueOrError.has_error()) { throw std::runtime_error(valueOrError.error().mesg); } diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py index 2f8bbe8e05..68855c9aa3 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py @@ -222,11 +222,14 @@ def export_fheuint8(value: Value, info: TfhersFheIntDescription) -> bytes: return bytes(_export_tfhers_fheuint8(value.cpp(), info.cpp())) @staticmethod - def import_fheuint8(buffer: bytes, keyid: int, variance: float) -> Value: + def import_fheuint8( + buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float + ) -> Value: """Unserialize and convert from TFHErs to Concrete value. Args: buffer (bytes): serialized fheuint8 + info (TfhersFheIntDescription): description of the TFHErs integer to import keyid (int): id of the key used for encryption variance (float): variance used for encryption @@ -238,8 +241,12 @@ def import_fheuint8(buffer: bytes, keyid: int, variance: float) -> Value: """ if not isinstance(buffer, bytes): raise TypeError(f"buffer must be of type bytes, not {type(buffer)}") + if not isinstance(info, TfhersFheIntDescription): + raise TypeError( + f"info must be of type TfhersFheIntDescription, not {type(info)}" + ) if not isinstance(keyid, int): raise TypeError(f"keyid must be of type int, not {type(keyid)}") if not isinstance(variance, float): raise TypeError(f"variance must be of type float, not {type(variance)}") - return Value.wrap(_import_tfhers_fheuint8(buffer, keyid, variance)) + return Value.wrap(_import_tfhers_fheuint8(buffer, info.cpp(), keyid, variance)) diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp index c325253662..87ae4a9d53 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp @@ -141,13 +141,14 @@ getTfhersFheUint8Description(llvm::ArrayRef serializedFheUint8) { Result importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, - uint32_t encryptionKeyId, double encryptionVariance) { - auto fheUintInfoOrError = getTfhersFheUint8Description(serializedFheUint8); - if (fheUintInfoOrError.has_error()) { - return fheUintInfoOrError.error(); + TfhersFheIntDescription desc, uint32_t encryptionKeyId, + double encryptionVariance) { + if (desc.width != 8 || desc.is_signed == true) { + return StringError( + "trying to import FheUint8 but description doesn't match this type"); } - auto fheUintDesc = fheUintInfoOrError.value(); - auto dims = std::vector({fheUintDesc.n_cts, fheUintDesc.lwe_size}); + + auto dims = std::vector({desc.n_cts, desc.lwe_size}); auto outputTensor = Tensor::fromDimensions(dims); auto err = concrete_cpu_tfhers_uint8_to_lwe_array(serializedFheUint8.data(), serializedFheUint8.size(), @@ -160,12 +161,12 @@ importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, auto lwe = value.asBuilder().initTypeInfo().initLweCiphertext(); lwe.setIntegerPrecision(64); // dimensions - lwe.initAbstractShape().setDimensions({(uint32_t)fheUintDesc.n_cts}); + lwe.initAbstractShape().setDimensions({(uint32_t)desc.n_cts}); lwe.initConcreteShape().setDimensions( - {(uint32_t)fheUintDesc.n_cts, (uint32_t)fheUintDesc.lwe_size}); + {(uint32_t)desc.n_cts, (uint32_t)desc.lwe_size}); // encryption auto encryption = lwe.initEncryption(); - encryption.setLweDimension((uint32_t)fheUintDesc.lwe_size - 1); + encryption.setLweDimension((uint32_t)desc.lwe_size - 1); encryption.initModulus().initMod().initNative(); encryption.setKeyId(encryptionKeyId); encryption.setVariance(encryptionVariance); @@ -173,8 +174,7 @@ importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, auto encoding = lwe.initEncoding(); auto integer = encoding.initInteger(); integer.setIsSigned(false); - integer.setWidth( - std::log2(fheUintDesc.message_modulus * fheUintDesc.carry_modulus)); + integer.setWidth(std::log2(desc.message_modulus * desc.carry_modulus)); integer.initMode().initNative(); return value; @@ -182,6 +182,11 @@ importTfhersFheUint8(llvm::ArrayRef serializedFheUint8, Result> exportTfhersFheUint8(TransportValue value, TfhersFheIntDescription desc) { + if (desc.width != 8 || desc.is_signed == true) { + return StringError( + "trying to export FheUint8 but description doesn't match this type"); + } + auto fheuint = Value::fromRawTransportValue(value); if (fheuint.isScalar()) { return StringError("expected a tensor, but value is a scalar"); diff --git a/frontends/concrete-python/concrete/fhe/tfhers/context.py b/frontends/concrete-python/concrete/fhe/tfhers/context.py index e4a8c479f5..0f0c3c34e6 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/context.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/context.py @@ -67,6 +67,37 @@ def _input_variance(self, input_idx: int) -> float: raise ValueError(msg) return input_type.params.encryption_variance() + @staticmethod + def _description_from_type( + tfhers_int_type: TFHERSIntegerType, + ) -> TfhersFheIntDescription: + """Construct a TFHErs integer description based on type.""" + + bit_width = tfhers_int_type.bit_width + signed = tfhers_int_type.is_signed + params = tfhers_int_type.params + message_modulus = 2**tfhers_int_type.msg_width + carry_modulus = 2**tfhers_int_type.carry_width + lwe_size = params.polynomial_size + 1 + n_cts = bit_width // tfhers_int_type.msg_width + ks_first = params.encryption_key_choice is EncryptionKeyChoice.BIG + # maximum value using message bits as we don't use carry bits here + degree = message_modulus - 1 + # this should imply running a PBS on TFHErs side + noise_level = TfhersFheIntDescription.get_unknown_noise_level() + + return TfhersFheIntDescription.new( + bit_width, + signed, + message_modulus, + carry_modulus, + degree, + lwe_size, + n_cts, + noise_level, + ks_first, + ) + def import_value(self, buffer: bytes, input_idx: int) -> "fhe.Value": """Import a serialized TFHErs integer as a Value. @@ -82,13 +113,17 @@ def import_value(self, buffer: bytes, input_idx: int) -> "fhe.Value": msg = "input at 'input_idx' is not a TFHErs value" raise ValueError(msg) + fheint_desc = self._description_from_type(input_type) + bit_width = input_type.bit_width signed = input_type.is_signed if bit_width == 8: if not signed: keyid = self._input_keyid(input_idx) variance = self._input_variance(input_idx) - return fhe.Value(TfhersExporter.import_fheuint8(buffer, keyid, variance)) + return fhe.Value( + TfhersExporter.import_fheuint8(buffer, fheint_desc, keyid, variance) + ) msg = ( f"importing {'signed' if signed else 'unsigned'} integers of {bit_width}bits is not" @@ -111,32 +146,10 @@ def export_value(self, value: "fhe.Value", output_idx: int) -> bytes: msg = "output at 'output_idx' is not a TFHErs value" raise ValueError(msg) - # construct a TFHErs integer description based on type + fheint_desc = self._description_from_type(output_type) + bit_width = output_type.bit_width signed = output_type.is_signed - params = output_type.params - message_modulus = 2**output_type.msg_width - carry_modulus = 2**output_type.carry_width - lwe_size = params.polynomial_size + 1 - n_cts = bit_width // output_type.msg_width - ks_first = params.encryption_key_choice is EncryptionKeyChoice.BIG - # maximum value using message bits as we don't use carry bits here - degree = message_modulus - 1 - # this should imply running a PBS on TFHErs side - noise_level = TfhersFheIntDescription.get_unknown_noise_level() - - fheint_desc = TfhersFheIntDescription.new( - bit_width, - signed, - message_modulus, - carry_modulus, - degree, - lwe_size, - n_cts, - noise_level, - ks_first, - ) - if bit_width == 8: if not signed: return TfhersExporter.export_fheuint8(value.inner, fheint_desc)