Skip to content

Commit

Permalink
Merge pull request #146 from zama-ai/davidk/fix-executor-scalar-support
Browse files Browse the repository at this point in the history
fix: adds arbitrary byte size scalar support
  • Loading branch information
david-zk authored Nov 28, 2024
2 parents a630534 + e9b6177 commit 4b7f71c
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 47 deletions.
2 changes: 1 addition & 1 deletion fhevm-engine/coprocessor/src/tests/scheduling_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async fn schedule_multi_erc20() -> Result<(), Box<dyn std::error::Error>> {
"10" if i % 4 == 1 => (), // select trxa
"90" if i % 4 == 2 => (), // bals - trxa
"30" if i % 4 == 3 => (), // bald + trxa
s => assert!(false, "unexpected result: {} for output {i}", s),
s => panic!("unexpected result: {} for output {i}", s),
}
}
Ok(())
Expand Down
45 changes: 22 additions & 23 deletions fhevm-engine/executor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use fhevm_engine_common::{
common::FheOperation,
keys::{FhevmKeys, SerializedFhevmKeys},
tfhe_ops::{current_ciphertext_version, perform_fhe_operation, try_expand_ciphertext_list},
types::{get_ct_type, FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN},
types::{get_ct_type, FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN},
};
use sha3::{Digest, Keccak256};
use std::{cell::Cell, collections::HashMap};
Expand Down Expand Up @@ -264,10 +264,9 @@ impl FhevmExecutorService {
let ct = state.ciphertexts.get(h).ok_or(FhevmError::BadInputs)?;
Ok(ct.expanded.clone())
}
Input::Scalar(s) if s.len() == SCALAR_LEN => {
Input::Scalar(s) => {
Ok(SupportedFheCiphertexts::Scalar(s.clone()))
}
_ => Err(FhevmError::BadInputs.into()),
},
None => Err(FhevmError::BadInputs.into()),
})
Expand Down Expand Up @@ -305,7 +304,7 @@ pub fn build_taskgraph_from_request(
let mut produced_handles: HashMap<&Handle, usize> = HashMap::new();
// Add all computations as nodes in the graph.
for computation in &req.computations {
let inputs: Result<Vec<DFGTaskInput>> = computation
let inputs = computation
.inputs
.iter()
.map(|input| match &input.input {
Expand All @@ -317,29 +316,29 @@ pub fn build_taskgraph_from_request(
Ok(DFGTaskInput::Dependence(None))
}
}
Input::Scalar(s) if s.len() == SCALAR_LEN => Ok(DFGTaskInput::Value(
Input::Scalar(s) => Ok(DFGTaskInput::Value(
SupportedFheCiphertexts::Scalar(s.clone()),
)),
_ => Err(FhevmError::BadInputs.into()),
},
None => Err(FhevmError::BadInputs.into()),
None => Err(SyncComputeError::BadInputs),
})
.collect();
if let Ok(mut inputs) = inputs {
let res_handle = computation
.result_handles
.first()
.filter(|h| h.len() == HANDLE_LEN)
.ok_or(SyncComputeError::BadResultHandles)?;
let n = dfg
.add_node(
res_handle.clone(),
computation.operation,
std::mem::take(&mut inputs),
)
.or_else(|_| Err(SyncComputeError::ComputationFailed))?;
produced_handles.insert(res_handle, n.index());
}
.collect::<Result<Vec<DFGTaskInput>, SyncComputeError>>();

let mut inputs = inputs?;

let res_handle = computation
.result_handles
.first()
.filter(|h| h.len() == HANDLE_LEN)
.ok_or(SyncComputeError::BadResultHandles)?;
let n = dfg
.add_node(
res_handle.clone(),
computation.operation,
std::mem::take(&mut inputs),
)
.or_else(|_| Err(SyncComputeError::ComputationFailed))?;
produced_handles.insert(res_handle, n.index());
}
// Traverse computations and add dependences/edges as required
for (index, computation) in req.computations.iter().enumerate() {
Expand Down
2 changes: 1 addition & 1 deletion fhevm-engine/executor/tests/scheduling_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,6 @@ async fn schedule_multi_erc20() {
}
}
}
Resp::Error(e) => assert!(false, "error response: {}", e),
Resp::Error(e) => panic!("error response: {}", e),
}
}
12 changes: 6 additions & 6 deletions fhevm-engine/executor/tests/scheduling_patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,11 @@ async fn schedule_dependent_computations() {
),
}
}
_ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]),
_ => panic!("unexpected handle 0x{:x}", ct.handle[0]),
}
}
}
Resp::Error(e) => assert!(false, "error response: {}", e),
Resp::Error(e) => panic!("error response: {}", e),
}
}

Expand Down Expand Up @@ -539,11 +539,11 @@ async fn schedule_y_patterns() {
),
}
}
_ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]),
_ => panic!("unexpected handle 0x{:x}", ct.handle[0]),
}
}
}
Resp::Error(e) => assert!(false, "error response: {}", e),
Resp::Error(e) => panic!("error response: {}", e),
}
}

Expand Down Expand Up @@ -831,10 +831,10 @@ async fn schedule_diamond_reduction_dependence_pattern() {
),
}
}
_ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]),
_ => panic!("unexpected handle 0x{:x}", ct.handle[0]),
}
}
}
Resp::Error(e) => assert!(false, "error response: {}", e),
Resp::Error(e) => panic!("error response: {}", e),
}
}
80 changes: 65 additions & 15 deletions fhevm-engine/executor/tests/sync_compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ async fn get_input_ciphertext() {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != input_handle || ct.serialization.is_empty() {
assert!(false, "response handle or ciphertext are unexpected");
panic!("response handle or ciphertext are unexpected");
}
}
_ => assert!(false, "no response"),
_ => panic!("no response"),
},
Resp::Error(e) => assert!(false, "error: {}", e),
Resp::Error(e) => panic!("error: {}", e),
}
}

Expand Down Expand Up @@ -111,20 +111,20 @@ async fn compute_on_two_serialized_ciphertexts() {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != vec![0xaa; HANDLE_LEN] {
assert!(false, "response handle is unexpected");
panic!("response handle is unexpected");
}
let ct = SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap();
match ct
.decrypt(&test.as_ref().keys.client_key.clone().unwrap())
.as_str()
{
"21" => (),
s => assert!(false, "unexpected result: {}", s),
s => panic!("unexpected result: {}", s),
}
}
_ => assert!(false, "unexpected amount of result ciphertexts returned"),
_ => panic!("unexpected amount of result ciphertexts returned"),
},
Resp::Error(e) => assert!(false, "error response: {}", e),
Resp::Error(e) => panic!("error response: {}", e),
}
}

Expand Down Expand Up @@ -178,20 +178,20 @@ async fn compute_on_compact_and_serialized_ciphertexts() {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != vec![0xaa; HANDLE_LEN] {
assert!(false, "response handle is unexpected");
panic!("response handle is unexpected");
}
let ct = SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap();
match ct
.decrypt(&test.as_ref().keys.client_key.clone().unwrap())
.as_str()
{
"21" => (),
s => assert!(false, "unexpected result: {}", s),
s => panic!("unexpected result: {}", s),
}
}
_ => assert!(false, "unexpected amount of result ciphertexts returned"),
_ => panic!("unexpected amount of result ciphertexts returned"),
},
Resp::Error(e) => assert!(false, "error response: {}", e),
Resp::Error(e) => panic!("error response: {}", e),
}
}

Expand Down Expand Up @@ -256,19 +256,69 @@ async fn compute_on_result_ciphertext() {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.get(1), cts.ciphertexts.len()) {
(Some(ct), 2) => {
if ct.handle != vec![0xbb; HANDLE_LEN] {
assert!(false, "response handle is unexpected");
panic!("response handle is unexpected");
}
let ct = SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap();
match ct
.decrypt(&test.as_ref().keys.client_key.clone().unwrap())
.as_str()
{
"32" => (),
s => assert!(false, "unexpected result: {}", s),
s => panic!("unexpected result: {}", s),
}
}
_ => assert!(false, "unexpected amount of result ciphertexts returned"),
_ => panic!("unexpected amount of result ciphertexts returned"),
},
Resp::Error(e) => assert!(false, "error response: {}", e),
Resp::Error(e) => panic!("error response: {}", e),
}
}

#[tokio::test]
async fn trivial_encryption_scalar_less_than_32_bytes() {
let test = get_test().await;
test.keys.set_server_key_for_current_thread();
let mut client = FhevmExecutorClient::connect(test.server_addr.clone())
.await
.unwrap();
// 10 big endian
let mut triv_encrypt_input = vec![0; 31];
triv_encrypt_input.push(10);
let sync_input1 = SyncInput {
input: Some(Input::Scalar(triv_encrypt_input)),
};
let sync_input2 = SyncInput {
input: Some(Input::Scalar(vec![3])),
};
let computation = SyncComputation {
operation: FheOperation::FheTrivialEncrypt.into(),
result_handles: vec![vec![0xaa; HANDLE_LEN]],
inputs: vec![sync_input1, sync_input2],
};
let req = SyncComputeRequest {
computations: vec![computation],
compact_ciphertext_lists: vec![],
compressed_ciphertexts: vec![],
};
let response = client.sync_compute(req).await.unwrap();
let sync_compute_response = response.get_ref();
let resp = sync_compute_response.resp.clone().unwrap();
match resp {
Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) {
(Some(ct), 1) => {
if ct.handle != vec![0xaa; HANDLE_LEN] {
panic!("response handle is unexpected: {:?}", ct.handle);
}
let ct = SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap();
match ct
.decrypt(&test.as_ref().keys.client_key.clone().unwrap())
.as_str()
{
"10" => (),
s => panic!("unexpected result: {}", s),
}
}
_ => panic!("unexpected amount of result ciphertexts returned: {}", cts.ciphertexts.len()),
},
Resp::Error(e) => panic!("error response: {}", e),
}
}
1 change: 0 additions & 1 deletion fhevm-engine/fhevm-engine-common/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,6 @@ impl From<SupportedFheOperations> for i16 {

pub type Handle = Vec<u8>;
pub const HANDLE_LEN: usize = 32;
pub const SCALAR_LEN: usize = 32;

pub fn get_ct_type(handle: &[u8]) -> Result<i16, FhevmError> {
match handle.len() {
Expand Down

0 comments on commit 4b7f71c

Please sign in to comment.