Skip to content

Commit

Permalink
[ENH] Implement the PullLog operator (chroma-core#1906)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - ...
 - New functionality
	 - This PR implements the PullLog operator

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
Ishiihara authored Mar 21, 2024
1 parent 3985032 commit e6f3aec
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 14 deletions.
171 changes: 161 additions & 10 deletions rust/worker/src/execution/operators/pull_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ impl PullLogsOperator {
/// * `collection_id` - The collection id to read logs from.
/// * `offset` - The offset to start reading logs from.
/// * `batch_size` - The number of log entries to read.
/// * `num_records` - The maximum number of records to read.
/// * `end_timestamp` - The end timestamp to read logs until.
#[derive(Debug)]
pub struct PullLogsInput {
collection_id: Uuid,
offset: i64,
batch_size: i32,
num_records: Option<i32>,
end_timestamp: Option<i64>,
}

impl PullLogsInput {
Expand All @@ -39,11 +43,21 @@ impl PullLogsInput {
/// * `collection_id` - The collection id to read logs from.
/// * `offset` - The offset to start reading logs from.
/// * `batch_size` - The number of log entries to read.
pub fn new(collection_id: Uuid, offset: i64, batch_size: i32) -> Self {
/// * `num_records` - The maximum number of records to read.
/// * `end_timestamp` - The end timestamp to read logs until.
pub fn new(
collection_id: Uuid,
offset: i64,
batch_size: i32,
num_records: Option<i32>,
end_timestamp: Option<i64>,
) -> Self {
PullLogsInput {
collection_id,
offset,
batch_size,
num_records,
end_timestamp,
}
}
}
Expand Down Expand Up @@ -75,18 +89,155 @@ pub type PullLogsResult = Result<PullLogsOutput, PullLogsError>;
#[async_trait]
impl Operator<PullLogsInput, PullLogsOutput> for PullLogsOperator {
type Error = PullLogsError;

async fn run(&self, input: &PullLogsInput) -> PullLogsResult {
// We expect the log to be cheaply cloneable, we need to clone it since we need
// a mutable reference to it. Not necessarily the best, but it works for our needs.
let mut client_clone = self.client.clone();
let logs = client_clone
.read(
input.collection_id.to_string(),
input.offset,
input.batch_size,
None,
)
.await?;
Ok(PullLogsOutput::new(logs))
let batch_size = input.batch_size;
let mut num_records_read = 0;
let mut offset = input.offset;
let mut result = Vec::new();
loop {
let logs = client_clone
.read(
input.collection_id.to_string(),
offset,
batch_size,
input.end_timestamp,
)
.await;

let mut logs = match logs {
Ok(logs) => logs,
Err(e) => {
return Err(e);
}
};

if logs.is_empty() {
break;
}

num_records_read += logs.len();
offset += batch_size as i64;
result.append(&mut logs);

if input.num_records.is_some()
&& num_records_read >= input.num_records.unwrap() as usize
{
break;
}
}
if input.num_records.is_some() && result.len() > input.num_records.unwrap() as usize {
result.truncate(input.num_records.unwrap() as usize);
}
Ok(PullLogsOutput::new(result))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::log::log::InMemoryLog;
use crate::log::log::LogRecord;
use crate::types::EmbeddingRecord;
use crate::types::Operation;
use num_bigint::BigInt;
use std::str::FromStr;
use uuid::Uuid;

#[tokio::test]
async fn test_pull_logs() {
let mut log = Box::new(InMemoryLog::new());

let collection_uuid_1 = Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap();
let collection_id_1 = collection_uuid_1.to_string();
log.add_log(
collection_id_1.clone(),
Box::new(LogRecord {
collection_id: collection_id_1.clone(),
log_id: 1,
log_id_ts: 1,
record: Box::new(EmbeddingRecord {
id: "embedding_id_1".to_string(),
seq_id: BigInt::from(1),
embedding: None,
encoding: None,
metadata: None,
operation: Operation::Add,
collection_id: collection_uuid_1,
}),
}),
);
log.add_log(
collection_id_1.clone(),
Box::new(LogRecord {
collection_id: collection_id_1.clone(),
log_id: 2,
log_id_ts: 2,
record: Box::new(EmbeddingRecord {
id: "embedding_id_2".to_string(),
seq_id: BigInt::from(2),
embedding: None,
encoding: None,
metadata: None,
operation: Operation::Add,
collection_id: collection_uuid_1,
}),
}),
);

let operator = PullLogsOperator::new(log);

// Pull all logs from collection 1
let input = PullLogsInput::new(collection_uuid_1, 0, 1, None, None);
let output = operator.run(&input).await.unwrap();
assert_eq!(output.logs().len(), 2);

// Pull all logs from collection 1 with a large batch size
let input = PullLogsInput::new(collection_uuid_1, 0, 100, None, None);
let output = operator.run(&input).await.unwrap();
assert_eq!(output.logs().len(), 2);

// Pull logs from collection 1 with a limit
let input = PullLogsInput::new(collection_uuid_1, 0, 1, Some(1), None);
let output = operator.run(&input).await.unwrap();
assert_eq!(output.logs().len(), 1);

// Pull logs from collection 1 with an end timestamp
let input = PullLogsInput::new(collection_uuid_1, 0, 1, None, Some(1));
let output = operator.run(&input).await.unwrap();
assert_eq!(output.logs().len(), 1);

// Pull logs from collection 1 with an end timestamp
let input = PullLogsInput::new(collection_uuid_1, 0, 1, None, Some(2));
let output = operator.run(&input).await.unwrap();
assert_eq!(output.logs().len(), 2);

// Pull logs from collection 1 with an end timestamp and a limit
let input = PullLogsInput::new(collection_uuid_1, 0, 1, Some(1), Some(2));
let output = operator.run(&input).await.unwrap();
assert_eq!(output.logs().len(), 1);

// Pull logs from collection 1 with a limit and a large batch size
let input = PullLogsInput::new(collection_uuid_1, 0, 100, Some(1), None);
let output = operator.run(&input).await.unwrap();
assert_eq!(output.logs().len(), 1);

// Pull logs from collection 1 with an end timestamp and a large batch size
let input = PullLogsInput::new(collection_uuid_1, 0, 100, None, Some(1));
let output = operator.run(&input).await.unwrap();
assert_eq!(output.logs().len(), 1);

// Pull logs from collection 1 with an end timestamp and a large batch size
let input = PullLogsInput::new(collection_uuid_1, 0, 100, None, Some(2));
let output = operator.run(&input).await.unwrap();
assert_eq!(output.logs().len(), 2);

// Pull logs from collection 1 with an end timestamp and a limit and a large batch size
let input = PullLogsInput::new(collection_uuid_1, 0, 100, Some(1), Some(2));
let output = operator.run(&input).await.unwrap();
assert_eq!(output.logs().len(), 1);
}
}
16 changes: 14 additions & 2 deletions rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use super::super::operator::{wrap, TaskMessage};
use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput};
use crate::errors::ChromaError;
use crate::execution::operators::pull_log::PullLogsResult;
use crate::log::log::PullLogsError;
use crate::sysdb::sysdb::SysDb;
use crate::system::System;
use crate::types::VectorQueryResult;
Expand All @@ -13,6 +12,8 @@ use crate::{
use async_trait::async_trait;
use num_bigint::BigInt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;

/** The state of the orchestrator.
Expand Down Expand Up @@ -114,7 +115,16 @@ impl HnswQueryOrchestrator {
return;
}
};
let input = PullLogsInput::new(collection_id, 0, 100);
let end_timestamp = SystemTime::now().duration_since(UNIX_EPOCH);
let end_timestamp = match end_timestamp {
// TODO: change protobuf definition to use u64 instead of i64
Ok(end_timestamp) => end_timestamp.as_secs() as i64,
Err(e) => {
// Log an error and reply + return
return;
}
};
let input = PullLogsInput::new(collection_id, 0, 100, None, Some(end_timestamp));
let task = wrap(operator, input, self_address);
match self.dispatcher.send(task).await {
Ok(_) => (),
Expand Down Expand Up @@ -175,6 +185,8 @@ impl Handler<PullLogsResult> for HnswQueryOrchestrator {

match message {
Ok(logs) => {
// TODO: remove this after debugging
println!("Received logs: {:?}", logs);
let _ = result_channel.send(Ok(vec![vec![VectorQueryResult {
id: "abc".to_string(),
seq_id: BigInt::from(0),
Expand Down
12 changes: 10 additions & 2 deletions rust/worker/src/log/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,18 @@ impl Log for InMemoryLog {
batch_size: i32,
end_timestamp: Option<i64>,
) -> Result<Vec<Box<EmbeddingRecord>>, PullLogsError> {
let logs = self.logs.get(&collection_id).unwrap();
let end_timestamp = match end_timestamp {
Some(end_timestamp) => end_timestamp,
None => i64::MAX,
};

let logs = match self.logs.get(&collection_id) {
Some(logs) => logs,
None => return Ok(Vec::new()),
};
let mut result = Vec::new();
for i in offset..(offset + batch_size as i64) {
if i < logs.len() as i64 {
if i < logs.len() as i64 && logs[i as usize].log_id_ts <= end_timestamp {
result.push(logs[i as usize].record.clone());
}
}
Expand Down

0 comments on commit e6f3aec

Please sign in to comment.