Skip to content

Commit

Permalink
refactor(frontend/compiler): single API for import/export of TFHErs int
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Oct 25, 2024
1 parent b0e7c08 commit 514fe62
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 232 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,13 @@ using concretelang::values::Value;
namespace concretelang {
namespace clientlib {

Result<TransportValue>
importTfhersFheUint8(llvm::ArrayRef<uint8_t> serializedFheUint8,
TfhersFheIntDescription desc, uint32_t encryptionKeyId,
double encryptionVariance);
Result<std::vector<uint8_t>> exportTfhersFheUint8(TransportValue value,
TfhersFheIntDescription info);
Result<TransportValue>
importTfhersFheInt8(llvm::ArrayRef<uint8_t> serializedFheUint8,
TfhersFheIntDescription desc, uint32_t encryptionKeyId,
double encryptionVariance);
Result<std::vector<uint8_t>> exportTfhersFheInt8(TransportValue value,
TfhersFheIntDescription info);
Result<TransportValue> importTfhersInteger(llvm::ArrayRef<uint8_t> buffer,
TfhersFheIntDescription integerDesc,
uint32_t encryptionKeyId,
double encryptionVariance);

Result<std::vector<uint8_t>>
exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc);

class ClientCircuit {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1881,49 +1881,24 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
"Return the `circuit` ClientCircuit.", arg("circuit"))
.doc() = "Client-side / Encryption program";

m.def("import_tfhers_fheuint8",
m.def("import_tfhers_int",
[](const pybind11::bytes &serialized_fheuint,
TfhersFheIntDescription info, uint32_t encryptionKeyId,
double encryptionVariance) {
const std::string &buffer_str = serialized_fheuint;
std::vector<uint8_t> buffer(buffer_str.begin(), buffer_str.end());
auto arrayRef = llvm::ArrayRef<uint8_t>(buffer);
auto valueOrError = ::concretelang::clientlib::importTfhersFheUint8(
auto valueOrError = ::concretelang::clientlib::importTfhersInteger(
arrayRef, info, encryptionKeyId, encryptionVariance);
if (valueOrError.has_error()) {
throw std::runtime_error(valueOrError.error().mesg);
}
return TransportValue{valueOrError.value()};
});

m.def("export_tfhers_fheuint8",
[](TransportValue fheuint, TfhersFheIntDescription info) {
auto result =
::concretelang::clientlib::exportTfhersFheUint8(fheuint, info);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
});

m.def("import_tfhers_fheint8",
[](const pybind11::bytes &serialized_fheuint,
TfhersFheIntDescription info, uint32_t encryptionKeyId,
double encryptionVariance) {
const std::string &buffer_str = serialized_fheuint;
std::vector<uint8_t> buffer(buffer_str.begin(), buffer_str.end());
auto arrayRef = llvm::ArrayRef<uint8_t>(buffer);
auto valueOrError = ::concretelang::clientlib::importTfhersFheInt8(
arrayRef, info, encryptionKeyId, encryptionVariance);
if (valueOrError.has_error()) {
throw std::runtime_error(valueOrError.error().mesg);
}
return TransportValue{valueOrError.value()};
});

m.def("export_tfhers_fheint8", [](TransportValue fheuint,
TfhersFheIntDescription info) {
auto result = ::concretelang::clientlib::exportTfhersFheInt8(fheuint, info);
m.def("export_tfhers_int", [](TransportValue fheuint,
TfhersFheIntDescription info) {
auto result = ::concretelang::clientlib::exportTfhersInteger(fheuint, info);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
# pylint: disable=no-name-in-module,import-error,

from mlir._mlir_libs._concretelang._compiler import (
import_tfhers_fheuint8 as _import_tfhers_fheuint8,
export_tfhers_fheuint8 as _export_tfhers_fheuint8,
import_tfhers_fheint8 as _import_tfhers_fheint8,
export_tfhers_fheint8 as _export_tfhers_fheint8,
import_tfhers_int as _import_tfhers_int,
export_tfhers_int as _export_tfhers_int,
TfhersFheIntDescription as _TfhersFheIntDescription,
TransportValue,
)
Expand Down Expand Up @@ -184,7 +182,7 @@ class TfhersExporter:
"""A helper class to import and export TFHErs big integers."""

@staticmethod
def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> bytes:
def export_int(value: TransportValue, info: TfhersFheIntDescription) -> bytes:
"""Convert Concrete value to TFHErs and serialize it.
Args:
Expand All @@ -195,24 +193,24 @@ def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> byt
TypeError: if wrong input types
Returns:
bytes: converted and serialized fheuint8
bytes: converted and serialized TFHErs integer
"""
if not isinstance(value, TransportValue):
raise TypeError(f"value must be of type TransportValue, not {type(value)}")
if not isinstance(info, TfhersFheIntDescription):
raise TypeError(
f"info must be of type TfhersFheIntDescription, not {type(info)}"
)
return bytes(_export_tfhers_fheuint8(value, info.cpp()))
return bytes(_export_tfhers_int(value, info.cpp()))

@staticmethod
def import_fheuint8(
def import_int(
buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float
) -> TransportValue:
"""Unserialize and convert from TFHErs to Concrete value.
Args:
buffer (bytes): serialized fheuint8
buffer (bytes): serialized TFHErs integer
info (TfhersFheIntDescription): description of the TFHErs integer to import
keyid (int): id of the key used for encryption
variance (float): variance used for encryption
Expand All @@ -233,56 +231,4 @@ def import_fheuint8(
raise TypeError(f"keyid must be of type int, not {type(keyid)}")
if not isinstance(variance, float):
raise TypeError(f"variance must be of type float, not {type(variance)}")
return _import_tfhers_fheuint8(buffer, info.cpp(), keyid, variance)

@staticmethod
def export_fheint8(value: TransportValue, info: TfhersFheIntDescription) -> bytes:
"""Convert Concrete value to TFHErs and serialize it.
Args:
value (Value): value to export
info (TfhersFheIntDescription): description of the TFHErs integer to export to
Raises:
TypeError: if wrong input types
Returns:
bytes: converted and serialized fheuint8
"""
if not isinstance(value, TransportValue):
raise TypeError(f"value must be of type Value, not {type(value)}")
if not isinstance(info, TfhersFheIntDescription):
raise TypeError(
f"info must be of type TfhersFheIntDescription, not {type(info)}"
)
return bytes(_export_tfhers_fheint8(value, info.cpp()))

@staticmethod
def import_fheint8(
buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float
) -> TransportValue:
"""Unserialize and convert from TFHErs to Concrete value.
Args:
buffer (bytes): serialized fheuint8
info (TfhersFheIntDescription): description of the TFHErs integer to import
keyid (int): id of the key used for encryption
variance (float): variance used for encryption
Raises:
TypeError: if wrong input types
Returns:
Value: unserialized and converted value
"""
if not isinstance(buffer, bytes):
raise TypeError(f"buffer must be of type bytes, not {type(buffer)}")
if not isinstance(info, TfhersFheIntDescription):
raise TypeError(
f"info must be of type TfhersFheIntDescription, not {type(info)}"
)
if not isinstance(keyid, int):
raise TypeError(f"keyid must be of type int, not {type(keyid)}")
if not isinstance(variance, float):
raise TypeError(f"variance must be of type float, not {type(variance)}")
return _import_tfhers_fheint8(buffer, info.cpp(), keyid, variance)
return _import_tfhers_int(buffer, info.cpp(), keyid, variance)
153 changes: 56 additions & 97 deletions compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,92 +187,34 @@ Result<ClientCircuit> ClientProgram::getClientCircuit(std::string circuitName) {
"`");
}

Result<TransportValue>
importTfhersFheUint8(llvm::ArrayRef<uint8_t> serializedFheUint8,
TfhersFheIntDescription desc, uint32_t encryptionKeyId,
double encryptionVariance) {
if (desc.width != 8 || desc.is_signed == true) {
return StringError(
"trying to import FheUint8 but description doesn't match this type");
}

auto dims = std::vector({desc.n_cts, desc.lwe_size});
auto outputTensor = Tensor<uint64_t>::fromDimensions(dims);
auto err = concrete_cpu_tfhers_uint8_to_lwe_array(
serializedFheUint8.data(), serializedFheUint8.size(),
outputTensor.values.data(), desc);
if (err) {
return StringError("couldn't convert fheuint to lwe array: err()")
<< err << ")";
}

auto value = Value{outputTensor}.intoRawTransportValue();
auto lwe = value.asBuilder().initTypeInfo().initLweCiphertext();
lwe.setIntegerPrecision(64);
// dimensions
lwe.initAbstractShape().setDimensions({(uint32_t)desc.n_cts});
lwe.initConcreteShape().setDimensions(
{(uint32_t)desc.n_cts, (uint32_t)desc.lwe_size});
// encryption
auto encryption = lwe.initEncryption();
encryption.setLweDimension((uint32_t)desc.lwe_size - 1);
encryption.initModulus().initMod().initNative();
encryption.setKeyId(encryptionKeyId);
encryption.setVariance(encryptionVariance);
// Encoding
auto encoding = lwe.initEncoding();
auto integer = encoding.initInteger();
integer.setIsSigned(false);
integer.setWidth(std::log2(desc.message_modulus * desc.carry_modulus));
integer.initMode().initNative();

return value;
}

Result<std::vector<uint8_t>>
exportTfhersFheUint8(TransportValue value, TfhersFheIntDescription desc) {
if (desc.width != 8 || desc.is_signed == true) {
return StringError(
"trying to export FheUint8 but description doesn't match this type");
}

auto fheuint = Value::fromRawTransportValue(value);
if (fheuint.isScalar()) {
return StringError("expected a tensor, but value is a scalar");
}
auto tensorOrError = fheuint.getTensor<uint64_t>();
if (!tensorOrError.has_value()) {
return StringError("couldn't get tensor from value");
}
const size_t bufferSize =
concrete_cpu_tfhers_fheint_buffer_size_u64(desc.lwe_size, desc.n_cts);
std::vector<uint8_t> buffer(bufferSize, 0);
auto flatData = tensorOrError.value().values;
auto size = concrete_cpu_lwe_array_to_tfhers_uint8(
flatData.data(), buffer.data(), buffer.size(), desc);
if (size == 0) {
return StringError("couldn't convert lwe array to fheuint8");
}
// we truncate to the serialized data
assert(size <= buffer.size());
buffer.resize(size, 0);
return buffer;
}

Result<TransportValue>
importTfhersFheInt8(llvm::ArrayRef<uint8_t> serializedFheUint8,
TfhersFheIntDescription desc, uint32_t encryptionKeyId,
double encryptionVariance) {
if (desc.width != 8 || desc.is_signed == false) {
return StringError(
"trying to import FheInt8 but description doesn't match this type");
Result<TransportValue> importTfhersInteger(llvm::ArrayRef<uint8_t> buffer,
TfhersFheIntDescription integerDesc,
uint32_t encryptionKeyId,
double encryptionVariance) {

// Select conversion function based on integer description
std::function<int64_t(const uint8_t *, size_t, uint64_t *,
TfhersFheIntDescription)>
conversion_func;
if (integerDesc.width == 8) {
if (integerDesc.is_signed) { // fheint8
conversion_func = concrete_cpu_tfhers_int8_to_lwe_array;
} else { // fheuint8
conversion_func = concrete_cpu_tfhers_uint8_to_lwe_array;
}
} else {
std::ostringstream stringStream;
stringStream << "importTfhersInteger: no support for " << integerDesc.width
<< "bits " << (integerDesc.is_signed ? "signed" : "unsigned")
<< " integer";
std::string errorMsg = stringStream.str();
return StringError(errorMsg);
}

auto dims = std::vector({desc.n_cts, desc.lwe_size});
auto dims = std::vector({integerDesc.n_cts, integerDesc.lwe_size});
auto outputTensor = Tensor<uint64_t>::fromDimensions(dims);
auto err = concrete_cpu_tfhers_int8_to_lwe_array(
serializedFheUint8.data(), serializedFheUint8.size(),
outputTensor.values.data(), desc);
auto err = conversion_func(buffer.data(), buffer.size(),
outputTensor.values.data(), integerDesc);
if (err) {
return StringError("couldn't convert fheint to lwe array");
}
Expand All @@ -281,30 +223,47 @@ importTfhersFheInt8(llvm::ArrayRef<uint8_t> serializedFheUint8,
auto lwe = value.asBuilder().initTypeInfo().initLweCiphertext();
lwe.setIntegerPrecision(64);
// dimensions
lwe.initAbstractShape().setDimensions({(uint32_t)desc.n_cts});
lwe.initAbstractShape().setDimensions({(uint32_t)integerDesc.n_cts});
lwe.initConcreteShape().setDimensions(
{(uint32_t)desc.n_cts, (uint32_t)desc.lwe_size});
{(uint32_t)integerDesc.n_cts, (uint32_t)integerDesc.lwe_size});
// encryption
auto encryption = lwe.initEncryption();
encryption.setLweDimension((uint32_t)desc.lwe_size - 1);
encryption.setLweDimension((uint32_t)integerDesc.lwe_size - 1);
encryption.initModulus().initMod().initNative();
encryption.setKeyId(encryptionKeyId);
encryption.setVariance(encryptionVariance);
// Encoding
auto encoding = lwe.initEncoding();
auto integer = encoding.initInteger();
integer.setIsSigned(false);
integer.setWidth(std::log2(desc.message_modulus * desc.carry_modulus));
integer.setIsSigned(
false); // should always be unsigned as its for the radix encoded cts
integer.setWidth(
std::log2(integerDesc.message_modulus * integerDesc.carry_modulus));
integer.initMode().initNative();

return value;
}

Result<std::vector<uint8_t>> exportTfhersFheInt8(TransportValue value,
TfhersFheIntDescription desc) {
if (desc.width != 8 || desc.is_signed == false) {
return StringError(
"trying to export FheInt8 but description doesn't match this type");
Result<std::vector<uint8_t>>
exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc) {
// Select conversion function based on integer description
std::function<size_t(const uint64_t *, uint8_t *, size_t,
TfhersFheIntDescription)>
conversion_func;
std::function<size_t(size_t, size_t)> buffer_size_func;
if (integerDesc.width == 8) {
if (integerDesc.is_signed) { // fheint8
conversion_func = concrete_cpu_lwe_array_to_tfhers_int8;
} else { // fheuint8
conversion_func = concrete_cpu_lwe_array_to_tfhers_uint8;
}
} else {
std::ostringstream stringStream;
stringStream << "exportTfhersInteger: no support for " << integerDesc.width
<< "bits " << (integerDesc.is_signed ? "signed" : "unsigned")
<< " integer";
std::string errorMsg = stringStream.str();
return StringError(errorMsg);
}

auto fheuint = Value::fromRawTransportValue(value);
Expand All @@ -315,12 +274,12 @@ Result<std::vector<uint8_t>> exportTfhersFheInt8(TransportValue value,
if (!tensorOrError.has_value()) {
return StringError("couldn't get tensor from value");
}
size_t buffer_size =
concrete_cpu_tfhers_fheint_buffer_size_u64(desc.lwe_size, desc.n_cts);
size_t buffer_size = concrete_cpu_tfhers_fheint_buffer_size_u64(
integerDesc.lwe_size, integerDesc.n_cts);
std::vector<uint8_t> buffer(buffer_size, 0);
auto flat_data = tensorOrError.value().values;
auto size = concrete_cpu_lwe_array_to_tfhers_int8(
flat_data.data(), buffer.data(), buffer.size(), desc);
auto size = conversion_func(flat_data.data(), buffer.data(), buffer.size(),
integerDesc);
if (size == 0) {
return StringError("couldn't convert lwe array to fheint8");
}
Expand Down
Loading

0 comments on commit 514fe62

Please sign in to comment.