diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index 080e17031ff..65ab2a68ff0 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -4,8 +4,8 @@ version = "0.1.0" edition = "2021" [[bin]] -name = "worker" -path = "src/bin/worker.rs" +name = "query_service" +path = "src/bin/query_service.rs" [dependencies] tonic = "0.10" @@ -46,4 +46,4 @@ proptest-state-machine = "0.1.0" [build-dependencies] tonic-build = "0.10" -cc = "1.0" \ No newline at end of file +cc = "1.0" diff --git a/rust/worker/chroma_config.yaml b/rust/worker/chroma_config.yaml index fa6d41fa069..151d2128779 100644 --- a/rust/worker/chroma_config.yaml +++ b/rust/worker/chroma_config.yaml @@ -33,3 +33,7 @@ worker: Grpc: host: "logservice.chroma" port: 50052 + dispatcher: + num_worker_threads: 4 + dispatcher_queue_size: 100 + worker_queue_size: 100 diff --git a/rust/worker/src/bin/query_service.rs b/rust/worker/src/bin/query_service.rs new file mode 100644 index 00000000000..f3cfa4c8282 --- /dev/null +++ b/rust/worker/src/bin/query_service.rs @@ -0,0 +1,6 @@ +use worker::query_service_entrypoint; + +#[tokio::main] +async fn main() { + query_service_entrypoint().await; +} diff --git a/rust/worker/src/bin/worker.rs b/rust/worker/src/bin/worker.rs deleted file mode 100644 index 16428d244ff..00000000000 --- a/rust/worker/src/bin/worker.rs +++ /dev/null @@ -1,6 +0,0 @@ -use worker::worker_entrypoint; - -#[tokio::main] -async fn main() { - worker_entrypoint().await; -} diff --git a/rust/worker/src/config.rs b/rust/worker/src/config.rs index 309155bfb9d..637c5cac98a 100644 --- a/rust/worker/src/config.rs +++ b/rust/worker/src/config.rs @@ -110,6 +110,7 @@ pub(crate) struct WorkerConfig { pub(crate) segment_manager: crate::segment::config::SegmentManagerConfig, pub(crate) storage: crate::storage::config::StorageConfig, pub(crate) log: crate::log::config::LogConfig, + pub(crate) dispatcher: crate::execution::config::DispatcherConfig, } /// # Description @@ -165,6 +166,10 @@ mod tests { Grpc: host: "localhost" port: 50052 + dispatcher: + num_worker_threads: 4 + dispatcher_queue_size: 100 + worker_queue_size: 100 "#, ); let config = RootConfig::load(); @@ -213,6 +218,10 @@ mod tests { Grpc: host: "localhost" port: 50052 + dispatcher: + num_worker_threads: 4 + dispatcher_queue_size: 100 + worker_queue_size: 100 "#, ); @@ -277,6 +286,10 @@ mod tests { Grpc: host: "localhost" port: 50052 + dispatcher: + num_worker_threads: 4 + dispatcher_queue_size: 100 + worker_queue_size: 100 "#, ); let config = RootConfig::load(); @@ -321,6 +334,10 @@ mod tests { Grpc: host: "localhost" port: 50052 + dispatcher: + num_worker_threads: 4 + dispatcher_queue_size: 100 + worker_queue_size: 100 "#, ); let config = RootConfig::load(); diff --git a/rust/worker/src/errors.rs b/rust/worker/src/errors.rs index 18365cb789f..968dbaeeb17 100644 --- a/rust/worker/src/errors.rs +++ b/rust/worker/src/errors.rs @@ -42,6 +42,6 @@ pub(crate) enum ErrorCodes { DataLoss = 15, } -pub(crate) trait ChromaError: Error { +pub(crate) trait ChromaError: Error + Send { fn code(&self) -> ErrorCodes; } diff --git a/rust/worker/src/execution/config.rs b/rust/worker/src/execution/config.rs new file mode 100644 index 00000000000..d8550dc41bc --- /dev/null +++ b/rust/worker/src/execution/config.rs @@ -0,0 +1,8 @@ +use serde::Deserialize; + +#[derive(Deserialize)] +pub(crate) struct DispatcherConfig { + pub(crate) num_worker_threads: usize, + pub(crate) dispatcher_queue_size: usize, + pub(crate) worker_queue_size: usize, +} diff --git a/rust/worker/src/execution/dispatcher.rs b/rust/worker/src/execution/dispatcher.rs index 1fe94b255c1..8a25e0b26fe 100644 --- a/rust/worker/src/execution/dispatcher.rs +++ b/rust/worker/src/execution/dispatcher.rs @@ -1,5 +1,9 @@ use super::{operator::TaskMessage, worker_thread::WorkerThread}; -use crate::system::{Component, ComponentContext, Handler, Receiver, System}; +use crate::{ + config::{Configurable, WorkerConfig}, + errors::ChromaError, + system::{Component, ComponentContext, Handler, Receiver, System}, +}; use async_trait::async_trait; use std::fmt::Debug; @@ -46,21 +50,27 @@ use std::fmt::Debug; coarser work-stealing, and other optimizations. */ #[derive(Debug)] -struct Dispatcher { +pub(crate) struct Dispatcher { task_queue: Vec, waiters: Vec, n_worker_threads: usize, + queue_size: usize, + worker_queue_size: usize, } impl Dispatcher { /// Create a new dispatcher /// # Parameters /// - n_worker_threads: The number of worker threads to use - pub fn new(n_worker_threads: usize) -> Self { + /// - queue_size: The size of the components message queue + /// - worker_queue_size: The size of the worker components queue + pub fn new(n_worker_threads: usize, queue_size: usize, worker_queue_size: usize) -> Self { Dispatcher { task_queue: Vec::new(), waiters: Vec::new(), n_worker_threads, + queue_size, + worker_queue_size, } } @@ -74,7 +84,7 @@ impl Dispatcher { self_receiver: Box>, ) { for _ in 0..self.n_worker_threads { - let worker = WorkerThread::new(self_receiver.clone()); + let worker = WorkerThread::new(self_receiver.clone(), self.worker_queue_size); system.start_component(worker); } } @@ -118,6 +128,17 @@ impl Dispatcher { } } +#[async_trait] +impl Configurable for Dispatcher { + async fn try_from_config(worker_config: &WorkerConfig) -> Result> { + Ok(Dispatcher::new( + worker_config.dispatcher.num_worker_threads, + worker_config.dispatcher.dispatcher_queue_size, + worker_config.dispatcher.worker_queue_size, + )) + } +} + /// A message that a worker thread sends to the dispatcher to request a task /// # Members /// - reply_to: The receiver to send the task to, this is the worker thread @@ -141,7 +162,7 @@ impl TaskRequestMessage { #[async_trait] impl Component for Dispatcher { fn queue_size(&self) -> usize { - 1000 // TODO: make configurable + self.queue_size } async fn on_start(&mut self, ctx: &ComponentContext) { @@ -166,19 +187,15 @@ impl Handler for Dispatcher { #[cfg(test)] mod tests { - use std::{ - env::current_dir, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - }; - use super::*; use crate::{ execution::operator::{wrap, Operator}, system::System, }; + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; // Create a component that will schedule DISPATCH_COUNT invocations of the MockOperator // on an interval of DISPATCH_FREQUENCY_MS. @@ -249,7 +266,7 @@ mod tests { #[tokio::test] async fn test_dispatcher() { let mut system = System::new(); - let dispatcher = Dispatcher::new(THREAD_COUNT); + let dispatcher = Dispatcher::new(THREAD_COUNT, 1000, 1000); let dispatcher_handle = system.start_component(dispatcher); let counter = Arc::new(AtomicUsize::new(0)); let dispatch_user = MockDispatchUser { diff --git a/rust/worker/src/execution/mod.rs b/rust/worker/src/execution/mod.rs index 3bd82a311ab..0000e23f3a3 100644 --- a/rust/worker/src/execution/mod.rs +++ b/rust/worker/src/execution/mod.rs @@ -1,5 +1,6 @@ -mod dispatcher; -mod operator; +pub(crate) mod config; +pub(crate) mod dispatcher; +pub(crate) mod operator; mod operators; -mod orchestration; +pub(crate) mod orchestration; mod worker_thread; diff --git a/rust/worker/src/execution/operator.rs b/rust/worker/src/execution/operator.rs index 10f7321684f..85baa7d8c7d 100644 --- a/rust/worker/src/execution/operator.rs +++ b/rust/worker/src/execution/operator.rs @@ -27,12 +27,12 @@ where } /// A message type used by the dispatcher to send tasks to worker threads. -pub(super) type TaskMessage = Box; +pub(crate) type TaskMessage = Box; /// A task wrapper is a trait that can be used to run a task. We use it to /// erase the I, O types from the Task struct so that tasks. #[async_trait] -pub(super) trait TaskWrapper: Send + Debug { +pub(crate) trait TaskWrapper: Send + Debug { async fn run(&self); } diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index 7dd2aebb7bb..d25862bc9e1 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -1,12 +1,16 @@ use super::super::operator::{wrap, TaskMessage}; use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput}; +use crate::errors::ChromaError; use crate::sysdb::sysdb::SysDb; +use crate::system::System; +use crate::types::VectorQueryResult; use crate::{ log::log::Log, system::{Component, Handler, Receiver}, }; use async_trait::async_trait; -use std::fmt::{self, Debug, Formatter}; +use num_bigint::BigInt; +use std::fmt::Debug; use uuid::Uuid; /** The state of the orchestrator. @@ -35,8 +39,10 @@ enum ExecutionState { } #[derive(Debug)] -struct HnswQueryOrchestrator { +pub(crate) struct HnswQueryOrchestrator { state: ExecutionState, + // Component Execution + system: System, // Query state query_vectors: Vec>, k: i32, @@ -46,10 +52,15 @@ struct HnswQueryOrchestrator { log: Box, sysdb: Box, dispatcher: Box>, + // Result channel + result_channel: Option< + tokio::sync::oneshot::Sender>, Box>>, + >, } impl HnswQueryOrchestrator { - pub fn new( + pub(crate) fn new( + system: System, query_vectors: Vec>, k: i32, include_embeddings: bool, @@ -60,6 +71,7 @@ impl HnswQueryOrchestrator { ) -> Self { HnswQueryOrchestrator { state: ExecutionState::Pending, + system, query_vectors, k, include_embeddings, @@ -67,6 +79,7 @@ impl HnswQueryOrchestrator { log, sysdb, dispatcher, + result_channel: None, } } @@ -108,6 +121,19 @@ impl HnswQueryOrchestrator { } } } + + /// Run the orchestrator and return the result. + /// # Note + /// Use this over spawning the component directly. This method will start the component and + /// wait for it to finish before returning the result. + pub(crate) async fn run(mut self) -> Result>, Box> { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.result_channel = Some(tx); + let mut handle = self.system.clone().start_component(self); + let result = rx.await; + handle.stop(); + result.unwrap() + } } // ============== Component Implementation ============== @@ -133,6 +159,22 @@ impl Handler for HnswQueryOrchestrator { ctx: &crate::system::ComponentContext, ) { self.state = ExecutionState::Dedupe; + // TODO: implement the remaining state transitions and operators + // This is an example of the final state transition and result + + match self.result_channel.take() { + Some(tx) => { + let _ = tx.send(Ok(vec![vec![VectorQueryResult { + id: "abc".to_string(), + seq_id: BigInt::from(0), + distance: 0.0, + vector: Some(vec![0.0, 0.0, 0.0]), + }]])); + } + None => { + // Log an error + } + } } } diff --git a/rust/worker/src/execution/orchestration/mod.rs b/rust/worker/src/execution/orchestration/mod.rs index e0c45e2e87c..902c3eaf84d 100644 --- a/rust/worker/src/execution/orchestration/mod.rs +++ b/rust/worker/src/execution/orchestration/mod.rs @@ -1 +1,3 @@ mod hnsw; + +pub(crate) use hnsw::*; diff --git a/rust/worker/src/execution/worker_thread.rs b/rust/worker/src/execution/worker_thread.rs index 7a5c0fcbe92..d651a725d34 100644 --- a/rust/worker/src/execution/worker_thread.rs +++ b/rust/worker/src/execution/worker_thread.rs @@ -9,11 +9,18 @@ use std::fmt::{Debug, Formatter, Result}; /// - The actor loop will block until work is available pub(super) struct WorkerThread { dispatcher: Box>, + queue_size: usize, } impl WorkerThread { - pub(super) fn new(dispatcher: Box>) -> Self { - WorkerThread { dispatcher } + pub(super) fn new( + dispatcher: Box>, + queue_size: usize, + ) -> WorkerThread { + WorkerThread { + dispatcher, + queue_size, + } } } @@ -26,7 +33,7 @@ impl Debug for WorkerThread { #[async_trait] impl Component for WorkerThread { fn queue_size(&self) -> usize { - 1000 // TODO: make configurable + self.queue_size } fn runtime() -> ComponentRuntime { diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index 0af68c5e06c..1cb31b3f3b2 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -15,15 +15,50 @@ mod sysdb; mod system; mod types; +use crate::sysdb::sysdb::SysDb; use config::Configurable; use memberlist::MemberlistProvider; -use crate::sysdb::sysdb::SysDb; - mod chroma_proto { tonic::include_proto!("chroma"); } +pub async fn query_service_entrypoint() { + let config = config::RootConfig::load(); + let system: system::System = system::System::new(); + let segment_manager = match segment::SegmentManager::try_from_config(&config.worker).await { + Ok(segment_manager) => segment_manager, + Err(err) => { + println!("Failed to create segment manager component: {:?}", err); + return; + } + }; + let dispatcher = match execution::dispatcher::Dispatcher::try_from_config(&config.worker).await + { + Ok(dispatcher) => dispatcher, + Err(err) => { + println!("Failed to create dispatcher component: {:?}", err); + return; + } + }; + let mut dispatcher_handle = system.start_component(dispatcher); + let mut worker_server = match server::WorkerServer::try_from_config(&config.worker).await { + Ok(worker_server) => worker_server, + Err(err) => { + println!("Failed to create worker server component: {:?}", err); + return; + } + }; + worker_server.set_segment_manager(segment_manager.clone()); + worker_server.set_dispatcher(dispatcher_handle.receiver()); + + let server_join_handle = tokio::spawn(async move { + crate::server::WorkerServer::run(worker_server).await; + }); + + let _ = tokio::join!(server_join_handle, dispatcher_handle.join()); +} + pub async fn worker_entrypoint() { let config = config::RootConfig::load(); // Create all the core components and start them @@ -103,5 +138,6 @@ pub async fn worker_entrypoint() { ingest_handle.join(), memberlist_handle.join(), scheduler_handler.join(), + server_join_handle, ); } diff --git a/rust/worker/src/log/mod.rs b/rust/worker/src/log/mod.rs index c7873c00ce9..cd769734c48 100644 --- a/rust/worker/src/log/mod.rs +++ b/rust/worker/src/log/mod.rs @@ -1,2 +1,17 @@ pub(crate) mod config; pub(crate) mod log; + +use crate::{ + config::{Configurable, WorkerConfig}, + errors::ChromaError, +}; + +pub(crate) async fn from_config( + config: &WorkerConfig, +) -> Result, Box> { + match &config.log { + crate::log::config::LogConfig::Grpc(_) => { + Ok(Box::new(log::GrpcLog::try_from_config(config).await?)) + } + } +} diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 1ecc6ba2e70..205a51b6a97 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -1,28 +1,53 @@ -use std::f32::consts::E; - use crate::chroma_proto; use crate::chroma_proto::{ GetVectorsRequest, GetVectorsResponse, QueryVectorsRequest, QueryVectorsResponse, }; use crate::config::{Configurable, WorkerConfig}; use crate::errors::ChromaError; +use crate::execution::operator::TaskMessage; +use crate::execution::orchestration::HnswQueryOrchestrator; +use crate::log::log::Log; use crate::segment::SegmentManager; +use crate::sysdb::sysdb::SysDb; +use crate::system::{Receiver, System}; use crate::types::ScalarEncoding; use async_trait::async_trait; -use kube::core::request; use tonic::{transport::Server, Request, Response, Status}; use uuid::Uuid; pub struct WorkerServer { + // System + system: Option, + // Component dependencies segment_manager: Option, + dispatcher: Option>>, + // Service dependencies + log: Box, + sysdb: Box, port: u16, } #[async_trait] impl Configurable for WorkerServer { async fn try_from_config(config: &WorkerConfig) -> Result> { + let sysdb = match crate::sysdb::from_config(&config).await { + Ok(sysdb) => sysdb, + Err(err) => { + return Err(err); + } + }; + let log = match crate::log::from_config(&config).await { + Ok(log) => log, + Err(err) => { + return Err(err); + } + }; Ok(WorkerServer { segment_manager: None, + dispatcher: None, + system: None, + sysdb, + log, port: config.my_port, }) } @@ -46,6 +71,14 @@ impl WorkerServer { pub(crate) fn set_segment_manager(&mut self, segment_manager: SegmentManager) { self.segment_manager = Some(segment_manager); } + + pub(crate) fn set_dispatcher(&mut self, dispatcher: Box>) { + self.dispatcher = Some(dispatcher); + } + + pub(crate) fn set_system(&mut self, system: System) { + self.system = Some(system); + } } #[tonic::async_trait] @@ -126,6 +159,8 @@ impl chroma_proto::vector_reader_server::VectorReader for WorkerServer { }; let mut proto_results_for_all = Vec::new(); + + let mut query_vectors = Vec::new(); for proto_query_vector in request.vectors { let (query_vector, encoding) = match proto_query_vector.try_into() { Ok((vector, encoding)) => (vector, encoding), @@ -133,31 +168,58 @@ impl chroma_proto::vector_reader_server::VectorReader for WorkerServer { return Err(Status::internal(format!("Error converting vector: {}", e))); } }; + query_vectors.push(query_vector); + } + + let dispatcher = match self.dispatcher { + Some(ref dispatcher) => dispatcher, + None => { + return Err(Status::internal("No dispatcher found")); + } + }; - let results = match segment_manager - .query_vector( - &segment_uuid, - &query_vector, - request.k as usize, + let result = match self.system { + Some(ref system) => { + let orchestrator = HnswQueryOrchestrator::new( + // TODO: Should not have to clone query vectors here + system.clone(), + query_vectors.clone(), + request.k, request.include_embeddings, - ) - .await - { - Ok(results) => results, - Err(e) => { - return Err(Status::internal(format!("Error querying segment: {}", e))); - } - }; + segment_uuid, + self.log.clone(), + self.sysdb.clone(), + dispatcher.clone(), + ); + orchestrator.run().await + } + None => { + return Err(Status::internal("No system found")); + } + }; + let result = match result { + Ok(result) => result, + Err(e) => { + return Err(Status::internal(format!( + "Error running orchestrator: {}", + e + ))); + } + }; + + for result_set in result { let mut proto_results = Vec::new(); - for query_result in results { + for query_result in result_set { let proto_result = chroma_proto::VectorQueryResult { id: query_result.id, seq_id: query_result.seq_id.to_bytes_le().1, distance: query_result.distance, vector: match query_result.vector { Some(vector) => { - match (vector, ScalarEncoding::FLOAT32, query_vector.len()).try_into() { + match (vector, ScalarEncoding::FLOAT32, query_vectors[0].len()) + .try_into() + { Ok(proto_vector) => Some(proto_vector), Err(e) => { return Err(Status::internal(format!( @@ -172,11 +234,9 @@ impl chroma_proto::vector_reader_server::VectorReader for WorkerServer { }; proto_results.push(proto_result); } - - let vector_query_results = chroma_proto::VectorQueryResults { + proto_results_for_all.push(chroma_proto::VectorQueryResults { results: proto_results, - }; - proto_results_for_all.push(vector_query_results); + }); } let resp = chroma_proto::QueryVectorsResponse { diff --git a/rust/worker/src/sysdb/mod.rs b/rust/worker/src/sysdb/mod.rs index 1db5510f893..770fa5cc208 100644 --- a/rust/worker/src/sysdb/mod.rs +++ b/rust/worker/src/sysdb/mod.rs @@ -1,2 +1,17 @@ pub(crate) mod config; pub(crate) mod sysdb; + +use crate::{ + config::{Configurable, WorkerConfig}, + errors::ChromaError, +}; + +pub(crate) async fn from_config( + config: &WorkerConfig, +) -> Result, Box> { + match &config.sysdb { + crate::sysdb::config::SysDbConfig::Grpc(_) => { + Ok(Box::new(sysdb::GrpcSysDb::try_from_config(config).await?)) + } + } +} diff --git a/rust/worker/src/sysdb/sysdb.rs b/rust/worker/src/sysdb/sysdb.rs index 450761e8896..990268e66ec 100644 --- a/rust/worker/src/sysdb/sysdb.rs +++ b/rust/worker/src/sysdb/sysdb.rs @@ -1,3 +1,4 @@ +use super::config::SysDbConfig; use crate::chroma_proto; use crate::config::{Configurable, WorkerConfig}; use crate::types::{CollectionConversionError, SegmentConversionError}; @@ -11,8 +12,6 @@ use std::fmt::Debug; use thiserror::Error; use uuid::Uuid; -use super::config::SysDbConfig; - const DEFAULT_DATBASE: &str = "default_database"; const DEFAULT_TENANT: &str = "default_tenant"; diff --git a/rust/worker/src/system/scheduler.rs b/rust/worker/src/system/scheduler.rs index f69d8fa3450..aab82e5d098 100644 --- a/rust/worker/src/system/scheduler.rs +++ b/rust/worker/src/system/scheduler.rs @@ -1,6 +1,5 @@ use parking_lot::RwLock; use std::fmt::Debug; -use std::num; use std::sync::Arc; use std::time::Duration; use tokio::select; @@ -10,12 +9,13 @@ use super::{ executor::ComponentExecutor, sender::Sender, system::System, Receiver, ReceiverImpl, Wrapper, }; +#[derive(Debug)] pub(crate) struct SchedulerTaskHandle { join_handle: Option>, cancel: tokio_util::sync::CancellationToken, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub(crate) struct Scheduler { handles: Arc>>, } diff --git a/rust/worker/src/system/system.rs b/rust/worker/src/system/system.rs index 42c7c565046..0d9f4738625 100644 --- a/rust/worker/src/system/system.rs +++ b/rust/worker/src/system/system.rs @@ -10,11 +10,12 @@ use std::sync::Arc; use tokio::runtime::Builder; use tokio::{pin, select}; -#[derive(Clone)] +#[derive(Clone, Debug)] pub(crate) struct System { inner: Arc, } +#[derive(Debug)] struct Inner { scheduler: Scheduler, } @@ -28,7 +29,7 @@ impl System { } } - pub(crate) fn start_component(&mut self, component: C) -> ComponentHandle + pub(crate) fn start_component(&self, component: C) -> ComponentHandle where C: Component + Send + 'static, {