Skip to content

Commit

Permalink
Run rustfmt
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zk committed Aug 16, 2024
1 parent ff8ce85 commit bf8a4a1
Show file tree
Hide file tree
Showing 11 changed files with 787 additions and 503 deletions.
2 changes: 1 addition & 1 deletion fhevm-engine/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ pub struct Args {

pub fn parse_args() -> Args {
Args::parse()
}
}
55 changes: 35 additions & 20 deletions fhevm-engine/src/db_queries.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
use std::collections::{BTreeSet, HashMap};
use std::str::FromStr;

use sqlx::{query, Postgres};
use crate::types::CoprocessorError;
use sqlx::{query, Postgres};

/// Returns tenant id upon valid authorization request
pub async fn check_if_api_key_is_valid<T>(req: &tonic::Request<T>, pool: &sqlx::Pool<Postgres>) -> Result<i32, CoprocessorError> {
pub async fn check_if_api_key_is_valid<T>(
req: &tonic::Request<T>,
pool: &sqlx::Pool<Postgres>,
) -> Result<i32, CoprocessorError> {
match req.metadata().get("authorization") {
Some(auth) => {
let auth_header = String::from_utf8(auth.as_bytes().to_owned())
.map_err(|_| CoprocessorError::Unauthorized)?
.to_lowercase();

let prefix = "bearer ";
if !auth_header.starts_with(prefix) {
return Err(CoprocessorError::Unauthorized);
}

let tail = &auth_header[prefix.len()..];
let api_key = tail.trim();
let api_key =
match sqlx::types::Uuid::from_str(api_key) {
Ok(uuid) => uuid,
Err(_) => return Err(CoprocessorError::Unauthorized),
};
let api_key = match sqlx::types::Uuid::from_str(api_key) {
Ok(uuid) => uuid,
Err(_) => return Err(CoprocessorError::Unauthorized),
};

let tenant =
query!("SELECT tenant_id FROM tenants WHERE tenant_api_key = $1", api_key)
.fetch_all(pool)
.await
.map_err(Into::<CoprocessorError>::into)?;
let tenant = query!(
"SELECT tenant_id FROM tenants WHERE tenant_api_key = $1",
api_key
)
.fetch_all(pool)
.await
.map_err(Into::<CoprocessorError>::into)?;

if tenant.is_empty() {
return Err(CoprocessorError::Unauthorized);
Expand All @@ -44,7 +48,11 @@ pub async fn check_if_api_key_is_valid<T>(req: &tonic::Request<T>, pool: &sqlx::
}

/// Returns ciphertext types
pub async fn check_if_ciphertexts_exist_in_db(mut cts: BTreeSet<Vec<u8>>, tenant_id: i32, pool: &sqlx::Pool<Postgres>) -> Result<HashMap<Vec<u8>, i16>, CoprocessorError> {
pub async fn check_if_ciphertexts_exist_in_db(
mut cts: BTreeSet<Vec<u8>>,
tenant_id: i32,
pool: &sqlx::Pool<Postgres>,
) -> Result<HashMap<Vec<u8>, i16>, CoprocessorError> {
let handles_to_check_in_db_vec = cts.iter().cloned().collect::<Vec<_>>();
let ciphertexts = query!(
"
Expand All @@ -55,19 +63,26 @@ pub async fn check_if_ciphertexts_exist_in_db(mut cts: BTreeSet<Vec<u8>>, tenant
",
&handles_to_check_in_db_vec,
tenant_id,
).fetch_all(pool).await.map_err(Into::<CoprocessorError>::into)?;
)
.fetch_all(pool)
.await
.map_err(Into::<CoprocessorError>::into)?;

let mut result = HashMap::with_capacity(cts.len());
for ct in ciphertexts {
assert!(cts.remove(&ct.handle), "any ciphertext selected must exist");
assert!(result.insert(ct.handle.clone(), ct.ciphertext_type).is_none());
assert!(result
.insert(ct.handle.clone(), ct.ciphertext_type)
.is_none());
}

if !cts.is_empty() {
return Err(CoprocessorError::UnexistingInputCiphertextsFound(cts.into_iter().map(|i| {
format!("0x{}", hex::encode(i))
}).collect()));
return Err(CoprocessorError::UnexistingInputCiphertextsFound(
cts.into_iter()
.map(|i| format!("0x{}", hex::encode(i)))
.collect(),
));
}

Ok(result)
}
}
28 changes: 18 additions & 10 deletions fhevm-engine/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
use tokio::task::JoinSet;

mod server;
mod db_queries;
mod cli;
mod types;
mod utils;
mod tfhe_worker;
mod tfhe_ops;
mod db_queries;
mod server;
#[cfg(test)]
mod tests;
mod tfhe_ops;
mod tfhe_worker;
mod types;
mod utils;

fn main() {
let args = crate::cli::parse_args();
assert!(args.work_items_batch_size < args.tenant_key_cache_size, "Work items batch size must be less than tenant key cache size");
assert!(
args.work_items_batch_size < args.tenant_key_cache_size,
"Work items batch size must be less than tenant key cache size"
);

if args.generate_fhe_keys {
generate_fhe_keys();
Expand All @@ -22,7 +25,10 @@ fn main() {
}

// separate function for testing
pub fn start_runtime(args: crate::cli::Args, close_recv: Option<tokio::sync::watch::Receiver<bool>>) {
pub fn start_runtime(
args: crate::cli::Args,
close_recv: Option<tokio::sync::watch::Receiver<bool>>,
) {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(args.tokio_threads)
// not using tokio main to specify max blocking threads
Expand Down Expand Up @@ -50,7 +56,9 @@ pub fn start_runtime(args: crate::cli::Args, close_recv: Option<tokio::sync::wat
})
}

async fn async_main(args: crate::cli::Args) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
async fn async_main(
args: crate::cli::Args,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut set = JoinSet::new();
if args.run_server {
println!("Initializing api server");
Expand Down Expand Up @@ -89,4 +97,4 @@ fn generate_fhe_keys() {
std::fs::write(format!("{output_dir}/pks"), compact_key).unwrap();
println!("Creating file {output_dir}/sks");
std::fs::write(format!("{output_dir}/sks"), server_key).unwrap();
}
}
Loading

0 comments on commit bf8a4a1

Please sign in to comment.