diff --git a/rust/worker/src/execution/operators/pull_log.rs b/rust/worker/src/execution/operators/pull_log.rs index 7fb150fd34c..b172ebfe606 100644 --- a/rust/worker/src/execution/operators/pull_log.rs +++ b/rust/worker/src/execution/operators/pull_log.rs @@ -26,11 +26,15 @@ impl PullLogsOperator { /// * `collection_id` - The collection id to read logs from. /// * `offset` - The offset to start reading logs from. /// * `batch_size` - The number of log entries to read. +/// * `num_records` - The maximum number of records to read. +/// * `end_timestamp` - The end timestamp to read logs until. #[derive(Debug)] pub struct PullLogsInput { collection_id: Uuid, offset: i64, batch_size: i32, + num_records: Option, + end_timestamp: Option, } impl PullLogsInput { @@ -39,11 +43,21 @@ impl PullLogsInput { /// * `collection_id` - The collection id to read logs from. /// * `offset` - The offset to start reading logs from. /// * `batch_size` - The number of log entries to read. - pub fn new(collection_id: Uuid, offset: i64, batch_size: i32) -> Self { + /// * `num_records` - The maximum number of records to read. + /// * `end_timestamp` - The end timestamp to read logs until. + pub fn new( + collection_id: Uuid, + offset: i64, + batch_size: i32, + num_records: Option, + end_timestamp: Option, + ) -> Self { PullLogsInput { collection_id, offset, batch_size, + num_records, + end_timestamp, } } } @@ -75,18 +89,155 @@ pub type PullLogsResult = Result; #[async_trait] impl Operator for PullLogsOperator { type Error = PullLogsError; + async fn run(&self, input: &PullLogsInput) -> PullLogsResult { // We expect the log to be cheaply cloneable, we need to clone it since we need // a mutable reference to it. Not necessarily the best, but it works for our needs. let mut client_clone = self.client.clone(); - let logs = client_clone - .read( - input.collection_id.to_string(), - input.offset, - input.batch_size, - None, - ) - .await?; - Ok(PullLogsOutput::new(logs)) + let batch_size = input.batch_size; + let mut num_records_read = 0; + let mut offset = input.offset; + let mut result = Vec::new(); + loop { + let logs = client_clone + .read( + input.collection_id.to_string(), + offset, + batch_size, + input.end_timestamp, + ) + .await; + + let mut logs = match logs { + Ok(logs) => logs, + Err(e) => { + return Err(e); + } + }; + + if logs.is_empty() { + break; + } + + num_records_read += logs.len(); + offset += batch_size as i64; + result.append(&mut logs); + + if input.num_records.is_some() + && num_records_read >= input.num_records.unwrap() as usize + { + break; + } + } + if input.num_records.is_some() && result.len() > input.num_records.unwrap() as usize { + result.truncate(input.num_records.unwrap() as usize); + } + Ok(PullLogsOutput::new(result)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::log::log::InMemoryLog; + use crate::log::log::LogRecord; + use crate::types::EmbeddingRecord; + use crate::types::Operation; + use num_bigint::BigInt; + use std::str::FromStr; + use uuid::Uuid; + + #[tokio::test] + async fn test_pull_logs() { + let mut log = Box::new(InMemoryLog::new()); + + let collection_uuid_1 = Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(); + let collection_id_1 = collection_uuid_1.to_string(); + log.add_log( + collection_id_1.clone(), + Box::new(LogRecord { + collection_id: collection_id_1.clone(), + log_id: 1, + log_id_ts: 1, + record: Box::new(EmbeddingRecord { + id: "embedding_id_1".to_string(), + seq_id: BigInt::from(1), + embedding: None, + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: collection_uuid_1, + }), + }), + ); + log.add_log( + collection_id_1.clone(), + Box::new(LogRecord { + collection_id: collection_id_1.clone(), + log_id: 2, + log_id_ts: 2, + record: Box::new(EmbeddingRecord { + id: "embedding_id_2".to_string(), + seq_id: BigInt::from(2), + embedding: None, + encoding: None, + metadata: None, + operation: Operation::Add, + collection_id: collection_uuid_1, + }), + }), + ); + + let operator = PullLogsOperator::new(log); + + // Pull all logs from collection 1 + let input = PullLogsInput::new(collection_uuid_1, 0, 1, None, None); + let output = operator.run(&input).await.unwrap(); + assert_eq!(output.logs().len(), 2); + + // Pull all logs from collection 1 with a large batch size + let input = PullLogsInput::new(collection_uuid_1, 0, 100, None, None); + let output = operator.run(&input).await.unwrap(); + assert_eq!(output.logs().len(), 2); + + // Pull logs from collection 1 with a limit + let input = PullLogsInput::new(collection_uuid_1, 0, 1, Some(1), None); + let output = operator.run(&input).await.unwrap(); + assert_eq!(output.logs().len(), 1); + + // Pull logs from collection 1 with an end timestamp + let input = PullLogsInput::new(collection_uuid_1, 0, 1, None, Some(1)); + let output = operator.run(&input).await.unwrap(); + assert_eq!(output.logs().len(), 1); + + // Pull logs from collection 1 with an end timestamp + let input = PullLogsInput::new(collection_uuid_1, 0, 1, None, Some(2)); + let output = operator.run(&input).await.unwrap(); + assert_eq!(output.logs().len(), 2); + + // Pull logs from collection 1 with an end timestamp and a limit + let input = PullLogsInput::new(collection_uuid_1, 0, 1, Some(1), Some(2)); + let output = operator.run(&input).await.unwrap(); + assert_eq!(output.logs().len(), 1); + + // Pull logs from collection 1 with a limit and a large batch size + let input = PullLogsInput::new(collection_uuid_1, 0, 100, Some(1), None); + let output = operator.run(&input).await.unwrap(); + assert_eq!(output.logs().len(), 1); + + // Pull logs from collection 1 with an end timestamp and a large batch size + let input = PullLogsInput::new(collection_uuid_1, 0, 100, None, Some(1)); + let output = operator.run(&input).await.unwrap(); + assert_eq!(output.logs().len(), 1); + + // Pull logs from collection 1 with an end timestamp and a large batch size + let input = PullLogsInput::new(collection_uuid_1, 0, 100, None, Some(2)); + let output = operator.run(&input).await.unwrap(); + assert_eq!(output.logs().len(), 2); + + // Pull logs from collection 1 with an end timestamp and a limit and a large batch size + let input = PullLogsInput::new(collection_uuid_1, 0, 100, Some(1), Some(2)); + let output = operator.run(&input).await.unwrap(); + assert_eq!(output.logs().len(), 1); } } diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index 35c4134c940..f583506ad1f 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -2,7 +2,6 @@ use super::super::operator::{wrap, TaskMessage}; use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput}; use crate::errors::ChromaError; use crate::execution::operators::pull_log::PullLogsResult; -use crate::log::log::PullLogsError; use crate::sysdb::sysdb::SysDb; use crate::system::System; use crate::types::VectorQueryResult; @@ -13,6 +12,8 @@ use crate::{ use async_trait::async_trait; use num_bigint::BigInt; use std::fmt::Debug; +use std::fmt::Formatter; +use std::time::{SystemTime, UNIX_EPOCH}; use uuid::Uuid; /** The state of the orchestrator. @@ -114,7 +115,16 @@ impl HnswQueryOrchestrator { return; } }; - let input = PullLogsInput::new(collection_id, 0, 100); + let end_timestamp = SystemTime::now().duration_since(UNIX_EPOCH); + let end_timestamp = match end_timestamp { + // TODO: change protobuf definition to use u64 instead of i64 + Ok(end_timestamp) => end_timestamp.as_secs() as i64, + Err(e) => { + // Log an error and reply + return + return; + } + }; + let input = PullLogsInput::new(collection_id, 0, 100, None, Some(end_timestamp)); let task = wrap(operator, input, self_address); match self.dispatcher.send(task).await { Ok(_) => (), @@ -175,6 +185,8 @@ impl Handler for HnswQueryOrchestrator { match message { Ok(logs) => { + // TODO: remove this after debugging + println!("Received logs: {:?}", logs); let _ = result_channel.send(Ok(vec![vec![VectorQueryResult { id: "abc".to_string(), seq_id: BigInt::from(0), diff --git a/rust/worker/src/log/log.rs b/rust/worker/src/log/log.rs index 9ece2deb205..56a7da319e2 100644 --- a/rust/worker/src/log/log.rs +++ b/rust/worker/src/log/log.rs @@ -269,10 +269,18 @@ impl Log for InMemoryLog { batch_size: i32, end_timestamp: Option, ) -> Result>, PullLogsError> { - let logs = self.logs.get(&collection_id).unwrap(); + let end_timestamp = match end_timestamp { + Some(end_timestamp) => end_timestamp, + None => i64::MAX, + }; + + let logs = match self.logs.get(&collection_id) { + Some(logs) => logs, + None => return Ok(Vec::new()), + }; let mut result = Vec::new(); for i in offset..(offset + batch_size as i64) { - if i < logs.len() as i64 { + if i < logs.len() as i64 && logs[i as usize].log_id_ts <= end_timestamp { result.push(logs[i as usize].record.clone()); } }