Skip to content

Commit

Permalink
refactor(compiler): get fheint description from type instead of buffer
Browse files Browse the repository at this point in the history
we were doing a deserialization previously to get the fheint
description, but we will now construct it from the type instead. It's
still possible for the user to get the description from the buffer and
use it for import (using the Compiler API).
  • Loading branch information
youben11 committed Sep 3, 2024
1 parent 8168d07 commit feac6e8
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ namespace clientlib {

Result<TransportValue>
importTfhersFheUint8(llvm::ArrayRef<uint8_t> serializedFheUint8,
uint32_t encryptionKeyId, double encryptionVariance);
TfhersFheIntDescription desc, uint32_t encryptionKeyId,
double encryptionVariance);
Result<std::vector<uint8_t>> exportTfhersFheUint8(TransportValue value,
TfhersFheIntDescription info);
Result<TfhersFheIntDescription>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,13 +728,14 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
m.def("check_cuda_device_available", &checkCudaDeviceAvailable);

m.def("import_tfhers_fheuint8",
[](const pybind11::bytes &serialized_fheuint, uint32_t encryptionKeyId,
[](const pybind11::bytes &serialized_fheuint,
TfhersFheIntDescription info, uint32_t encryptionKeyId,
double encryptionVariance) {
const std::string &buffer_str = serialized_fheuint;
std::vector<uint8_t> buffer(buffer_str.begin(), buffer_str.end());
auto arrayRef = llvm::ArrayRef<uint8_t>(buffer);
auto valueOrError = ::concretelang::clientlib::importTfhersFheUint8(
arrayRef, encryptionKeyId, encryptionVariance);
arrayRef, info, encryptionKeyId, encryptionVariance);
if (valueOrError.has_error()) {
throw std::runtime_error(valueOrError.error().mesg);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,14 @@ def export_fheuint8(value: Value, info: TfhersFheIntDescription) -> bytes:
return bytes(_export_tfhers_fheuint8(value.cpp(), info.cpp()))

@staticmethod
def import_fheuint8(buffer: bytes, keyid: int, variance: float) -> Value:
def import_fheuint8(
buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float
) -> Value:
"""Unserialize and convert from TFHErs to Concrete value.
Args:
buffer (bytes): serialized fheuint8
info (TfhersFheIntDescription): description of the TFHErs integer to import
keyid (int): id of the key used for encryption
variance (float): variance used for encryption
Expand All @@ -238,8 +241,12 @@ def import_fheuint8(buffer: bytes, keyid: int, variance: float) -> Value:
"""
if not isinstance(buffer, bytes):
raise TypeError(f"buffer must be of type bytes, not {type(buffer)}")
if not isinstance(info, TfhersFheIntDescription):
raise TypeError(
f"info must be of type TfhersFheIntDescription, not {type(info)}"
)
if not isinstance(keyid, int):
raise TypeError(f"keyid must be of type int, not {type(keyid)}")
if not isinstance(variance, float):
raise TypeError(f"variance must be of type float, not {type(variance)}")
return Value.wrap(_import_tfhers_fheuint8(buffer, keyid, variance))
return Value.wrap(_import_tfhers_fheuint8(buffer, info.cpp(), keyid, variance))
27 changes: 16 additions & 11 deletions compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,14 @@ getTfhersFheUint8Description(llvm::ArrayRef<uint8_t> serializedFheUint8) {

Result<TransportValue>
importTfhersFheUint8(llvm::ArrayRef<uint8_t> serializedFheUint8,
uint32_t encryptionKeyId, double encryptionVariance) {
auto fheUintInfoOrError = getTfhersFheUint8Description(serializedFheUint8);
if (fheUintInfoOrError.has_error()) {
return fheUintInfoOrError.error();
TfhersFheIntDescription desc, uint32_t encryptionKeyId,
double encryptionVariance) {
if (desc.width != 8 || desc.is_signed == true) {
return StringError(
"trying to import FheUint8 but description doesn't match this type");
}
auto fheUintDesc = fheUintInfoOrError.value();
auto dims = std::vector({fheUintDesc.n_cts, fheUintDesc.lwe_size});

auto dims = std::vector({desc.n_cts, desc.lwe_size});
auto outputTensor = Tensor<uint64_t>::fromDimensions(dims);
auto err = concrete_cpu_tfhers_uint8_to_lwe_array(serializedFheUint8.data(),
serializedFheUint8.size(),
Expand All @@ -160,28 +161,32 @@ importTfhersFheUint8(llvm::ArrayRef<uint8_t> serializedFheUint8,
auto lwe = value.asBuilder().initTypeInfo().initLweCiphertext();
lwe.setIntegerPrecision(64);
// dimensions
lwe.initAbstractShape().setDimensions({(uint32_t)fheUintDesc.n_cts});
lwe.initAbstractShape().setDimensions({(uint32_t)desc.n_cts});
lwe.initConcreteShape().setDimensions(
{(uint32_t)fheUintDesc.n_cts, (uint32_t)fheUintDesc.lwe_size});
{(uint32_t)desc.n_cts, (uint32_t)desc.lwe_size});
// encryption
auto encryption = lwe.initEncryption();
encryption.setLweDimension((uint32_t)fheUintDesc.lwe_size - 1);
encryption.setLweDimension((uint32_t)desc.lwe_size - 1);
encryption.initModulus().initMod().initNative();
encryption.setKeyId(encryptionKeyId);
encryption.setVariance(encryptionVariance);
// Encoding
auto encoding = lwe.initEncoding();
auto integer = encoding.initInteger();
integer.setIsSigned(false);
integer.setWidth(
std::log2(fheUintDesc.message_modulus * fheUintDesc.carry_modulus));
integer.setWidth(std::log2(desc.message_modulus * desc.carry_modulus));
integer.initMode().initNative();

return value;
}

Result<std::vector<uint8_t>>
exportTfhersFheUint8(TransportValue value, TfhersFheIntDescription desc) {
if (desc.width != 8 || desc.is_signed == true) {
return StringError(
"trying to export FheUint8 but description doesn't match this type");
}

auto fheuint = Value::fromRawTransportValue(value);
if (fheuint.isScalar()) {
return StringError("expected a tensor, but value is a scalar");
Expand Down
63 changes: 38 additions & 25 deletions frontends/concrete-python/concrete/fhe/tfhers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,37 @@ def _input_variance(self, input_idx: int) -> float:
raise ValueError(msg)
return input_type.params.encryption_variance()

@staticmethod
def _description_from_type(
tfhers_int_type: TFHERSIntegerType,
) -> TfhersFheIntDescription:
"""Construct a TFHErs integer description based on type."""

bit_width = tfhers_int_type.bit_width
signed = tfhers_int_type.is_signed
params = tfhers_int_type.params
message_modulus = 2**tfhers_int_type.msg_width
carry_modulus = 2**tfhers_int_type.carry_width
lwe_size = params.polynomial_size + 1
n_cts = bit_width // tfhers_int_type.msg_width
ks_first = params.encryption_key_choice is EncryptionKeyChoice.BIG
# maximum value using message bits as we don't use carry bits here
degree = message_modulus - 1
# this should imply running a PBS on TFHErs side
noise_level = TfhersFheIntDescription.get_unknown_noise_level()

return TfhersFheIntDescription.new(
bit_width,
signed,
message_modulus,
carry_modulus,
degree,
lwe_size,
n_cts,
noise_level,
ks_first,
)

def import_value(self, buffer: bytes, input_idx: int) -> "fhe.Value":
"""Import a serialized TFHErs integer as a Value.
Expand All @@ -82,13 +113,17 @@ def import_value(self, buffer: bytes, input_idx: int) -> "fhe.Value":
msg = "input at 'input_idx' is not a TFHErs value"
raise ValueError(msg)

fheint_desc = self._description_from_type(input_type)

bit_width = input_type.bit_width
signed = input_type.is_signed
if bit_width == 8:
if not signed:
keyid = self._input_keyid(input_idx)
variance = self._input_variance(input_idx)
return fhe.Value(TfhersExporter.import_fheuint8(buffer, keyid, variance))
return fhe.Value(
TfhersExporter.import_fheuint8(buffer, fheint_desc, keyid, variance)
)

msg = (
f"importing {'signed' if signed else 'unsigned'} integers of {bit_width}bits is not"
Expand All @@ -111,32 +146,10 @@ def export_value(self, value: "fhe.Value", output_idx: int) -> bytes:
msg = "output at 'output_idx' is not a TFHErs value"
raise ValueError(msg)

# construct a TFHErs integer description based on type
fheint_desc = self._description_from_type(output_type)

bit_width = output_type.bit_width
signed = output_type.is_signed
params = output_type.params
message_modulus = 2**output_type.msg_width
carry_modulus = 2**output_type.carry_width
lwe_size = params.polynomial_size + 1
n_cts = bit_width // output_type.msg_width
ks_first = params.encryption_key_choice is EncryptionKeyChoice.BIG
# maximum value using message bits as we don't use carry bits here
degree = message_modulus - 1
# this should imply running a PBS on TFHErs side
noise_level = TfhersFheIntDescription.get_unknown_noise_level()

fheint_desc = TfhersFheIntDescription.new(
bit_width,
signed,
message_modulus,
carry_modulus,
degree,
lwe_size,
n_cts,
noise_level,
ks_first,
)

if bit_width == 8:
if not signed:
return TfhersExporter.export_fheuint8(value.inner, fheint_desc)
Expand Down

0 comments on commit feac6e8

Please sign in to comment.