Skip to content

Commit

Permalink
Merge pull request #1112 from zama-ai/feat/signed-tfhers
Browse files Browse the repository at this point in the history
feat(frontend/compiler): support TFHErs fheint8
  • Loading branch information
youben11 authored Oct 25, 2024
2 parents 120827d + 514fe62 commit bd36a47
Show file tree
Hide file tree
Showing 11 changed files with 489 additions and 150 deletions.
10 changes: 10 additions & 0 deletions backends/concrete-cpu/implementation/include/concrete-cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,11 @@ void concrete_cpu_keyswitch_lwe_ciphertext_u64(uint64_t *ct_out,
size_t input_dimension,
size_t output_dimension);

size_t concrete_cpu_lwe_array_to_tfhers_int8(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
struct TfhersFheIntDescription fheint_desc);

size_t concrete_cpu_lwe_array_to_tfhers_uint8(const uint64_t *lwe_vec_buffer,
uint8_t *buffer,
size_t buffer_len,
Expand Down Expand Up @@ -415,6 +420,11 @@ size_t concrete_cpu_serialize_lwe_secret_key_u64(const uint64_t *lwe_sk,

size_t concrete_cpu_tfhers_fheint_buffer_size_u64(size_t lwe_size, size_t n_cts);

int64_t concrete_cpu_tfhers_int8_to_lwe_array(const uint8_t *serialized_data_ptr,
size_t serialized_data_len,
uint64_t *lwe_vec_buffer,
struct TfhersFheIntDescription desc);

int64_t concrete_cpu_tfhers_uint8_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
Expand Down
104 changes: 102 additions & 2 deletions backends/concrete-cpu/implementation/src/c_api/fheint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tfhe::integer::ciphertext::Expandable;
use tfhe::integer::IntegerCiphertext;
use tfhe::shortint::parameters::{Degree, NoiseLevel};
use tfhe::shortint::{CarryModulus, Ciphertext, MessageModulus};
use tfhe::{FheUint128, FheUint8};
use tfhe::{FheInt8, FheUint128, FheUint8};

#[repr(C)]
pub struct TfhersFheIntDescription {
Expand Down Expand Up @@ -127,6 +127,29 @@ pub fn tfhers_uint8_description(fheuint: FheUint8) -> TfhersFheIntDescription {
}
}

pub fn tfhers_int8_description(fheuint: FheInt8) -> TfhersFheIntDescription {
// get metadata from fheuint's ciphertext
let (radix, _, _) = fheuint.into_raw_parts();
let blocks = radix.blocks();
let ct = match blocks.first() {
Some(value) => &value.ct,
None => {
return TfhersFheIntDescription::zero();
}
};
TfhersFheIntDescription {
width: 8,
is_signed: true,
lwe_size: ct.lwe_size().0,
n_cts: blocks.len(),
degree: blocks[0].degree.get(),
noise_level: blocks[0].noise_level().get(),
message_modulus: blocks[0].message_modulus.0,
carry_modulus: blocks[0].carry_modulus.0,
ks_first: blocks[0].pbs_order == PBSOrder::KeyswitchBootstrap,
}
}

#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_to_lwe_array(
buffer: *const u8,
Expand Down Expand Up @@ -161,6 +184,41 @@ pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_to_lwe_array(
})
}

#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_tfhers_int8_to_lwe_array(
serialized_data_ptr: *const u8,
serialized_data_len: usize,
lwe_vec_buffer: *mut u64,
desc: TfhersFheIntDescription,
) -> i64 {
nounwind(|| {
let fheint: FheInt8 =
super::utils::safe_deserialize(serialized_data_ptr, serialized_data_len);
// TODO - Use conformance check
let fheint_desc = tfhers_int8_description(fheint.clone());
if !fheint_desc.is_similar(&desc) {
return 1;
}

// collect LWEs from fheuint
let (radix, _, _) = fheint.into_raw_parts();
let blocks = radix.blocks();
let first_ct = match blocks.first() {
Some(value) => &value.ct,
None => return 1,
};
let lwe_size = first_ct.lwe_size().0;
let n_cts = blocks.len();
// copy LWEs to C buffer. Note that lsb is cts[0]
let lwe_vector: &mut [u64] = slice::from_raw_parts_mut(lwe_vec_buffer, n_cts * lwe_size);
for (i, block) in blocks.iter().enumerate() {
lwe_vector[i * lwe_size..(i + 1) * lwe_size]
.copy_from_slice(block.ct.clone().into_container().as_slice());
}
0
})
}

#[no_mangle]
pub extern "C" fn concrete_cpu_tfhers_fheint_buffer_size_u64(
lwe_size: usize,
Expand Down Expand Up @@ -198,7 +256,7 @@ pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8(
let n_cts = fheuint_desc.n_cts;
// construct fheuint from LWEs
let lwe_vector: &[u64] = slice::from_raw_parts(lwe_vec_buffer, n_cts * lwe_size);
let mut blocks: Vec<Ciphertext> = Vec::new();
let mut blocks: Vec<Ciphertext> = Vec::with_capacity(n_cts);
for i in 0..n_cts {
let lwe_ct = LweCiphertext::<Vec<u64>>::from_container(
lwe_vector[i * lwe_size..(i + 1) * lwe_size].to_vec(),
Expand All @@ -215,3 +273,45 @@ pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8(
super::utils::safe_serialize(&fheuint, buffer, buffer_len)
})
}

#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_int8(
lwe_vec_buffer: *const u64,
buffer: *mut u8,
buffer_len: usize,
fheint_desc: TfhersFheIntDescription,
) -> usize {
nounwind(|| {
// we want to trigger a PBS on TFHErs side
assert!(
fheint_desc.noise_level == NoiseLevel::UNKNOWN.get(),
"noise_level must be unknown"
);
// we want to use the max degree as we don't track it on Concrete side
assert!(
fheint_desc.degree == fheint_desc.message_modulus - 1,
"degree must be the max value (msg_modulus - 1)"
);

let lwe_size = fheint_desc.lwe_size;
let n_cts = fheint_desc.n_cts;
// construct fheuint from LWEs
let lwe_vector: &[u64] = slice::from_raw_parts(lwe_vec_buffer, n_cts * lwe_size);
let mut blocks: Vec<Ciphertext> = Vec::with_capacity(n_cts);
for i in 0..n_cts {
let lwe_ct = LweCiphertext::<Vec<u64>>::from_container(
lwe_vector[i * lwe_size..(i + 1) * lwe_size].to_vec(),
CiphertextModulus::new_native(),
);
blocks.push(fheint_desc.ct_from_lwe(lwe_ct));
}
let fheuint = match FheInt8::from_expanded_blocks(blocks, fheint_desc.data_kind()) {
Ok(value) => value,
Err(_) => {
return 0;
}
};

super::utils::safe_serialize(&fheuint, buffer, buffer_len)
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ using concretelang::values::Value;
namespace concretelang {
namespace clientlib {

Result<TransportValue>
importTfhersFheUint8(llvm::ArrayRef<uint8_t> serializedFheUint8,
TfhersFheIntDescription desc, uint32_t encryptionKeyId,
double encryptionVariance);
Result<std::vector<uint8_t>> exportTfhersFheUint8(TransportValue value,
TfhersFheIntDescription info);
Result<TransportValue> importTfhersInteger(llvm::ArrayRef<uint8_t> buffer,
TfhersFheIntDescription integerDesc,
uint32_t encryptionKeyId,
double encryptionVariance);

Result<std::vector<uint8_t>>
exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc);

class ClientCircuit {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1881,28 +1881,27 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
"Return the `circuit` ClientCircuit.", arg("circuit"))
.doc() = "Client-side / Encryption program";

m.def("import_tfhers_fheuint8",
m.def("import_tfhers_int",
[](const pybind11::bytes &serialized_fheuint,
TfhersFheIntDescription info, uint32_t encryptionKeyId,
double encryptionVariance) {
const std::string &buffer_str = serialized_fheuint;
std::vector<uint8_t> buffer(buffer_str.begin(), buffer_str.end());
auto arrayRef = llvm::ArrayRef<uint8_t>(buffer);
auto valueOrError = ::concretelang::clientlib::importTfhersFheUint8(
auto valueOrError = ::concretelang::clientlib::importTfhersInteger(
arrayRef, info, encryptionKeyId, encryptionVariance);
if (valueOrError.has_error()) {
throw std::runtime_error(valueOrError.error().mesg);
}
return TransportValue{valueOrError.value()};
});

m.def("export_tfhers_fheuint8",
[](TransportValue fheuint, TfhersFheIntDescription info) {
auto result =
::concretelang::clientlib::exportTfhersFheUint8(fheuint, info);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
});
m.def("export_tfhers_int", [](TransportValue fheuint,
TfhersFheIntDescription info) {
auto result = ::concretelang::clientlib::exportTfhersInteger(fheuint, info);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# pylint: disable=no-name-in-module,import-error,

from mlir._mlir_libs._concretelang._compiler import (
import_tfhers_fheuint8 as _import_tfhers_fheuint8,
export_tfhers_fheuint8 as _export_tfhers_fheuint8,
import_tfhers_int as _import_tfhers_int,
export_tfhers_int as _export_tfhers_int,
TfhersFheIntDescription as _TfhersFheIntDescription,
TransportValue,
)
Expand Down Expand Up @@ -182,7 +182,7 @@ class TfhersExporter:
"""A helper class to import and export TFHErs big integers."""

@staticmethod
def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> bytes:
def export_int(value: TransportValue, info: TfhersFheIntDescription) -> bytes:
"""Convert Concrete value to TFHErs and serialize it.
Args:
Expand All @@ -193,24 +193,24 @@ def export_fheuint8(value: TransportValue, info: TfhersFheIntDescription) -> byt
TypeError: if wrong input types
Returns:
bytes: converted and serialized fheuint8
bytes: converted and serialized TFHErs integer
"""
if not isinstance(value, TransportValue):
raise TypeError(f"value must be of type TransportValue, not {type(value)}")
if not isinstance(info, TfhersFheIntDescription):
raise TypeError(
f"info must be of type TfhersFheIntDescription, not {type(info)}"
)
return bytes(_export_tfhers_fheuint8(value, info.cpp()))
return bytes(_export_tfhers_int(value, info.cpp()))

@staticmethod
def import_fheuint8(
def import_int(
buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float
) -> TransportValue:
"""Unserialize and convert from TFHErs to Concrete value.
Args:
buffer (bytes): serialized fheuint8
buffer (bytes): serialized TFHErs integer
info (TfhersFheIntDescription): description of the TFHErs integer to import
keyid (int): id of the key used for encryption
variance (float): variance used for encryption
Expand All @@ -231,4 +231,4 @@ def import_fheuint8(
raise TypeError(f"keyid must be of type int, not {type(keyid)}")
if not isinstance(variance, float):
raise TypeError(f"variance must be of type float, not {type(variance)}")
return _import_tfhers_fheuint8(buffer, info.cpp(), keyid, variance)
return _import_tfhers_int(buffer, info.cpp(), keyid, variance)
Loading

0 comments on commit bd36a47

Please sign in to comment.