diff --git a/fhevm-engine/coprocessor/src/tests/scheduling_bench.rs b/fhevm-engine/coprocessor/src/tests/scheduling_bench.rs index 976dee1c..0f7564e8 100644 --- a/fhevm-engine/coprocessor/src/tests/scheduling_bench.rs +++ b/fhevm-engine/coprocessor/src/tests/scheduling_bench.rs @@ -202,7 +202,7 @@ async fn schedule_multi_erc20() -> Result<(), Box> { "10" if i % 4 == 1 => (), // select trxa "90" if i % 4 == 2 => (), // bals - trxa "30" if i % 4 == 3 => (), // bald + trxa - s => assert!(false, "unexpected result: {} for output {i}", s), + s => panic!("unexpected result: {} for output {i}", s), } } Ok(()) diff --git a/fhevm-engine/executor/src/server.rs b/fhevm-engine/executor/src/server.rs index 0fcc9c46..b0f710ec 100644 --- a/fhevm-engine/executor/src/server.rs +++ b/fhevm-engine/executor/src/server.rs @@ -11,7 +11,7 @@ use fhevm_engine_common::{ common::FheOperation, keys::{FhevmKeys, SerializedFhevmKeys}, tfhe_ops::{current_ciphertext_version, perform_fhe_operation, try_expand_ciphertext_list}, - types::{get_ct_type, FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN}, + types::{get_ct_type, FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN}, }; use sha3::{Digest, Keccak256}; use std::{cell::Cell, collections::HashMap}; @@ -264,10 +264,9 @@ impl FhevmExecutorService { let ct = state.ciphertexts.get(h).ok_or(FhevmError::BadInputs)?; Ok(ct.expanded.clone()) } - Input::Scalar(s) if s.len() == SCALAR_LEN => { + Input::Scalar(s) => { Ok(SupportedFheCiphertexts::Scalar(s.clone())) } - _ => Err(FhevmError::BadInputs.into()), }, None => Err(FhevmError::BadInputs.into()), }) @@ -305,7 +304,7 @@ pub fn build_taskgraph_from_request( let mut produced_handles: HashMap<&Handle, usize> = HashMap::new(); // Add all computations as nodes in the graph. for computation in &req.computations { - let inputs: Result> = computation + let inputs = computation .inputs .iter() .map(|input| match &input.input { @@ -317,29 +316,29 @@ pub fn build_taskgraph_from_request( Ok(DFGTaskInput::Dependence(None)) } } - Input::Scalar(s) if s.len() == SCALAR_LEN => Ok(DFGTaskInput::Value( + Input::Scalar(s) => Ok(DFGTaskInput::Value( SupportedFheCiphertexts::Scalar(s.clone()), )), - _ => Err(FhevmError::BadInputs.into()), }, - None => Err(FhevmError::BadInputs.into()), + None => Err(SyncComputeError::BadInputs), }) - .collect(); - if let Ok(mut inputs) = inputs { - let res_handle = computation - .result_handles - .first() - .filter(|h| h.len() == HANDLE_LEN) - .ok_or(SyncComputeError::BadResultHandles)?; - let n = dfg - .add_node( - res_handle.clone(), - computation.operation, - std::mem::take(&mut inputs), - ) - .or_else(|_| Err(SyncComputeError::ComputationFailed))?; - produced_handles.insert(res_handle, n.index()); - } + .collect::, SyncComputeError>>(); + + let mut inputs = inputs?; + + let res_handle = computation + .result_handles + .first() + .filter(|h| h.len() == HANDLE_LEN) + .ok_or(SyncComputeError::BadResultHandles)?; + let n = dfg + .add_node( + res_handle.clone(), + computation.operation, + std::mem::take(&mut inputs), + ) + .or_else(|_| Err(SyncComputeError::ComputationFailed))?; + produced_handles.insert(res_handle, n.index()); } // Traverse computations and add dependences/edges as required for (index, computation) in req.computations.iter().enumerate() { diff --git a/fhevm-engine/executor/tests/scheduling_mapping.rs b/fhevm-engine/executor/tests/scheduling_mapping.rs index 48f1169b..bbb1ee18 100644 --- a/fhevm-engine/executor/tests/scheduling_mapping.rs +++ b/fhevm-engine/executor/tests/scheduling_mapping.rs @@ -178,6 +178,6 @@ async fn schedule_multi_erc20() { } } } - Resp::Error(e) => assert!(false, "error response: {}", e), + Resp::Error(e) => panic!("error response: {}", e), } } diff --git a/fhevm-engine/executor/tests/scheduling_patterns.rs b/fhevm-engine/executor/tests/scheduling_patterns.rs index 918cc52e..a484859d 100644 --- a/fhevm-engine/executor/tests/scheduling_patterns.rs +++ b/fhevm-engine/executor/tests/scheduling_patterns.rs @@ -245,11 +245,11 @@ async fn schedule_dependent_computations() { ), } } - _ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]), + _ => panic!("unexpected handle 0x{:x}", ct.handle[0]), } } } - Resp::Error(e) => assert!(false, "error response: {}", e), + Resp::Error(e) => panic!("error response: {}", e), } } @@ -539,11 +539,11 @@ async fn schedule_y_patterns() { ), } } - _ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]), + _ => panic!("unexpected handle 0x{:x}", ct.handle[0]), } } } - Resp::Error(e) => assert!(false, "error response: {}", e), + Resp::Error(e) => panic!("error response: {}", e), } } @@ -831,10 +831,10 @@ async fn schedule_diamond_reduction_dependence_pattern() { ), } } - _ => assert!(false, "unexpected handle 0x{:x}", ct.handle[0]), + _ => panic!("unexpected handle 0x{:x}", ct.handle[0]), } } } - Resp::Error(e) => assert!(false, "error response: {}", e), + Resp::Error(e) => panic!("error response: {}", e), } } diff --git a/fhevm-engine/executor/tests/sync_compute.rs b/fhevm-engine/executor/tests/sync_compute.rs index ab089ad6..250219fc 100644 --- a/fhevm-engine/executor/tests/sync_compute.rs +++ b/fhevm-engine/executor/tests/sync_compute.rs @@ -50,12 +50,12 @@ async fn get_input_ciphertext() { Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) { (Some(ct), 1) => { if ct.handle != input_handle || ct.serialization.is_empty() { - assert!(false, "response handle or ciphertext are unexpected"); + panic!("response handle or ciphertext are unexpected"); } } - _ => assert!(false, "no response"), + _ => panic!("no response"), }, - Resp::Error(e) => assert!(false, "error: {}", e), + Resp::Error(e) => panic!("error: {}", e), } } @@ -111,7 +111,7 @@ async fn compute_on_two_serialized_ciphertexts() { Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) { (Some(ct), 1) => { if ct.handle != vec![0xaa; HANDLE_LEN] { - assert!(false, "response handle is unexpected"); + panic!("response handle is unexpected"); } let ct = SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); match ct @@ -119,12 +119,12 @@ async fn compute_on_two_serialized_ciphertexts() { .as_str() { "21" => (), - s => assert!(false, "unexpected result: {}", s), + s => panic!("unexpected result: {}", s), } } - _ => assert!(false, "unexpected amount of result ciphertexts returned"), + _ => panic!("unexpected amount of result ciphertexts returned"), }, - Resp::Error(e) => assert!(false, "error response: {}", e), + Resp::Error(e) => panic!("error response: {}", e), } } @@ -178,7 +178,7 @@ async fn compute_on_compact_and_serialized_ciphertexts() { Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) { (Some(ct), 1) => { if ct.handle != vec![0xaa; HANDLE_LEN] { - assert!(false, "response handle is unexpected"); + panic!("response handle is unexpected"); } let ct = SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); match ct @@ -186,12 +186,12 @@ async fn compute_on_compact_and_serialized_ciphertexts() { .as_str() { "21" => (), - s => assert!(false, "unexpected result: {}", s), + s => panic!("unexpected result: {}", s), } } - _ => assert!(false, "unexpected amount of result ciphertexts returned"), + _ => panic!("unexpected amount of result ciphertexts returned"), }, - Resp::Error(e) => assert!(false, "error response: {}", e), + Resp::Error(e) => panic!("error response: {}", e), } } @@ -256,7 +256,7 @@ async fn compute_on_result_ciphertext() { Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.get(1), cts.ciphertexts.len()) { (Some(ct), 2) => { if ct.handle != vec![0xbb; HANDLE_LEN] { - assert!(false, "response handle is unexpected"); + panic!("response handle is unexpected"); } let ct = SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); match ct @@ -264,11 +264,61 @@ async fn compute_on_result_ciphertext() { .as_str() { "32" => (), - s => assert!(false, "unexpected result: {}", s), + s => panic!("unexpected result: {}", s), } } - _ => assert!(false, "unexpected amount of result ciphertexts returned"), + _ => panic!("unexpected amount of result ciphertexts returned"), }, - Resp::Error(e) => assert!(false, "error response: {}", e), + Resp::Error(e) => panic!("error response: {}", e), } } + +#[tokio::test] +async fn trivial_encryption_scalar_less_than_32_bytes() { + let test = get_test().await; + test.keys.set_server_key_for_current_thread(); + let mut client = FhevmExecutorClient::connect(test.server_addr.clone()) + .await + .unwrap(); + // 10 big endian + let mut triv_encrypt_input = vec![0; 31]; + triv_encrypt_input.push(10); + let sync_input1 = SyncInput { + input: Some(Input::Scalar(triv_encrypt_input)), + }; + let sync_input2 = SyncInput { + input: Some(Input::Scalar(vec![3])), + }; + let computation = SyncComputation { + operation: FheOperation::FheTrivialEncrypt.into(), + result_handles: vec![vec![0xaa; HANDLE_LEN]], + inputs: vec![sync_input1, sync_input2], + }; + let req = SyncComputeRequest { + computations: vec![computation], + compact_ciphertext_lists: vec![], + compressed_ciphertexts: vec![], + }; + let response = client.sync_compute(req).await.unwrap(); + let sync_compute_response = response.get_ref(); + let resp = sync_compute_response.resp.clone().unwrap(); + match resp { + Resp::ResultCiphertexts(cts) => match (cts.ciphertexts.first(), cts.ciphertexts.len()) { + (Some(ct), 1) => { + if ct.handle != vec![0xaa; HANDLE_LEN] { + panic!("response handle is unexpected: {:?}", ct.handle); + } + let ct = SupportedFheCiphertexts::decompress(3, &ct.serialization).unwrap(); + match ct + .decrypt(&test.as_ref().keys.client_key.clone().unwrap()) + .as_str() + { + "10" => (), + s => panic!("unexpected result: {}", s), + } + } + _ => panic!("unexpected amount of result ciphertexts returned: {}", cts.ciphertexts.len()), + }, + Resp::Error(e) => panic!("error response: {}", e), + } +} \ No newline at end of file diff --git a/fhevm-engine/fhevm-engine-common/src/types.rs b/fhevm-engine/fhevm-engine-common/src/types.rs index b23f2096..093e2d6d 100644 --- a/fhevm-engine/fhevm-engine-common/src/types.rs +++ b/fhevm-engine/fhevm-engine-common/src/types.rs @@ -744,7 +744,6 @@ impl From for i16 { pub type Handle = Vec; pub const HANDLE_LEN: usize = 32; -pub const SCALAR_LEN: usize = 32; pub fn get_ct_type(handle: &[u8]) -> Result { match handle.len() {