diff --git a/src/common/recordbatch/src/cursor.rs b/src/common/recordbatch/src/cursor.rs new file mode 100644 index 000000000000..93fefae7c441 --- /dev/null +++ b/src/common/recordbatch/src/cursor.rs @@ -0,0 +1,107 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use futures::StreamExt; +use tokio::sync::Mutex; + +use crate::error::Result; +use crate::recordbatch::merge_record_batches; +use crate::{RecordBatch, SendableRecordBatchStream}; + +struct Inner { + stream: SendableRecordBatchStream, + current_row_index: usize, + current_batch: Option, + total_rows_in_current_batch: usize, +} + +pub struct RecordBatchStreamCursor { + inner: Mutex, +} + +impl RecordBatchStreamCursor { + pub fn new(stream: SendableRecordBatchStream) -> RecordBatchStreamCursor { + Self { + inner: Mutex::new(Inner { + stream, + current_row_index: 0, + current_batch: None, + total_rows_in_current_batch: 0, + }), + } + } + + /// Take `size` of row from the `RecordBatchStream` and create a new + /// `RecordBatch` for these rows. + pub async fn take(&self, size: usize) -> Result { + let mut remaining_rows_to_take = size; + let mut accumulated_rows = Vec::new(); + + let mut inner = self.inner.lock().await; + + while remaining_rows_to_take > 0 { + // Ensure we have a current batch or fetch the next one + if inner.current_batch.is_none() + || inner.current_row_index >= inner.total_rows_in_current_batch + { + match inner.stream.next().await { + Some(Ok(batch)) => { + inner.total_rows_in_current_batch = batch.num_rows(); + inner.current_batch = Some(batch); + inner.current_row_index = 0; + } + Some(Err(e)) => return Err(e), + None => { + // Stream is exhausted + break; + } + } + } + + // If we still have no batch after attempting to fetch + let current_batch = match &inner.current_batch { + Some(batch) => batch, + None => break, + }; + + // Calculate how many rows we can take from this batch + let rows_to_take_from_batch = remaining_rows_to_take + .min(inner.total_rows_in_current_batch - inner.current_row_index); + + // Slice the current batch to get the desired rows + let taken_batch = + current_batch.slice(inner.current_row_index, rows_to_take_from_batch)?; + + // Add the taken batch to accumulated rows + accumulated_rows.push(taken_batch); + + // Update cursor and remaining rows + inner.current_row_index += rows_to_take_from_batch; + remaining_rows_to_take -= rows_to_take_from_batch; + } + + // If no rows were accumulated, return None + if accumulated_rows.is_empty() { + return Ok(RecordBatch::new_empty(inner.stream.schema())); + } + + // If only one batch was accumulated, return it directly + if accumulated_rows.len() == 1 { + return Ok(accumulated_rows.remove(0)); + } + + // Merge multiple batches + merge_record_batches(inner.stream.schema(), &accumulated_rows) + } +} diff --git a/src/common/recordbatch/src/lib.rs b/src/common/recordbatch/src/lib.rs index 0016e02e94ed..257b6f09732a 100644 --- a/src/common/recordbatch/src/lib.rs +++ b/src/common/recordbatch/src/lib.rs @@ -15,6 +15,7 @@ #![feature(never_type)] pub mod adapter; +pub mod cursor; pub mod error; pub mod filter; mod recordbatch; diff --git a/src/common/recordbatch/src/recordbatch.rs b/src/common/recordbatch/src/recordbatch.rs index 71f7f60685e5..2c3d4c09a1c4 100644 --- a/src/common/recordbatch/src/recordbatch.rs +++ b/src/common/recordbatch/src/recordbatch.rs @@ -194,6 +194,12 @@ impl RecordBatch { .map(|t| t.to_string()) .unwrap_or("failed to pretty display a record batch".to_string()) } + + /// Return a slice record batch starts from offset to len + pub fn slice(&self, offset: usize, len: usize) -> Result { + let columns = self.columns.iter().map(|vector| vector.slice(offset, len)); + RecordBatch::new(self.schema.clone(), columns) + } } impl Serialize for RecordBatch { @@ -256,6 +262,36 @@ impl Iterator for RecordBatchRowIterator<'_> { } } +/// merge multiple recordbatch into a single +pub fn merge_record_batches(schema: SchemaRef, batches: &[RecordBatch]) -> Result { + let batches_len = batches.len(); + if batches_len == 0 { + return Ok(RecordBatch::new_empty(schema)); + } + + let n_rows = batches.iter().map(|b| b.num_rows()).sum(); + let n_columns = schema.num_columns(); + // Collect arrays from each batch + let mut merged_columns = Vec::with_capacity(n_columns); + + for col_idx in 0..n_columns { + let mut acc = schema.column_schemas()[col_idx] + .data_type + .create_mutable_vector(n_rows); + + for batch in batches { + let column = batch.column(col_idx); + acc.extend_slice_of(column.as_ref(), 0, column.len()) + .context(error::DataTypesSnafu)?; + } + + merged_columns.push(acc.to_vector()); + } + + // Create a new RecordBatch with merged columns + RecordBatch::new(schema, merged_columns) +} + #[cfg(test)] mod tests { use std::sync::Arc;