Skip to content

Commit

Permalink
feat: implement RecordBatchStreamCursor
Browse files Browse the repository at this point in the history
  • Loading branch information
sunng87 committed Dec 4, 2024
1 parent a1d3be8 commit 34ede04
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 0 deletions.
107 changes: 107 additions & 0 deletions src/common/recordbatch/src/cursor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use futures::StreamExt;
use tokio::sync::Mutex;

use crate::error::Result;
use crate::recordbatch::merge_record_batches;
use crate::{RecordBatch, SendableRecordBatchStream};

struct Inner {
stream: SendableRecordBatchStream,
current_row_index: usize,
current_batch: Option<RecordBatch>,
total_rows_in_current_batch: usize,
}

pub struct RecordBatchStreamCursor {
inner: Mutex<Inner>,
}

impl RecordBatchStreamCursor {
pub fn new(stream: SendableRecordBatchStream) -> RecordBatchStreamCursor {
Self {
inner: Mutex::new(Inner {
stream,
current_row_index: 0,
current_batch: None,
total_rows_in_current_batch: 0,
}),
}
}

/// Take `size` of row from the `RecordBatchStream` and create a new
/// `RecordBatch` for these rows.
pub async fn take(&self, size: usize) -> Result<RecordBatch> {
let mut remaining_rows_to_take = size;
let mut accumulated_rows = Vec::new();

let mut inner = self.inner.lock().await;

while remaining_rows_to_take > 0 {
// Ensure we have a current batch or fetch the next one
if inner.current_batch.is_none()
|| inner.current_row_index >= inner.total_rows_in_current_batch
{
match inner.stream.next().await {
Some(Ok(batch)) => {
inner.total_rows_in_current_batch = batch.num_rows();
inner.current_batch = Some(batch);
inner.current_row_index = 0;
}
Some(Err(e)) => return Err(e),
None => {
// Stream is exhausted
break;
}
}
}

// If we still have no batch after attempting to fetch
let current_batch = match &inner.current_batch {
Some(batch) => batch,
None => break,
};

// Calculate how many rows we can take from this batch
let rows_to_take_from_batch = remaining_rows_to_take
.min(inner.total_rows_in_current_batch - inner.current_row_index);

// Slice the current batch to get the desired rows
let taken_batch =
current_batch.slice(inner.current_row_index, rows_to_take_from_batch)?;

// Add the taken batch to accumulated rows
accumulated_rows.push(taken_batch);

// Update cursor and remaining rows
inner.current_row_index += rows_to_take_from_batch;
remaining_rows_to_take -= rows_to_take_from_batch;
}

// If no rows were accumulated, return None
if accumulated_rows.is_empty() {
return Ok(RecordBatch::new_empty(inner.stream.schema()));
}

// If only one batch was accumulated, return it directly
if accumulated_rows.len() == 1 {
return Ok(accumulated_rows.remove(0));
}

// Merge multiple batches
merge_record_batches(inner.stream.schema(), &accumulated_rows)
}
}
1 change: 1 addition & 0 deletions src/common/recordbatch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#![feature(never_type)]

pub mod adapter;
pub mod cursor;
pub mod error;
pub mod filter;
mod recordbatch;
Expand Down
36 changes: 36 additions & 0 deletions src/common/recordbatch/src/recordbatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ impl RecordBatch {
.map(|t| t.to_string())
.unwrap_or("failed to pretty display a record batch".to_string())
}

/// Return a slice record batch starts from offset to len
pub fn slice(&self, offset: usize, len: usize) -> Result<RecordBatch> {
let columns = self.columns.iter().map(|vector| vector.slice(offset, len));
RecordBatch::new(self.schema.clone(), columns)
}
}

impl Serialize for RecordBatch {
Expand Down Expand Up @@ -256,6 +262,36 @@ impl Iterator for RecordBatchRowIterator<'_> {
}
}

/// merge multiple recordbatch into a single
pub fn merge_record_batches(schema: SchemaRef, batches: &[RecordBatch]) -> Result<RecordBatch> {
let batches_len = batches.len();
if batches_len == 0 {
return Ok(RecordBatch::new_empty(schema));
}

let n_rows = batches.iter().map(|b| b.num_rows()).sum();
let n_columns = schema.num_columns();
// Collect arrays from each batch
let mut merged_columns = Vec::with_capacity(n_columns);

for col_idx in 0..n_columns {
let mut acc = schema.column_schemas()[col_idx]
.data_type
.create_mutable_vector(n_rows);

for batch in batches {
let column = batch.column(col_idx);
acc.extend_slice_of(column.as_ref(), 0, column.len())
.context(error::DataTypesSnafu)?;
}

merged_columns.push(acc.to_vector());
}

// Create a new RecordBatch with merged columns
RecordBatch::new(schema, merged_columns)
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down

0 comments on commit 34ede04

Please sign in to comment.