Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add cursor statements #5094

Merged
merged 12 commits into from
Dec 6, 2024
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

173 changes: 173 additions & 0 deletions src/common/recordbatch/src/cursor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use futures::StreamExt;
use tokio::sync::Mutex;

use crate::error::Result;
use crate::recordbatch::merge_record_batches;
use crate::{RecordBatch, SendableRecordBatchStream};

struct Inner {
stream: SendableRecordBatchStream,
current_row_index: usize,
current_batch: Option<RecordBatch>,
total_rows_in_current_batch: usize,
}

/// A cursor on RecordBatchStream that fetches data batch by batch
pub struct RecordBatchStreamCursor {
inner: Mutex<Inner>,
}

impl RecordBatchStreamCursor {
pub fn new(stream: SendableRecordBatchStream) -> RecordBatchStreamCursor {
Self {
inner: Mutex::new(Inner {
stream,
current_row_index: 0,
current_batch: None,
total_rows_in_current_batch: 0,
}),
}
}

/// Take `size` of row from the `RecordBatchStream` and create a new
/// `RecordBatch` for these rows.
pub async fn take(&self, size: usize) -> Result<RecordBatch> {
let mut remaining_rows_to_take = size;
let mut accumulated_rows = Vec::new();

let mut inner = self.inner.lock().await;

while remaining_rows_to_take > 0 {
// Ensure we have a current batch or fetch the next one
if inner.current_batch.is_none()
|| inner.current_row_index >= inner.total_rows_in_current_batch
{
match inner.stream.next().await {
Some(Ok(batch)) => {
inner.total_rows_in_current_batch = batch.num_rows();
inner.current_batch = Some(batch);
inner.current_row_index = 0;
}
Some(Err(e)) => return Err(e),
None => {
// Stream is exhausted
break;
}
}
}

// If we still have no batch after attempting to fetch
let current_batch = match &inner.current_batch {
Some(batch) => batch,
None => break,
sunng87 marked this conversation as resolved.
Show resolved Hide resolved
};

// Calculate how many rows we can take from this batch
let rows_to_take_from_batch = remaining_rows_to_take
.min(inner.total_rows_in_current_batch - inner.current_row_index);
killme2008 marked this conversation as resolved.
Show resolved Hide resolved

// Slice the current batch to get the desired rows
let taken_batch =
current_batch.slice(inner.current_row_index, rows_to_take_from_batch)?;

// Add the taken batch to accumulated rows
accumulated_rows.push(taken_batch);

// Update cursor and remaining rows
inner.current_row_index += rows_to_take_from_batch;
remaining_rows_to_take -= rows_to_take_from_batch;
}

// If no rows were accumulated, return empty
if accumulated_rows.is_empty() {
return Ok(RecordBatch::new_empty(inner.stream.schema()));
}

// If only one batch was accumulated, return it directly
if accumulated_rows.len() == 1 {
return Ok(accumulated_rows.remove(0));
}

// Merge multiple batches
merge_record_batches(inner.stream.schema(), &accumulated_rows)
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::StringVector;

use super::*;
use crate::RecordBatches;

#[tokio::test]
async fn test_cursor() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"a",
ConcreteDataType::string_datatype(),
false,
)]));

let rbs = RecordBatches::try_from_columns(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();

let cursor = RecordBatchStreamCursor::new(rbs.as_stream());
let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);

let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);

let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 0);

let rb = RecordBatch::new(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let rbs2 =
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
let cursor = RecordBatchStreamCursor::new(rbs2.as_stream());
let result_rb = cursor.take(3).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 3);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 2);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 0);

let rb = RecordBatch::new(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let rbs3 =
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
let cursor = RecordBatchStreamCursor::new(rbs3.as_stream());
let result_rb = cursor.take(10).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 6);
}
}
10 changes: 9 additions & 1 deletion src/common/recordbatch/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ pub enum Error {
#[snafu(source)]
error: tokio::time::error::Elapsed,
},
#[snafu(display("RecordBatch slice index overflow: {visit_index} > {size}"))]
RecordBatchSliceIndexOverflow {
sunng87 marked this conversation as resolved.
Show resolved Hide resolved
#[snafu(implicit)]
location: Location,
size: usize,
visit_index: usize,
},
}

impl ErrorExt for Error {
Expand All @@ -182,7 +189,8 @@ impl ErrorExt for Error {
| Error::Format { .. }
| Error::ToArrowScalar { .. }
| Error::ProjectArrowRecordBatch { .. }
| Error::PhysicalExpr { .. } => StatusCode::Internal,
| Error::PhysicalExpr { .. }
| Error::RecordBatchSliceIndexOverflow { .. } => StatusCode::Internal,

Error::PollStream { .. } => StatusCode::EngineExecuteQuery,

Expand Down
1 change: 1 addition & 0 deletions src/common/recordbatch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#![feature(never_type)]

pub mod adapter;
pub mod cursor;
pub mod error;
pub mod filter;
mod recordbatch;
Expand Down
121 changes: 120 additions & 1 deletion src/common/recordbatch/src/recordbatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datatypes::value::Value;
use datatypes::vectors::{Helper, VectorRef};
use serde::ser::{Error, SerializeStruct};
use serde::{Serialize, Serializer};
use snafu::{OptionExt, ResultExt};
use snafu::{ensure, OptionExt, ResultExt};

use crate::error::{
self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu,
Expand Down Expand Up @@ -194,6 +194,19 @@ impl RecordBatch {
.map(|t| t.to_string())
.unwrap_or("failed to pretty display a record batch".to_string())
}

/// Return a slice record batch starts from offset, with len rows
pub fn slice(&self, offset: usize, len: usize) -> Result<RecordBatch> {
ensure!(
offset + len <= self.num_rows(),
error::RecordBatchSliceIndexOverflowSnafu {
size: self.num_rows(),
visit_index: offset + len
}
);
let columns = self.columns.iter().map(|vector| vector.slice(offset, len));
RecordBatch::new(self.schema.clone(), columns)
}
}

impl Serialize for RecordBatch {
Expand Down Expand Up @@ -256,6 +269,36 @@ impl Iterator for RecordBatchRowIterator<'_> {
}
}

/// merge multiple recordbatch into a single
pub fn merge_record_batches(schema: SchemaRef, batches: &[RecordBatch]) -> Result<RecordBatch> {
let batches_len = batches.len();
if batches_len == 0 {
return Ok(RecordBatch::new_empty(schema));
}

let n_rows = batches.iter().map(|b| b.num_rows()).sum();
let n_columns = schema.num_columns();
// Collect arrays from each batch
let mut merged_columns = Vec::with_capacity(n_columns);

for col_idx in 0..n_columns {
let mut acc = schema.column_schemas()[col_idx]
.data_type
.create_mutable_vector(n_rows);

for batch in batches {
let column = batch.column(col_idx);
acc.extend_slice_of(column.as_ref(), 0, column.len())
.context(error::DataTypesSnafu)?;
}

merged_columns.push(acc.to_vector());
}

// Create a new RecordBatch with merged columns
RecordBatch::new(schema, merged_columns)
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down Expand Up @@ -375,4 +418,80 @@ mod tests {

assert!(record_batch_iter.next().is_none());
}

#[test]
fn test_record_batch_slice() {
let column_schemas = vec![
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
];
let schema = Arc::new(Schema::new(column_schemas));
let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let recordbatch = recordbatch.slice(1, 2).expect("recordbatch slice");
let mut record_batch_iter = recordbatch.rows();
assert_eq!(
vec![Value::UInt32(2), Value::String("hello".into())],
record_batch_iter
.next()
.unwrap()
.into_iter()
.collect::<Vec<Value>>()
);

assert_eq!(
vec![Value::UInt32(3), Value::String("greptime".into())],
record_batch_iter
.next()
.unwrap()
.into_iter()
.collect::<Vec<Value>>()
);

assert!(record_batch_iter.next().is_none());

assert!(recordbatch.slice(1, 5).is_err());
}

#[test]
fn test_merge_record_batch() {
let column_schemas = vec![
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
];
let schema = Arc::new(Schema::new(column_schemas));
let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();

let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch2 = RecordBatch::new(schema.clone(), columns).unwrap();

let merged = merge_record_batches(schema.clone(), &[recordbatch, recordbatch2])
.expect("merge recordbatch");
assert_eq!(merged.num_rows(), 8);
}
}
8 changes: 7 additions & 1 deletion src/frontend/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,11 @@ pub fn check_permission(
// TODO(dennis): add a hook for admin commands.
Statement::Admin(_) => {}
// These are executed by query engine, and will be checked there.
Statement::Query(_) | Statement::Explain(_) | Statement::Tql(_) | Statement::Delete(_) => {}
Statement::Query(_)
| Statement::Explain(_)
| Statement::Tql(_)
| Statement::Delete(_)
| Statement::DeclareCursor(_) => {}
// database ops won't be checked
Statement::CreateDatabase(_)
| Statement::ShowDatabases(_)
Expand Down Expand Up @@ -580,6 +584,8 @@ pub fn check_permission(
Statement::TruncateTable(stmt) => {
validate_param(stmt.table_name(), query_ctx)?;
}
// cursor operations are always allowed once it's created
Statement::FetchCursor(_) | Statement::CloseCursor(_) => {}
}
Ok(())
}
Expand Down
Loading
Loading