Skip to content

Commit

Permalink
[ENH] Operator/Executor error handling (chroma-core#1903)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Allows errors to be propagated from operators back to the original
caller of the task.
 - New functionality
	 - None

## Test plan
*How are these changes tested?*
Compile checks show that the error propagates as expected. 
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored Mar 20, 2024
1 parent 976aee3 commit 3985032
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 26 deletions.
1 change: 0 additions & 1 deletion rust/worker/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Defines 17 standard error codes based on the error codes defined in the
// gRPC spec. https://grpc.github.io/grpc/core/md_doc_statuscodes.html
// Custom errors can use these codes in order to allow for generic handling

use std::error::Error;

#[derive(PartialEq, Debug)]
Expand Down
13 changes: 9 additions & 4 deletions rust/worker/src/execution/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,14 @@ mod tests {
struct MockOperator {}
#[async_trait]
impl Operator<f32, String> for MockOperator {
async fn run(&self, input: &f32) -> String {
type Error = ();
async fn run(&self, input: &f32) -> Result<String, Self::Error> {
// sleep to simulate work
tokio::time::sleep(tokio::time::Duration::from_millis(
MOCK_OPERATOR_SLEEP_DURATION_MS,
))
.await;
input.to_string()
Ok(input.to_string())
}
}

Expand Down Expand Up @@ -244,8 +245,12 @@ mod tests {
}
}
#[async_trait]
impl Handler<String> for MockDispatchUser {
async fn handle(&mut self, message: String, ctx: &ComponentContext<MockDispatchUser>) {
impl Handler<Result<String, ()>> for MockDispatchUser {
async fn handle(
&mut self,
message: Result<String, ()>,
ctx: &ComponentContext<MockDispatchUser>,
) {
self.counter.fetch_add(1, Ordering::SeqCst);
let curr_count = self.counter.load(Ordering::SeqCst);
// Cancel self
Expand Down
21 changes: 13 additions & 8 deletions rust/worker/src/execution/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,23 @@ where
I: Send + Sync,
O: Send + Sync,
{
async fn run(&self, input: &I) -> O;
type Error;
// It would have been nice to do this with a default trait for result
// but that's not stable in rust yet.
async fn run(&self, input: &I) -> Result<O, Self::Error>;
}

/// A task is a wrapper around an operator and its input.
/// It is a description of a function to be run.
#[derive(Debug)]
struct Task<Input, Output>
struct Task<Input, Output, Error>
where
Input: Send + Sync + Debug,
Output: Send + Sync + Debug,
{
operator: Box<dyn Operator<Input, Output>>,
operator: Box<dyn Operator<Input, Output, Error = Error>>,
input: Input,
reply_channel: Box<dyn Receiver<Output>>,
reply_channel: Box<dyn Receiver<Result<Output, Error>>>,
}

/// A message type used by the dispatcher to send tasks to worker threads.
Expand All @@ -40,8 +43,9 @@ pub(crate) trait TaskWrapper: Send + Debug {
/// erase the I, O types from the Task struct so that tasks can be
/// stored in a homogenous queue regardless of their input and output types.
#[async_trait]
impl<Input, Output> TaskWrapper for Task<Input, Output>
impl<Input, Output, Error> TaskWrapper for Task<Input, Output, Error>
where
Error: Debug,
Input: Send + Sync + Debug,
Output: Send + Sync + Debug,
{
Expand All @@ -53,12 +57,13 @@ where
}

/// Wrap an operator and its input into a task message.
pub(super) fn wrap<Input, Output>(
operator: Box<dyn Operator<Input, Output>>,
pub(super) fn wrap<Input, Output, Error>(
operator: Box<dyn Operator<Input, Output, Error = Error>>,
input: Input,
reply_channel: Box<dyn Receiver<Output>>,
reply_channel: Box<dyn Receiver<Result<Output, Error>>>,
) -> TaskMessage
where
Error: Debug + 'static,
Input: Send + Sync + Debug + 'static,
Output: Send + Sync + Debug + 'static,
{
Expand Down
16 changes: 11 additions & 5 deletions rust/worker/src/execution/operators/pull_log.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::{execution::operator::Operator, log::log::Log, types::EmbeddingRecord};
use crate::{
execution::operator::Operator,
log::log::{Log, PullLogsError},
types::EmbeddingRecord,
};
use async_trait::async_trait;
use uuid::Uuid;

Expand Down Expand Up @@ -66,9 +70,12 @@ impl PullLogsOutput {
}
}

pub type PullLogsResult = Result<PullLogsOutput, PullLogsError>;

#[async_trait]
impl Operator<PullLogsInput, PullLogsOutput> for PullLogsOperator {
async fn run(&self, input: &PullLogsInput) -> PullLogsOutput {
type Error = PullLogsError;
async fn run(&self, input: &PullLogsInput) -> PullLogsResult {
// We expect the log to be cheaply cloneable, we need to clone it since we need
// a mutable reference to it. Not necessarily the best, but it works for our needs.
let mut client_clone = self.client.clone();
Expand All @@ -79,8 +86,7 @@ impl Operator<PullLogsInput, PullLogsOutput> for PullLogsOperator {
input.batch_size,
None,
)
.await
.unwrap();
PullLogsOutput::new(logs)
.await?;
Ok(PullLogsOutput::new(logs))
}
}
26 changes: 18 additions & 8 deletions rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::super::operator::{wrap, TaskMessage};
use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput};
use crate::errors::ChromaError;
use crate::execution::operators::pull_log::PullLogsResult;
use crate::log::log::PullLogsError;
use crate::sysdb::sysdb::SysDb;
use crate::system::System;
use crate::types::VectorQueryResult;
Expand Down Expand Up @@ -102,7 +104,7 @@ impl HnswQueryOrchestrator {
}
}

async fn pull_logs(&mut self, self_address: Box<dyn Receiver<PullLogsOutput>>) {
async fn pull_logs(&mut self, self_address: Box<dyn Receiver<PullLogsResult>>) {
self.state = ExecutionState::PullLogs;
let operator = PullLogsOperator::new(self.log.clone());
let collection_id = match self.get_collection_id_for_segment_id(self.segment_id).await {
Expand Down Expand Up @@ -152,28 +154,36 @@ impl Component for HnswQueryOrchestrator {
// ============== Handlers ==============

#[async_trait]
impl Handler<PullLogsOutput> for HnswQueryOrchestrator {
impl Handler<PullLogsResult> for HnswQueryOrchestrator {
async fn handle(
&mut self,
message: PullLogsOutput,
message: PullLogsResult,
ctx: &crate::system::ComponentContext<HnswQueryOrchestrator>,
) {
self.state = ExecutionState::Dedupe;

// TODO: implement the remaining state transitions and operators
// This is an example of the final state transition and result

match self.result_channel.take() {
Some(tx) => {
let _ = tx.send(Ok(vec![vec![VectorQueryResult {
let result_channel = match self.result_channel.take() {
Some(tx) => tx,
None => {
// Log an error
return;
}
};

match message {
Ok(logs) => {
let _ = result_channel.send(Ok(vec![vec![VectorQueryResult {
id: "abc".to_string(),
seq_id: BigInt::from(0),
distance: 0.0,
vector: Some(vec![0.0, 0.0, 0.0]),
}]]));
}
None => {
// Log an error
Err(e) => {
let _ = result_channel.send(Err(Box::new(e)));
}
}
}
Expand Down

0 comments on commit 3985032

Please sign in to comment.