Skip to content

Commit

Permalink
fix(scheduler): fixup error handling and switch to anyhow::Result
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniupop committed Nov 14, 2024
1 parent 48d58c6 commit f943a8e
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 154 deletions.
51 changes: 13 additions & 38 deletions fhevm-engine/coprocessor/src/tfhe_worker.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use crate::{db_queries::populate_cache_with_tenant_keys, types::TfheTenantKeys};
use fhevm_engine_common::types::{FhevmError, Handle, SupportedFheCiphertexts};
use fhevm_engine_common::{
tfhe_ops::{current_ciphertext_version, perform_fhe_operation},
types::SupportedFheOperations,
};
use fhevm_engine_common::{tfhe_ops::current_ciphertext_version, types::SupportedFheOperations};
use itertools::Itertools;
use lazy_static::lazy_static;
use opentelemetry::trace::{Span, TraceContextExt, Tracer};
Expand Down Expand Up @@ -102,10 +99,9 @@ async fn tfhe_worker_cycle(
let mut s = tracer.start_with_context("begin_transaction", &loop_ctx);
let mut trx = conn.begin().await?;
s.end();

// This query locks our work items so other worker doesn't select them.
let mut s = tracer.start_with_context("query_work_items", &loop_ctx);
let mut the_work = query!(
let the_work = query!(
"
WITH RECURSIVE dependent_computations(tenant_id, output_handle, dependencies, fhe_operation, is_scalar, produced_handles) AS (
SELECT c.tenant_id, c.output_handle, c.dependencies, c.fhe_operation, c.is_scalar, ARRAY[ROW(c.tenant_id, c.output_handle)]
Expand Down Expand Up @@ -169,7 +165,6 @@ async fn tfhe_worker_cycle(
if the_work.is_empty() {
continue;
}

WORK_ITEMS_FOUND_COUNTER.inc_by(the_work.len() as u64);
info!(target: "tfhe_worker", { count = the_work.len() }, "Processing work items");
// Make sure we process each tenant independently to avoid
Expand All @@ -184,12 +179,12 @@ async fn tfhe_worker_cycle(
let key_cache = tenant_key_cache.read().await;
for (tenant_id, work) in work_by_tenant.iter() {
let _ = tenants_to_query.insert(*tenant_id);
if !key_cache.contains(&tenant_id) {
if !key_cache.contains(tenant_id) {
let _ = keys_to_query.insert(*tenant_id);
}
for w in work.iter() {
for dh in &w.dependencies {
let _ = cts_to_query.insert(&dh);
let _ = cts_to_query.insert(dh);
}
}
}
Expand All @@ -207,7 +202,6 @@ async fn tfhe_worker_cycle(
));
populate_cache_with_tenant_keys(keys_to_query, trx.as_mut(), &tenant_key_cache).await?;
s.end();

let mut s = tracer.start_with_context("query_ciphertext_batch", &loop_ctx);
s.set_attribute(KeyValue::new("cts_to_query", cts_to_query.len() as i64));
// TODO: select all the ciphertexts where they're contained in the tuples
Expand All @@ -224,21 +218,18 @@ async fn tfhe_worker_cycle(
.fetch_all(trx.as_mut())
.await?;
s.end();

// index ciphertexts in hashmap
let mut ciphertext_map: HashMap<(i32, &[u8]), _> =
HashMap::with_capacity(ciphertexts_rows.len());
for row in &ciphertexts_rows {
let _ = ciphertext_map.insert((row.tenant_id, &row.handle), row);
}

// TODO-ap
let mut s = tracer.start_with_context("schedule_fhe_work", &loop_ctx);
s.set_attribute(KeyValue::new("work_items", work_by_tenant.len() as i64));

let mut s_outer = tracer.start_with_context("wait_and_update_fhe_work", &loop_ctx);
// Process tenants in sequence to avoid switching keys during execution
for (tenant_id, work) in work_by_tenant.iter() {
let mut s_schedule = tracer.start_with_context("schedule_fhe_work", &loop_ctx);
s_schedule.set_attribute(KeyValue::new("work_items", work.len() as i64));
s_schedule.set_attribute(KeyValue::new("tenant_id", *tenant_id as i64));
// We need to ensure that no handles are missing from
// either DB inputs or values produced within this batch
// before this batch is scheduled.
Expand Down Expand Up @@ -282,6 +273,7 @@ async fn tfhe_worker_cycle(
let mut producer_indexes: HashMap<&Handle, usize> = HashMap::new();
let mut consumer_indexes: HashMap<usize, usize> = HashMap::new();
'work_items: for (widx, w) in work.iter().enumerate() {
let mut s = tracer.start_with_context("tfhe_computation", &loop_ctx);
let fhe_op: SupportedFheOperations = w
.fhe_operation
.try_into()
Expand All @@ -305,16 +297,11 @@ async fn tfhe_worker_cycle(
} else {
// If this cannot be computed, we need to
// exclude it from the DF graph.
println!("Uncomputable found");
uncomputable.insert(widx, ());
continue 'work_items;
}
}

// copy for setting error in database
let mut s = tracer.start_with_context("tfhe_computation", &loop_ctx);

// TODO-ap
let n = graph.add_node(
w.output_handle.clone(),
w.fhe_operation.into(),
Expand All @@ -328,7 +315,6 @@ async fn tfhe_worker_cycle(
"handle",
format!("0x{}", hex::encode(&w.output_handle)),
));
//s.set_attribute(KeyValue::new("output_type", db_type as i64));
let input_types = input_ciphertexts
.iter()
.map(|i| match i {
Expand Down Expand Up @@ -362,34 +348,23 @@ async fn tfhe_worker_cycle(
}
}
}
s_schedule.end();

// Execute the DFG with the current tenant's keys
let mut s_outer = tracer.start_with_context("wait_and_update_fhe_work", &loop_ctx);
{
let mut rk = tenant_key_cache.write().await;
let keys = rk.get(tenant_id).expect("Can't get tenant key from cache");

// TODO-ap
// Schedule computations in parallel as dependences allow
let mut sched = Scheduler::new(&mut graph.graph, args.coprocessor_fhe_threads);
let now = std::time::SystemTime::now();
sched.schedule(keys.sks.clone()).await.map_err(|_| {
let err: Box<dyn std::error::Error + Send + Sync> =
Box::new(FhevmError::BadInputs);
error!(target: "tfhe_worker",
{ error = err },
"error while processing work item"
)
});
println!(
"GRAPH Execution time (sched): {}",
now.elapsed().unwrap().as_millis()
);
sched.schedule(keys.sks.clone()).await?;
}
// Extract the results from the graph
let res = graph.get_results().unwrap();

// TODO-ap filter out computations that could not complete
for (idx, w) in work.iter().enumerate() {
// Filter out computations that could not complete
if uncomputable.contains_key(&idx) {
continue;
}
Expand Down Expand Up @@ -478,9 +453,9 @@ async fn tfhe_worker_cycle(
}
}
}
s_outer.end();
}
s.end();
s_outer.end();

trx.commit().await?;

Expand Down
11 changes: 11 additions & 0 deletions fhevm-engine/coprocessor/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::sync::Arc;

use fhevm_engine_common::types::FhevmError;
use scheduler::dfg::types::SchedulerError;

#[derive(Debug)]
pub enum CoprocessorError {
DbError(sqlx::Error),
SchedulerError(SchedulerError),
Unauthorized,
FhevmError(FhevmError),
DuplicateOutputHandleInBatch(String),
Expand Down Expand Up @@ -60,6 +62,9 @@ impl std::fmt::Display for CoprocessorError {
Self::DbError(dbe) => {
write!(f, "Coprocessor db error: {:?}", dbe)
}
Self::SchedulerError(se) => {
write!(f, "Coprocessor scheduler error: {:?}", se)
}
Self::Unauthorized => {
write!(f, "API key unknown/invalid/not provided")
}
Expand Down Expand Up @@ -169,6 +174,12 @@ impl From<sqlx::Error> for CoprocessorError {
}
}

impl From<SchedulerError> for CoprocessorError {
fn from(err: SchedulerError) -> Self {
CoprocessorError::SchedulerError(err)
}
}

impl From<CoprocessorError> for tonic::Status {
fn from(err: CoprocessorError) -> Self {
tonic::Status::from_error(Box::new(err))
Expand Down
6 changes: 3 additions & 3 deletions fhevm-engine/executor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,9 @@ pub fn run_computation(
}
}

pub fn build_taskgraph_from_request<'a, 'b>(
dfg: &'a mut DFGraph<'b>,
req: &'b SyncComputeRequest,
pub fn build_taskgraph_from_request(
dfg: &mut DFGraph,
req: &SyncComputeRequest,
state: &ComputationState,
) -> Result<(), SyncComputeError> {
let mut produced_handles: HashMap<&Handle, usize> = HashMap::new();
Expand Down
30 changes: 11 additions & 19 deletions fhevm-engine/scheduler/src/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,9 @@ pub mod types;

use crate::dfg::types::*;
use anyhow::Result;
use std::{cell::RefCell, collections::HashMap};
// use executor::server::{
// CompressedCiphertext, ComputationState, Input, SyncComputation, SyncComputeError,
// SyncComputeRequest,
// };
use daggy::{petgraph::graph::node_index, Dag, NodeIndex};
use fhevm_engine_common::types::{
FhevmError, Handle, SupportedFheCiphertexts, HANDLE_LEN, SCALAR_LEN,
};
use tfhe::integer::U256;
use fhevm_engine_common::types::{Handle, SupportedFheCiphertexts};
use std::cell::RefCell;

thread_local! {
pub static THREAD_POOL: RefCell<Option<rayon::ThreadPool>> = const {RefCell::new(None)};
Expand All @@ -31,26 +24,25 @@ impl std::fmt::Debug for OpNode {
f.debug_struct("OpNode")
.field("OP", &self.opcode)
.field(
"Result",
&format_args!("{0:?} (0x{0:X})", &self.result_handle[0]),
"Result handle",
&format_args!("{:02X?}", &self.result_handle),
)
.finish()
}
}

#[derive(Default, Debug)]
pub struct DFGraph<'a> {
pub struct DFGraph {
pub graph: Dag<OpNode, OpEdge>,
produced_handles: HashMap<&'a Handle, NodeIndex>,
}

impl<'a> DFGraph<'a> {
impl DFGraph {
pub fn add_node(
&mut self,
rh: Handle,
opcode: i32,
inputs: Vec<DFGTaskInput>,
) -> Result<NodeIndex, SchedulerError> {
) -> Result<NodeIndex> {
Ok(self.graph.add_node(OpNode {
opcode,
result: None,
Expand All @@ -63,7 +55,7 @@ impl<'a> DFGraph<'a> {
source: usize,
destination: usize,
consumer_input: usize,
) -> Result<(), SchedulerError> {
) -> Result<()> {
let consumer_index = node_index(destination);
self.graph[consumer_index].inputs[consumer_input] = DFGTaskInput::Dependence(Some(source));
let _edge = self
Expand All @@ -73,20 +65,20 @@ impl<'a> DFGraph<'a> {
node_index(destination),
consumer_input as u8,
)
.map_err(|_| SchedulerError::SchedulerError)?;
.map_err(|_| SchedulerError::CyclicDependence)?;
Ok(())
}

pub fn get_results(
&mut self,
) -> Result<Vec<(Handle, (SupportedFheCiphertexts, i16, Vec<u8>))>, SchedulerError> {
) -> Result<Vec<(Handle, (SupportedFheCiphertexts, i16, Vec<u8>))>> {
let mut res = Vec::with_capacity(self.graph.node_count());
for index in 0..self.graph.node_count() {
let node = self.graph.node_weight_mut(NodeIndex::new(index)).unwrap();
if let Some(ct) = &node.result {
res.push((node.result_handle.clone(), ct.clone()));
} else {
return Err(SchedulerError::SchedulerError);
return Err(SchedulerError::DataflowGraphError.into());
}
}
Ok(res)
Expand Down
Loading

0 comments on commit f943a8e

Please sign in to comment.