Skip to content

Commit

Permalink
Merge pull request #1103 from zama-ai/safe_serialization
Browse files Browse the repository at this point in the history
fix(backend-cpu): Use safe serialization for TFHE-rs interoperability
  • Loading branch information
BourgerieQuentin authored Oct 18, 2024
2 parents 7ff7902 + c4f0a43 commit cd8208e
Show file tree
Hide file tree
Showing 13 changed files with 361 additions and 458 deletions.
465 changes: 211 additions & 254 deletions backends/concrete-cpu/implementation/Cargo.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions backends/concrete-cpu/implementation/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ crate-type = ["lib", "staticlib"]


[dependencies]
concrete-csprng = { version = "0.4.1", optional = true, features = [
concrete-csprng = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", optional = true, features = [
"generator_fallback",
] }
concrete-cpu-noise-model = { path = "../noise-model/" }
Expand All @@ -31,16 +31,16 @@ serde = "~1"
rayon = { version = "1.6", optional = true }
once_cell = { version = "1.16", optional = true }

tfhe = { version = "0.8", features = ["integer"] }
tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer"] }

[target.x86_64-unknown-unix-gnu.dependencies]
tfhe = { version = "0.8", features = ["integer", "x86_64-unix"] }
tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer", "x86_64-unix"] }

[target.aarch64-unknown-unix-gnu.dependencies]
tfhe = { version = "0.8", features = ["integer", "aarch64-unix"] }
tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer", "aarch64-unix"] }

[target.x86_64-pc-windows-gnu.dependencies]
tfhe = { version = "0.8", features = ["integer", "x86_64"] }
tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer", "x86_64"] }

[features]
default = ["parallel", "std", "csprng"]
Expand Down
11 changes: 4 additions & 7 deletions backends/concrete-cpu/implementation/include/concrete-cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ void concrete_cpu_keyswitch_lwe_ciphertext_u64(uint64_t *ct_out,
size_t output_dimension);

size_t concrete_cpu_lwe_array_to_tfhers_uint8(const uint64_t *lwe_vec_buffer,
uint8_t *fheuint_buffer,
size_t fheuint_buffer_size,
uint8_t *buffer,
size_t buffer_len,
struct TfhersFheIntDescription fheuint_desc);

size_t concrete_cpu_lwe_ciphertext_size_u64(size_t lwe_dimension);
Expand Down Expand Up @@ -415,11 +415,8 @@ size_t concrete_cpu_serialize_lwe_secret_key_u64(const uint64_t *lwe_sk,

size_t concrete_cpu_tfhers_fheint_buffer_size_u64(size_t lwe_size, size_t n_cts);

struct TfhersFheIntDescription concrete_cpu_tfhers_uint8_description(const uint8_t *serialized_data_ptr,
size_t serialized_data_len);

int64_t concrete_cpu_tfhers_uint8_to_lwe_array(const uint8_t *serialized_data_ptr,
size_t serialized_data_len,
int64_t concrete_cpu_tfhers_uint8_to_lwe_array(const uint8_t *buffer,
size_t buffer_len,
uint64_t *lwe_vec_buffer,
struct TfhersFheIntDescription desc);

Expand Down
67 changes: 58 additions & 9 deletions backends/concrete-cpu/implementation/src/c_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ pub mod wop_pbs;
pub mod wop_pbs_simulation;

mod utils {

use serde::de::DeserializeOwned;
use serde::Serialize;
use tfhe::named::Named;
use tfhe::{Unversionize, Versionize};

#[inline]
pub fn nounwind<R>(f: impl FnOnce() -> R) -> R {
struct AbortOnDrop;
Expand All @@ -33,27 +39,70 @@ mod utils {
let _: libc::size_t = 0_usize;
};

pub unsafe fn serialize<T>(value: &T, out_buffer: *mut u8, out_buffer_len: usize) -> usize
where
T: serde::ser::Serialize,
{
let serialized_size: usize = match bincode::serialized_size(value) {
// Serialize a tfhe-rs versionable value into a buffer, returns 0 if any error
// TODO: Better error management
pub unsafe fn safe_serialize<T: Serialize + Versionize + Named>(
value: &T,
buffer: *mut u8,
buffer_len: usize,
) -> usize {
let writer = core::slice::from_raw_parts_mut(buffer, buffer_len);
let size = match tfhe::safe_serialization::safe_serialized_size(value) {
Ok(size) => {
if size > out_buffer_len as u64 {
if size > buffer_len as u64 {
return 0;
}
size as usize
}
Err(_) => return 0,
Err(_e) => return 0,
};

let write_buff: &mut [u8] = core::slice::from_raw_parts_mut(out_buffer, out_buffer_len);
match bincode::serialize_into(&mut write_buff[..], value) {
match tfhe::safe_serialization::safe_serialize(value, writer, buffer_len as u64) {
Ok(_) => size,
Err(_e) => 0,
}
}

// Deserialize a tfhe-rs versionable value from a buffer, panic if any error
// TODO: Better error management
pub unsafe fn safe_deserialize<T: DeserializeOwned + Unversionize + Named>(
buffer: *const u8,
buffer_len: usize,
) -> T {
let reader = core::slice::from_raw_parts(buffer, buffer_len);
// TODO: Fix approximation when is fixed in TFHE-rs
tfhe::safe_serialization::safe_deserialize(reader, (buffer_len + 1000) as u64).unwrap()
}

// Serialize a tfhe-rs NON-versionable value into a buffer, returns 0 if any error.
// TODO: Remove me when safe_serialization by thfe-rs is implemented for all object.
pub unsafe fn unsafe_serialize<T: Serialize>(
value: &T,
buffer: *mut u8,
buffer_len: usize,
) -> usize {
let serialized_size: usize = match bincode::serialized_size(value) {
Ok(size) if size <= buffer_len as u64 => size as usize,
_ => return 0,
};

let writer: &mut [u8] = core::slice::from_raw_parts_mut(buffer, buffer_len);
match bincode::serialize_into(&mut writer[..], value) {
Ok(_) => serialized_size,
Err(_) => 0,
}
}

// Deserialize a tfhe-rs NON-versionable value into a buffer, panic if any error
// TODO: Remove me when safe_serialization by thfe-rs is implemented for all object.
pub unsafe fn unsafe_deserialize<T: DeserializeOwned>(
buffer: *const u8,
buffer_len: usize,
) -> T {
let reader = core::slice::from_raw_parts(buffer, buffer_len);
bincode::deserialize_from(reader).unwrap()
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
55 changes: 11 additions & 44 deletions backends/concrete-cpu/implementation/src/c_api/fheint.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use super::utils::nounwind;
use core::slice;
use std::io::Cursor;
use tfhe::core_crypto::prelude::*;
use tfhe::integer::ciphertext::Expandable;
use tfhe::integer::IntegerCiphertext;
Expand Down Expand Up @@ -128,49 +127,17 @@ pub fn tfhers_uint8_description(fheuint: FheUint8) -> TfhersFheIntDescription {
}
}

#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_description(
serialized_data_ptr: *const u8,
serialized_data_len: usize,
) -> TfhersFheIntDescription {
nounwind(|| {
// deserialize fheuint8
let mut serialized_data = Cursor::new(slice::from_raw_parts(
serialized_data_ptr,
serialized_data_len,
));
let fheuint: FheUint8 = match bincode::deserialize_from(&mut serialized_data) {
Ok(value) => value,
Err(_) => {
return TfhersFheIntDescription::zero();
}
};
tfhers_uint8_description(fheuint)
})
}

#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_to_lwe_array(
serialized_data_ptr: *const u8,
serialized_data_len: usize,
buffer: *const u8,
buffer_len: usize,
lwe_vec_buffer: *mut u64,
desc: TfhersFheIntDescription,
) -> i64 {
nounwind(|| {
// deserialize fheuint8
let mut serialized_data = Cursor::new(slice::from_raw_parts(
serialized_data_ptr,
serialized_data_len,
));
// TODO: can we have a generic deserialize?
let fheuint: FheUint8 = match bincode::deserialize_from(&mut serialized_data) {
Ok(value) => value,
Err(_) => {
return 1;
}
};
let fheuint: FheUint8 = super::utils::safe_deserialize(buffer, buffer_len);
// TODO - Use conformance check
let fheuint_desc = tfhers_uint8_description(fheuint.clone());

if !fheuint_desc.is_similar(&desc) {
return 1;
}
Expand Down Expand Up @@ -199,19 +166,20 @@ pub extern "C" fn concrete_cpu_tfhers_fheint_buffer_size_u64(
lwe_size: usize,
n_cts: usize,
) -> usize {
// TODO - that is fragile
// all FheUint should have the same size, but we use a big one to be safe
let meta_fheuint = core::mem::size_of::<FheUint128>();
let meta_ct = core::mem::size_of::<Ciphertext>();

// FheUint[metadata, ciphertexts[ciphertext[metadata, lwe_buffer] * n_cts]]
meta_fheuint + (meta_ct + lwe_size * 8/*u64*/) * n_cts
// FheUint[metadata, ciphertexts[ciphertext[metadata, lwe_buffer] * n_cts]] + headers
(meta_fheuint + (meta_ct + lwe_size * 8/*u64*/) * n_cts) + 201
}

#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8(
lwe_vec_buffer: *const u64,
fheuint_buffer: *mut u8,
fheuint_buffer_size: usize,
buffer: *mut u8,
buffer_len: usize,
fheuint_desc: TfhersFheIntDescription,
) -> usize {
nounwind(|| {
Expand Down Expand Up @@ -240,11 +208,10 @@ pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8(
}
let fheuint = match FheUint8::from_expanded_blocks(blocks, fheuint_desc.data_kind()) {
Ok(value) => value,
Err(_) => {
Err(_e) => {
return 0;
}
};

super::utils::serialize(&fheuint, fheuint_buffer, fheuint_buffer_size)
super::utils::safe_serialize(&fheuint, buffer, buffer_len)
})
}
10 changes: 4 additions & 6 deletions backends/concrete-cpu/implementation/src/c_api/secret_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ pub unsafe extern "C" fn concrete_cpu_serialize_lwe_secret_key_u64(
concrete_cpu_lwe_secret_key_size_u64(lwe_dimension),
));

super::utils::serialize(&lwe_sk, out_buffer, out_buffer_len)
super::utils::unsafe_serialize(&lwe_sk, out_buffer, out_buffer_len)
}

#[no_mangle]
Expand All @@ -233,8 +233,7 @@ pub unsafe extern "C" fn concrete_cpu_unserialize_lwe_secret_key_u64(
lwe_sk: *mut u64,
lwe_sk_size: usize,
) -> usize {
let serialized_data = slice::from_raw_parts(buffer, buffer_len);
let sk: LweSecretKey<Vec<u64>> = bincode::deserialize_from(serialized_data).unwrap();
let sk: LweSecretKey<Vec<u64>> = super::utils::unsafe_deserialize(buffer, buffer_len);
let container = sk.into_container();
assert!(container.len() <= lwe_sk_size);
let lwe_sk_slice = slice::from_raw_parts_mut(lwe_sk, lwe_sk_size);
Expand All @@ -258,7 +257,7 @@ pub unsafe extern "C" fn concrete_cpu_serialize_glwe_secret_key_u64(
PolynomialSize(polynomial_size),
);

super::utils::serialize(&glwe_sk, out_buffer, out_buffer_len)
super::utils::unsafe_serialize(&glwe_sk, out_buffer, out_buffer_len)
}

#[no_mangle]
Expand All @@ -268,8 +267,7 @@ pub unsafe extern "C" fn concrete_cpu_unserialize_glwe_secret_key_u64(
glwe_sk: *mut u64,
glwe_sk_size: usize,
) -> usize {
let serialized_data = slice::from_raw_parts(buffer, buffer_len);
let sk: GlweSecretKey<Vec<u64>> = bincode::deserialize_from(serialized_data).unwrap();
let sk: GlweSecretKey<Vec<u64>> = super::utils::unsafe_deserialize(buffer, buffer_len);
assert!(sk.glwe_dimension().0 == 1);
let container = sk.into_container();
assert!(container.len() <= glwe_sk_size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ importTfhersFheUint8(llvm::ArrayRef<uint8_t> serializedFheUint8,
double encryptionVariance);
Result<std::vector<uint8_t>> exportTfhersFheUint8(TransportValue value,
TfhersFheIntDescription info);
Result<TfhersFheIntDescription>
getTfhersFheUint8Description(llvm::ArrayRef<uint8_t> serializedFheUint8);

class ClientCircuit {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1905,17 +1905,4 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
}
return result.value();
});

m.def("get_tfhers_fheuint8_description",
[](const pybind11::bytes &serialized_fheuint) {
const std::string &buffer_str = serialized_fheuint;
std::vector<uint8_t> buffer(buffer_str.begin(), buffer_str.end());
auto arrayRef = llvm::ArrayRef<uint8_t>(buffer);
auto info =
::concretelang::clientlib::getTfhersFheUint8Description(arrayRef);
if (info.has_error()) {
throw std::runtime_error(info.error().mesg);
}
return info.value();
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from mlir._mlir_libs._concretelang._compiler import (
import_tfhers_fheuint8 as _import_tfhers_fheuint8,
export_tfhers_fheuint8 as _export_tfhers_fheuint8,
get_tfhers_fheuint8_description as _get_tfhers_fheuint8_description,
TfhersFheIntDescription as _TfhersFheIntDescription,
TransportValue,
)
Expand Down Expand Up @@ -178,23 +177,6 @@ def __str__(self) -> str:
f"ks_first={self.ks_first}>"
)

@staticmethod
def from_serialized_fheuint8(buffer: bytes) -> "TfhersFheIntDescription":
"""Get the description of a serialized TFHErs fheuint8.
Args:
buffer (bytes): serialized fheuint8
Raises:
TypeError: buffer is not of type bytes
Returns:
TfhersFheIntDescription: description of the serialized fheuint8
"""
if not isinstance(buffer, bytes):
raise TypeError(f"buffer must be of type bytes, not {type(buffer)}")
return TfhersFheIntDescription.wrap(_get_tfhers_fheuint8_description(buffer))


class TfhersExporter:
"""A helper class to import and export TFHErs big integers."""
Expand Down
12 changes: 2 additions & 10 deletions compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,6 @@ Result<ClientCircuit> ClientProgram::getClientCircuit(std::string circuitName) {
"`");
}

Result<TfhersFheIntDescription>
getTfhersFheUint8Description(llvm::ArrayRef<uint8_t> serializedFheUint8) {
auto fheUintDesc = concrete_cpu_tfhers_uint8_description(
serializedFheUint8.data(), serializedFheUint8.size());
if (fheUintDesc.width == 0)
return StringError("couldn't get fheuint info");
return fheUintDesc;
}

Result<TransportValue>
importTfhersFheUint8(llvm::ArrayRef<uint8_t> serializedFheUint8,
TfhersFheIntDescription desc, uint32_t encryptionKeyId,
Expand All @@ -211,7 +202,8 @@ importTfhersFheUint8(llvm::ArrayRef<uint8_t> serializedFheUint8,
serializedFheUint8.data(), serializedFheUint8.size(),
outputTensor.values.data(), desc);
if (err) {
return StringError("couldn't convert fheuint to lwe array");
return StringError("couldn't convert fheuint to lwe array: err()")
<< err << ")";
}

auto value = Value{outputTensor}.intoRawTransportValue();
Expand Down
Loading

0 comments on commit cd8208e

Please sign in to comment.