From db74ffd73e9ef0923052f3f1a373a58db16ed311 Mon Sep 17 00:00:00 2001 From: David Kazlauskas Date: Thu, 8 Aug 2024 09:54:01 +0300 Subject: [PATCH] Binary operators test harness --- fhevm-engine/Cargo.lock | 47 +++++++ fhevm-engine/Cargo.toml | 2 + fhevm-engine/src/server.rs | 78 +++++++---- fhevm-engine/src/tests/mod.rs | 68 +++++----- fhevm-engine/src/tests/operators.rs | 203 ++++++++++++++++++++++++++++ fhevm-engine/src/tests/utils.rs | 12 +- fhevm-engine/src/tfhe_ops.rs | 172 ++++++++++++++++++++++- fhevm-engine/src/types.rs | 23 +++- proto/coprocessor.proto | 16 ++- 9 files changed, 552 insertions(+), 69 deletions(-) create mode 100644 fhevm-engine/src/tests/operators.rs diff --git a/fhevm-engine/Cargo.lock b/fhevm-engine/Cargo.lock index 2593a93f..ac16df4c 100644 --- a/fhevm-engine/Cargo.lock +++ b/fhevm-engine/Cargo.lock @@ -296,6 +296,19 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bigdecimal" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d712318a27c7150326677b321a5fa91b55f6d9034ffd67f20319e147d40cee" +dependencies = [ + "autocfg", + "libm", + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "bincode" version = "1.3.3" @@ -839,6 +852,7 @@ checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" name = "fhevm-engine" version = "0.1.0" dependencies = [ + "bigdecimal", "bincode", "clap", "hex", @@ -848,6 +862,7 @@ dependencies = [ "regex", "serde_json", "sqlx", + "strum", "testcontainers", "tfhe", "tokio", @@ -1569,6 +1584,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -2661,6 +2686,28 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.72", +] + [[package]] name = "subtle" version = "2.6.1" diff --git a/fhevm-engine/Cargo.toml b/fhevm-engine/Cargo.toml index 6e889f71..211aa64d 100644 --- a/fhevm-engine/Cargo.toml +++ b/fhevm-engine/Cargo.toml @@ -22,6 +22,8 @@ clap = { version = "4.5", features = ["derive"] } lru = "0.12.3" bincode = "1.3.3" hex = "0.4" +strum = { version = "0.26", features = ["derive"] } +bigdecimal = "0.4" [dev-dependencies] testcontainers = "0.21" diff --git a/fhevm-engine/src/server.rs b/fhevm-engine/src/server.rs index ac4b06ba..42fce0b1 100644 --- a/fhevm-engine/src/server.rs +++ b/fhevm-engine/src/server.rs @@ -1,12 +1,10 @@ -use coprocessor::DebugDecryptResponse; -use sqlx::query; -use tfhe::prelude::FheTryTrivialEncrypt; -use tfhe::FheUint32; +use coprocessor::{DebugDecryptResponse, DebugDecryptResponseSingle}; +use sqlx::{query, Acquire}; use tonic::transport::Server; use crate::db_queries::{check_if_api_key_is_valid, check_if_ciphertexts_exist_in_db}; use crate::utils::sort_computations_by_dependencies; -use crate::types::{CoprocessorError, SupportedFheCiphertexts}; -use crate::tfhe_ops::{self, check_fhe_operand_types, current_ciphertext_version}; +use crate::types::CoprocessorError; +use crate::tfhe_ops::{self, check_fhe_operand_types, current_ciphertext_version, debug_trivial_encrypt_le_bytes}; use crate::server::coprocessor::GenericResponse; pub mod coprocessor { @@ -63,20 +61,40 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ let public_key = public_key.pop().unwrap(); - let value_to_encrypt = req.original_value as u32; - let handle = req.handle.clone(); - let (db_type, db_bytes) = tokio::task::spawn_blocking(move || { + let cloned = req.values.clone(); + let out_cts = tokio::task::spawn_blocking(move || { let server_key: tfhe::ServerKey = bincode::deserialize(&public_key.sks_key).unwrap(); tfhe::set_server_key(server_key); - let encrypted = FheUint32::try_encrypt_trivial(value_to_encrypt).unwrap(); - SupportedFheCiphertexts::FheUint32(encrypted).serialize() + + // single threaded implementation as this is debug function and it is simple to implement + let mut res: Vec<(String, i16, Vec)> = Vec::with_capacity(cloned.len()); + for v in cloned { + let ct = debug_trivial_encrypt_le_bytes(v.output_type as i16, &v.le_value); + let (ct_type, ct_bytes) = ct.serialize(); + res.push(( + v.handle, + ct_type, + ct_bytes + )); + } + + res }).await.unwrap(); - sqlx::query!(" - INSERT INTO ciphertexts(tenant_id, handle, ciphertext, ciphertext_version, ciphertext_type) - VALUES ($1, $2, $3, $4, $5) - ", tenant_id, handle, db_bytes, current_ciphertext_version(), db_type) - .execute(&self.pool).await.map_err(Into::::into)?; + let mut conn = self.pool.acquire().await.map_err(Into::::into)?; + let mut trx = conn.begin().await.map_err(Into::::into)?; + + for (handle, db_type, db_bytes) in out_cts { + sqlx::query!(" + INSERT INTO ciphertexts(tenant_id, handle, ciphertext, ciphertext_version, ciphertext_type) + VALUES ($1, $2, $3, $4, $5) + ", + tenant_id, handle, db_bytes, current_ciphertext_version(), db_type as i16 + ) + .execute(trx.as_mut()).await.map_err(Into::::into)?; + } + + trx.commit().await.map_err(Into::::into)?; return Ok(tonic::Response::new(GenericResponse { response_code: 0 })); } @@ -103,13 +121,13 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ assert_eq!(priv_key.len(), 1); - let mut cts = sqlx::query!(" - SELECT ciphertext, ciphertext_type + let cts = sqlx::query!(" + SELECT ciphertext, ciphertext_type, handle FROM ciphertexts WHERE tenant_id = $1 - AND handle = $2 + AND handle = ANY($2::TEXT[]) AND ciphertext_version = $3 - ", tenant_id, &req.handle, current_ciphertext_version()) + ", tenant_id, &req.handles, current_ciphertext_version()) .fetch_all(&self.pool) .await.map_err(Into::::into)?; @@ -117,18 +135,24 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ return Err(tonic::Status::not_found("ciphertext not found")); } - assert_eq!(cts.len(), 1); - let priv_key = priv_key.pop().unwrap().cks_key.unwrap(); - let ciphertext = cts.pop().unwrap(); - let value = tokio::task::spawn_blocking(move || { + let values = tokio::task::spawn_blocking(move || { let client_key: tfhe::ClientKey = bincode::deserialize(&priv_key).unwrap(); - let deserialized = tfhe_ops::deserialize_fhe_ciphertext(ciphertext.ciphertext_type, &ciphertext.ciphertext).unwrap(); - deserialized.decrypt(&client_key) + + let mut decrypted: Vec = Vec::with_capacity(cts.len()); + for ct in cts { + let deserialized = tfhe_ops::deserialize_fhe_ciphertext(ct.ciphertext_type, &ct.ciphertext).unwrap(); + decrypted.push(DebugDecryptResponseSingle { + output_type: ct.ciphertext_type as i32, + value: deserialized.decrypt(&client_key), + }); + } + + decrypted }).await.unwrap(); - return Ok(tonic::Response::new(DebugDecryptResponse { value })); + return Ok(tonic::Response::new(DebugDecryptResponse { values })); } async fn upload_ciphertexts( diff --git a/fhevm-engine/src/tests/mod.rs b/fhevm-engine/src/tests/mod.rs index 1e47100d..f8044d96 100644 --- a/fhevm-engine/src/tests/mod.rs +++ b/fhevm-engine/src/tests/mod.rs @@ -1,8 +1,12 @@ +use std::str::FromStr; + use tonic::metadata::MetadataValue; +use utils::default_api_key; use crate::server::coprocessor::fhevm_coprocessor_client::FhevmCoprocessorClient; -use crate::server::coprocessor::{AsyncComputeRequest, FheOperation, DebugEncryptRequest, DebugDecryptRequest, AsyncComputation}; +use crate::server::coprocessor::{AsyncComputation, AsyncComputeRequest, DebugDecryptRequest, DebugEncryptRequest, DebugEncryptRequestSingle, FheOperation}; mod utils; +mod operators; #[tokio::test] async fn test_smoke() -> Result<(), Box> { @@ -10,26 +14,26 @@ async fn test_smoke() -> Result<(), Box> { let mut client = FhevmCoprocessorClient::connect(app.app_url().to_string()).await?; - let api_key = "Bearer a1503fb6-d79b-4e9e-826d-44cf262f3e05"; + let api_key_header = format!("Bearer {}", default_api_key()); + let ct_type = 4; // i32 - // ciphertext A + // encrypt two ciphertexts { let mut encrypt_request = tonic::Request::new(DebugEncryptRequest { - handle: "0x0abc".to_string(), - original_value: 123, - }); - encrypt_request.metadata_mut().append("authorization", MetadataValue::from_static(api_key)); - let resp = client.debug_encrypt_ciphertext(encrypt_request).await?; - println!("encryption request: {:?}", resp); - } - - // ciphertext B - { - let mut encrypt_request = tonic::Request::new(DebugEncryptRequest { - handle: "0x0abd".to_string(), - original_value: 124, + values: vec![ + DebugEncryptRequestSingle { + handle: "0x0abc".to_string(), + le_value: vec![123], + output_type: ct_type, + }, + DebugEncryptRequestSingle { + handle: "0x0abd".to_string(), + le_value: vec![124], + output_type: ct_type, + }, + ], }); - encrypt_request.metadata_mut().append("authorization", MetadataValue::from_static(api_key)); + encrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); let resp = client.debug_encrypt_ciphertext(encrypt_request).await?; println!("encryption request: {:?}", resp); } @@ -58,7 +62,7 @@ async fn test_smoke() -> Result<(), Box> { }, ] }); - compute_request.metadata_mut().append("authorization", MetadataValue::from_static(api_key)); + compute_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); let resp = client.async_compute(compute_request).await?; println!("compute request: {:?}", resp); } @@ -66,26 +70,24 @@ async fn test_smoke() -> Result<(), Box> { println!("sleeping for computation to complete..."); tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; - // decrypt first - { - let mut decrypt_request = tonic::Request::new(DebugDecryptRequest { - handle: "0x0abe".to_string() - }); - decrypt_request.metadata_mut().append("authorization", MetadataValue::from_static(api_key)); - let resp = client.debug_decrypt_ciphertext(decrypt_request).await?; - println!("decrypt request: {:?}", resp); - assert_eq!(resp.get_ref().value, "247"); - } - - // decrypt second + // decrypt values { let mut decrypt_request = tonic::Request::new(DebugDecryptRequest { - handle: "0x0abf".to_string() + handles: vec![ + "0x0abe".to_string(), + "0x0abf".to_string(), + ], }); - decrypt_request.metadata_mut().append("authorization", MetadataValue::from_static(api_key)); + decrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); let resp = client.debug_decrypt_ciphertext(decrypt_request).await?; println!("decrypt request: {:?}", resp); - assert_eq!(resp.get_ref().value, "263"); + assert_eq!(resp.get_ref().values.len(), 2); + // first value + assert_eq!(resp.get_ref().values[0].value, "247"); + assert_eq!(resp.get_ref().values[0].output_type, ct_type); + // second value + assert_eq!(resp.get_ref().values[1].value, "263"); + assert_eq!(resp.get_ref().values[1].output_type, ct_type); } Ok(()) diff --git a/fhevm-engine/src/tests/operators.rs b/fhevm-engine/src/tests/operators.rs new file mode 100644 index 00000000..73aa9d1d --- /dev/null +++ b/fhevm-engine/src/tests/operators.rs @@ -0,0 +1,203 @@ +use bigdecimal::num_bigint::BigInt; +use strum::IntoEnumIterator; +use tonic::metadata::MetadataValue; +use std::str::FromStr; +use crate::{tests::utils::{setup_test_app, default_api_key}, tfhe_ops::{does_fhe_operation_support_both_encrypted_operands, does_fhe_operation_support_scalar}, types::{FheOperationType, SupportedFheOperations}}; +use crate::server::coprocessor::fhevm_coprocessor_client::FhevmCoprocessorClient; +use crate::server::coprocessor::{AsyncComputation, AsyncComputeRequest, DebugDecryptRequest, DebugEncryptRequest, DebugEncryptRequestSingle}; + + +struct BinaryOperatorTestCase { + bits: i32, + operand: i32, + operand_types: i32, + lhs: BigInt, + rhs: BigInt, + expected_output: BigInt, + is_scalar: bool, +} + +struct UnaryOperatorTestCase { + bits: i32, + inp: BigInt, +} + +fn supported_bits() -> &'static [i32] { + &[ + 8, + 16, + 32, + ] +} + +fn supported_bits_to_bit_type_in_db(inp: i32) -> i32 { + match inp { + 8 => 2, + 16 => 3, + 32 => 4, + other => panic!("unknown supported bits: {other}") + } +} + +#[tokio::test] +async fn test_fhe_binary_operands() -> Result<(), Box> { + let ops = generate_binary_test_cases(); + let app = setup_test_app().await?; + let mut client = FhevmCoprocessorClient::connect(app.app_url().to_string()).await?; + // needed for polling status + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(2) + .connect(app.db_url()) + .await?; + + let mut handle_counter = 0; + let mut next_handle = || { + let out = handle_counter; + handle_counter += 1; + format!("{:#08x}", out) + }; + + let api_key_header = format!("bearer {}", default_api_key()); + + let mut output_handles = Vec::with_capacity(ops.len()); + for op in &ops { + let lhs_handle = next_handle(); + let rhs_handle = if op.is_scalar { + let (_, bytes) = op.rhs.to_bytes_be(); + format!("0x{}", hex::encode(bytes)) + } else { next_handle() }; + let output_handle = next_handle(); + output_handles.push(output_handle.clone()); + + let (_, lhs_bytes) = op.lhs.to_bytes_le(); + + println!("Encrypting inputs for binary test bits:{} op:{} is_scalar:{} lhs:{} rhs:{}", + op.bits, op.operand, op.is_scalar, op.lhs.to_string(), op.rhs.to_string()); + let mut enc_request_payload = vec![ + DebugEncryptRequestSingle { + handle: lhs_handle.clone(), + le_value: lhs_bytes, + output_type: op.operand_types, + }, + ]; + if !op.is_scalar { + let (_, rhs_bytes) = op.rhs.to_bytes_le(); + enc_request_payload.push(DebugEncryptRequestSingle { + handle: rhs_handle.clone(), + le_value: rhs_bytes, + output_type: op.operand_types, + }); + } + let mut encrypt_request = tonic::Request::new(DebugEncryptRequest { + values: enc_request_payload, + }); + encrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); + let _resp = client.debug_encrypt_ciphertext(encrypt_request).await?; + + println!("rhs handle:{}", rhs_handle); + println!("Scheduling computation for binary test bits:{} op:{} is_scalar:{} lhs:{} rhs:{} output:{}", + op.bits, op.operand, op.is_scalar, op.lhs.to_string(), op.rhs.to_string(), op.expected_output.to_string()); + let mut compute_request = tonic::Request::new(AsyncComputeRequest { + computations: vec![ + AsyncComputation { + operation: op.operand, + is_scalar: op.is_scalar, + output_handle: output_handle, + input_handles: vec![ + lhs_handle.clone(), + rhs_handle.clone(), + ] + }, + ] + }); + compute_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); + let _resp = client.async_compute(compute_request).await?; + } + + println!("Computations scheduled, waiting upon completion..."); + + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + let count = + sqlx::query!("SELECT count(*) FROM computations WHERE NOT is_completed AND NOT is_error") + .fetch_one(&pool) + .await?; + let current_count = count.count.unwrap(); + if current_count == 0 { + println!("All computations completed"); + break; + } else { + println!("{current_count} computations remaining, waiting..."); + } + } + + let mut decrypt_request = tonic::Request::new(DebugDecryptRequest { + handles: output_handles.clone(), + }); + decrypt_request.metadata_mut().append("authorization", MetadataValue::from_str(&api_key_header).unwrap()); + let resp = client.debug_decrypt_ciphertext(decrypt_request).await?; + + assert_eq!(resp.get_ref().values.len(), output_handles.len(), "Outputs length doesn't match"); + for (idx, op) in ops.iter().enumerate() { + let decr_response = &resp.get_ref().values[idx]; + println!("Checking computation for binary test bits:{} op:{} is_scalar:{} lhs:{} rhs:{} output:{}", + op.bits, op.operand, op.is_scalar, op.lhs.to_string(), op.rhs.to_string(), op.expected_output.to_string()); + assert_eq!(decr_response.output_type, op.operand_types, "operand types not equal"); + assert_eq!(decr_response.value, op.expected_output.to_string(), "operand output values not equal"); + } + + Ok(()) +} + +fn generate_binary_test_cases() -> Vec { + let mut cases = Vec::new(); + let mut push_case = |bits: i32, is_scalar: bool, shift_by: i32, op: SupportedFheOperations| { + let mut lhs = BigInt::from(12); + let mut rhs = BigInt::from(7); + lhs <<= shift_by; + rhs <<= shift_by; + let expected_output = compute_expected_binary_output(&lhs, &rhs, op); + let operand = op as i32; + cases.push(BinaryOperatorTestCase { + bits, + operand, + operand_types: supported_bits_to_bit_type_in_db(bits), + lhs, + rhs, + expected_output, + is_scalar, + }); + }; + + for bits in supported_bits() { + let bits = *bits; + let mut shift_by = bits - 8; + for op in SupportedFheOperations::iter() { + if op == SupportedFheOperations::FheMul { + // don't go out of bit bounds when multiplying two numbers, so we shift by less + shift_by /= 2; + } + if op.op_type() == FheOperationType::Binary { + if does_fhe_operation_support_both_encrypted_operands(&op) { + push_case(bits, false, shift_by, op); + } + + if does_fhe_operation_support_scalar(&op) { + push_case(bits, true, shift_by, op); + } + } + } + } + + cases +} + +fn compute_expected_binary_output(lhs: &BigInt, rhs: &BigInt, op: SupportedFheOperations) -> BigInt { + match op { + SupportedFheOperations::FheAdd => lhs + rhs, + SupportedFheOperations::FheSub => lhs - rhs, + SupportedFheOperations::FheMul => lhs * rhs, + SupportedFheOperations::FheDiv => lhs / rhs, + other => panic!("unsupported binary operation: {:?}", other), + } +} \ No newline at end of file diff --git a/fhevm-engine/src/tests/utils.rs b/fhevm-engine/src/tests/utils.rs index faf838f5..1fe51d7f 100644 --- a/fhevm-engine/src/tests/utils.rs +++ b/fhevm-engine/src/tests/utils.rs @@ -8,6 +8,7 @@ pub struct TestInstance { // send message to this on destruction to stop the app app_close_channel: tokio::sync::watch::Sender, app_url: String, + db_url: String, } impl Drop for TestInstance { @@ -21,6 +22,14 @@ impl TestInstance { pub fn app_url(&self) -> &str { self.app_url.as_str() } + + pub fn db_url(&self) -> &str { + self.db_url.as_str() + } +} + +pub fn default_api_key() -> &'static str { + "a1503fb6-d79b-4e9e-826d-44cf262f3e05" } pub async fn setup_test_app() -> Result> { @@ -77,7 +86,7 @@ pub async fn setup_test_app() -> Result tokio_threads: 2, pg_pool_max_connections: 2, server_addr: format!("127.0.0.1:{app_port}"), - database_url: Some(db_url), + database_url: Some(db_url.clone()), }; std::thread::spawn(move || { @@ -91,5 +100,6 @@ pub async fn setup_test_app() -> Result _container: container, app_close_channel, app_url: format!("http://127.0.0.1:{app_port}"), + db_url, }) } \ No newline at end of file diff --git a/fhevm-engine/src/tfhe_ops.rs b/fhevm-engine/src/tfhe_ops.rs index 8261fb56..3310f013 100644 --- a/fhevm-engine/src/tfhe_ops.rs +++ b/fhevm-engine/src/tfhe_ops.rs @@ -1,3 +1,5 @@ +use tfhe::{prelude::FheTryTrivialEncrypt, FheBool, FheUint16, FheUint32, FheUint8}; + use crate::types::{CoprocessorError, FheOperationType, SupportedFheCiphertexts, SupportedFheOperations}; pub fn current_ciphertext_version() -> i16 { @@ -15,9 +17,24 @@ pub fn perform_fhe_operation(fhe_operation: i16, input_operands: &[SupportedFheC (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { Ok(SupportedFheCiphertexts::FheUint8(a + b)) } + (SupportedFheCiphertexts::FheUint16(a), SupportedFheCiphertexts::FheUint16(b)) => { + Ok(SupportedFheCiphertexts::FheUint16(a + b)) + } (SupportedFheCiphertexts::FheUint32(a), SupportedFheCiphertexts::FheUint32(b)) => { Ok(SupportedFheCiphertexts::FheUint32(a + b)) } + (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { + // TODO: figure out type to add correctly 256 bit operands from handles + let (l, h) = b.to_low_high_u128(); + assert_eq!(h, 0, "Not supported yet"); + Ok(SupportedFheCiphertexts::FheUint8(a + (l as u8))) + } + (SupportedFheCiphertexts::FheUint16(a), SupportedFheCiphertexts::Scalar(b)) => { + // TODO: figure out type to add correctly 256 bit operands from handles + let (l, h) = b.to_low_high_u128(); + assert_eq!(h, 0, "Not supported yet"); + Ok(SupportedFheCiphertexts::FheUint16(a + (l as u16))) + } (SupportedFheCiphertexts::FheUint32(a), SupportedFheCiphertexts::Scalar(b)) => { // TODO: figure out type to add correctly 256 bit operands from handles let (l, h) = b.to_low_high_u128(); @@ -29,18 +46,162 @@ pub fn perform_fhe_operation(fhe_operation: i16, input_operands: &[SupportedFheC } } } - SupportedFheOperations::FheSub => todo!(), + SupportedFheOperations::FheSub => { + assert_eq!(input_operands.len(), 2); + + match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { + Ok(SupportedFheCiphertexts::FheUint8(a - b)) + } + (SupportedFheCiphertexts::FheUint16(a), SupportedFheCiphertexts::FheUint16(b)) => { + Ok(SupportedFheCiphertexts::FheUint16(a - b)) + } + (SupportedFheCiphertexts::FheUint32(a), SupportedFheCiphertexts::FheUint32(b)) => { + Ok(SupportedFheCiphertexts::FheUint32(a - b)) + } + (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { + // TODO: figure out type to add correctly 256 bit operands from handles + let (l, h) = b.to_low_high_u128(); + assert_eq!(h, 0, "Not supported yet"); + Ok(SupportedFheCiphertexts::FheUint8(a - (l as u8))) + } + (SupportedFheCiphertexts::FheUint16(a), SupportedFheCiphertexts::Scalar(b)) => { + // TODO: figure out type to add correctly 256 bit operands from handles + let (l, h) = b.to_low_high_u128(); + assert_eq!(h, 0, "Not supported yet"); + Ok(SupportedFheCiphertexts::FheUint16(a - (l as u16))) + } + (SupportedFheCiphertexts::FheUint32(a), SupportedFheCiphertexts::Scalar(b)) => { + // TODO: figure out type to add correctly 256 bit operands from handles + let (l, h) = b.to_low_high_u128(); + assert_eq!(h, 0, "Not supported yet"); + Ok(SupportedFheCiphertexts::FheUint32(a - (l as u32))) + } + _ => { + panic!("Unsupported fhe types"); + } + } + }, + SupportedFheOperations::FheMul => { + assert_eq!(input_operands.len(), 2); + + match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { + Ok(SupportedFheCiphertexts::FheUint8(a * b)) + } + (SupportedFheCiphertexts::FheUint16(a), SupportedFheCiphertexts::FheUint16(b)) => { + Ok(SupportedFheCiphertexts::FheUint16(a * b)) + } + (SupportedFheCiphertexts::FheUint32(a), SupportedFheCiphertexts::FheUint32(b)) => { + Ok(SupportedFheCiphertexts::FheUint32(a * b)) + } + (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { + // TODO: figure out type to add correctly 256 bit operands from handles + let (l, h) = b.to_low_high_u128(); + assert_eq!(h, 0, "Not supported yet"); + Ok(SupportedFheCiphertexts::FheUint8(a * (l as u8))) + } + (SupportedFheCiphertexts::FheUint16(a), SupportedFheCiphertexts::Scalar(b)) => { + // TODO: figure out type to add correctly 256 bit operands from handles + let (l, h) = b.to_low_high_u128(); + assert_eq!(h, 0, "Not supported yet"); + Ok(SupportedFheCiphertexts::FheUint16(a * (l as u16))) + } + (SupportedFheCiphertexts::FheUint32(a), SupportedFheCiphertexts::Scalar(b)) => { + // TODO: figure out type to add correctly 256 bit operands from handles + let (l, h) = b.to_low_high_u128(); + assert_eq!(h, 0, "Not supported yet"); + Ok(SupportedFheCiphertexts::FheUint32(a * (l as u32))) + } + _ => { + panic!("Unsupported fhe types"); + } + } + }, + SupportedFheOperations::FheDiv => { + assert_eq!(input_operands.len(), 2); + + match (&input_operands[0], &input_operands[1]) { + (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::FheUint8(b)) => { + Ok(SupportedFheCiphertexts::FheUint8(a / b)) + } + (SupportedFheCiphertexts::FheUint16(a), SupportedFheCiphertexts::FheUint16(b)) => { + Ok(SupportedFheCiphertexts::FheUint16(a / b)) + } + (SupportedFheCiphertexts::FheUint32(a), SupportedFheCiphertexts::FheUint32(b)) => { + Ok(SupportedFheCiphertexts::FheUint32(a / b)) + } + (SupportedFheCiphertexts::FheUint8(a), SupportedFheCiphertexts::Scalar(b)) => { + // TODO: figure out type to add correctly 256 bit operands from handles + let (l, h) = b.to_low_high_u128(); + assert_eq!(h, 0, "Not supported yet"); + Ok(SupportedFheCiphertexts::FheUint8(a / (l as u8))) + } + (SupportedFheCiphertexts::FheUint16(a), SupportedFheCiphertexts::Scalar(b)) => { + // TODO: figure out type to add correctly 256 bit operands from handles + let (l, h) = b.to_low_high_u128(); + assert_eq!(h, 0, "Not supported yet"); + Ok(SupportedFheCiphertexts::FheUint16(a / (l as u16))) + } + (SupportedFheCiphertexts::FheUint32(a), SupportedFheCiphertexts::Scalar(b)) => { + // TODO: figure out type to add correctly 256 bit operands from handles + let (l, h) = b.to_low_high_u128(); + assert_eq!(h, 0, "Not supported yet"); + Ok(SupportedFheCiphertexts::FheUint32(a / (l as u32))) + } + _ => { + panic!("Unsupported fhe types"); + } + } + }, SupportedFheOperations::FheNot => todo!(), SupportedFheOperations::FheIfThenElse => todo!(), } } +/// Function assumes encryption key already set +pub fn debug_trivial_encrypt_le_bytes(output_type: i16, input_bytes: &[u8]) -> SupportedFheCiphertexts { + match output_type { + 1 => { + SupportedFheCiphertexts::FheBool(FheBool::try_encrypt_trivial(input_bytes[0] > 0).unwrap()) + } + 2 => { + SupportedFheCiphertexts::FheUint8(FheUint8::try_encrypt_trivial(input_bytes[0]).unwrap()) + } + 3 => { + let mut padded: [u8; 2] = [0; 2]; + let len = padded.len().min(input_bytes.len()); + padded[0..len].copy_from_slice(&input_bytes[0..len]); + let res = u16::from_le_bytes(padded); + SupportedFheCiphertexts::FheUint16(FheUint16::try_encrypt_trivial(res).unwrap()) + } + 4 => { + let mut padded: [u8; 4] = [0; 4]; + let len = padded.len().min(input_bytes.len()); + padded[0..len].copy_from_slice(&input_bytes[0..len]); + let res: u32 = u32::from_le_bytes(padded); + SupportedFheCiphertexts::FheUint32(FheUint32::try_encrypt_trivial(res).unwrap()) + } + other => { + panic!("Unknown input type for trivial encryption: {other}") + } + } +} + pub fn deserialize_fhe_ciphertext(input_type: i16, input_bytes: &[u8]) -> Result> { match input_type { 1 => { let v: tfhe::FheBool = bincode::deserialize(input_bytes)?; Ok(SupportedFheCiphertexts::FheBool(v)) } + 2 => { + let v: tfhe::FheUint8 = bincode::deserialize(input_bytes)?; + Ok(SupportedFheCiphertexts::FheUint8(v)) + } + 3 => { + let v: tfhe::FheUint16 = bincode::deserialize(input_bytes)?; + Ok(SupportedFheCiphertexts::FheUint16(v)) + } 4 => { let v: tfhe::FheUint32 = bincode::deserialize(input_bytes)?; Ok(SupportedFheCiphertexts::FheUint32(v)) @@ -124,6 +285,15 @@ pub fn check_fhe_operand_types(fhe_operation: i32, input_types: &[i16], is_scala } } +// add operations here that don't support both encrypted operands +#[cfg(test)] +pub fn does_fhe_operation_support_both_encrypted_operands(op: &SupportedFheOperations) -> bool { + match op { + SupportedFheOperations::FheDiv => false, + _ => true + } +} + pub fn does_fhe_operation_support_scalar(op: &SupportedFheOperations) -> bool { match op.op_type() { FheOperationType::Binary => { diff --git a/fhevm-engine/src/types.rs b/fhevm-engine/src/types.rs index 2501e99c..a8916c0e 100644 --- a/fhevm-engine/src/types.rs +++ b/fhevm-engine/src/types.rs @@ -39,6 +39,12 @@ pub enum CoprocessorError { fhe_operation_name: String, operand_types: Vec, }, + // TODO: implement scalar division by zero error + // FheOperationScalarDivisionByZero { + // lhs_handle: String, + // fhe_operation: i32, + // fhe_operation_name: String, + // }, } impl std::fmt::Display for CoprocessorError { @@ -128,15 +134,18 @@ pub enum SupportedFheCiphertexts { Scalar(U256), } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, strum::EnumIter)] #[repr(i8)] pub enum SupportedFheOperations { FheAdd = 0, FheSub = 1, - FheNot = 2, - FheIfThenElse = 3, + FheMul = 2, + FheDiv = 3, + FheNot = 4, + FheIfThenElse = 5, } +#[derive(PartialEq, Eq)] pub enum FheOperationType { Binary, Unary, @@ -173,7 +182,11 @@ impl SupportedFheCiphertexts { impl SupportedFheOperations { pub fn op_type(&self) -> FheOperationType { match self { - SupportedFheOperations::FheAdd | SupportedFheOperations::FheSub => FheOperationType::Binary, + SupportedFheOperations::FheAdd | + SupportedFheOperations::FheSub | + SupportedFheOperations::FheMul | + SupportedFheOperations::FheDiv + => FheOperationType::Binary, SupportedFheOperations::FheNot => FheOperationType::Unary, SupportedFheOperations::FheIfThenElse => FheOperationType::Other, } @@ -187,6 +200,8 @@ impl TryFrom for SupportedFheOperations { let res = match value { 0 => Ok(SupportedFheOperations::FheAdd), 1 => Ok(SupportedFheOperations::FheSub), + 2 => Ok(SupportedFheOperations::FheMul), + 3 => Ok(SupportedFheOperations::FheDiv), _ => Err(CoprocessorError::UnknownFheOperation(value as i32)) }; diff --git a/proto/coprocessor.proto b/proto/coprocessor.proto index 6cb10095..362cef79 100644 --- a/proto/coprocessor.proto +++ b/proto/coprocessor.proto @@ -17,16 +17,26 @@ service FhevmCoprocessor { } message DebugEncryptRequest { + repeated DebugEncryptRequestSingle values = 1; +} + +message DebugEncryptRequestSingle { string handle = 1; - int32 original_value = 2; + bytes le_value = 2; + int32 output_type = 3; } message DebugDecryptRequest { - string handle = 1; + repeated string handles = 1; } message DebugDecryptResponse { - string value = 1; + repeated DebugDecryptResponseSingle values = 1; +} + +message DebugDecryptResponseSingle { + int32 output_type = 1; + string value = 2; } enum FheOperation {