Skip to content

Commit

Permalink
feat: add cursor statements (#5094)
Browse files Browse the repository at this point in the history
* feat: add sql parsers for cursor operations

* feat: cursor operator

* feat: implement RecordBatchStreamCursor

* feat: implement cursor storage and execution

* test: add tests

* chore: update docstring

* feat: add a temporary sql rewrite for cast in limit

this issue is described in #5097

* test: add more sql for cursor integration test

* feat: reject non-select query for cursor statement

* refactor: address review issues

* test: add empty result case

* feat: address review comments
  • Loading branch information
sunng87 authored and MichaelScofield committed Dec 27, 2024
1 parent 0df75ac commit ee9808b
Show file tree
Hide file tree
Showing 21 changed files with 786 additions and 5 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

173 changes: 173 additions & 0 deletions src/common/recordbatch/src/cursor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use futures::StreamExt;
use tokio::sync::Mutex;

use crate::error::Result;
use crate::recordbatch::merge_record_batches;
use crate::{RecordBatch, SendableRecordBatchStream};

struct Inner {
stream: SendableRecordBatchStream,
current_row_index: usize,
current_batch: Option<RecordBatch>,
total_rows_in_current_batch: usize,
}

/// A cursor on RecordBatchStream that fetches data batch by batch
pub struct RecordBatchStreamCursor {
inner: Mutex<Inner>,
}

impl RecordBatchStreamCursor {
pub fn new(stream: SendableRecordBatchStream) -> RecordBatchStreamCursor {
Self {
inner: Mutex::new(Inner {
stream,
current_row_index: 0,
current_batch: None,
total_rows_in_current_batch: 0,
}),
}
}

/// Take `size` of row from the `RecordBatchStream` and create a new
/// `RecordBatch` for these rows.
pub async fn take(&self, size: usize) -> Result<RecordBatch> {
let mut remaining_rows_to_take = size;
let mut accumulated_rows = Vec::new();

let mut inner = self.inner.lock().await;

while remaining_rows_to_take > 0 {
// Ensure we have a current batch or fetch the next one
if inner.current_batch.is_none()
|| inner.current_row_index >= inner.total_rows_in_current_batch
{
match inner.stream.next().await {
Some(Ok(batch)) => {
inner.total_rows_in_current_batch = batch.num_rows();
inner.current_batch = Some(batch);
inner.current_row_index = 0;
}
Some(Err(e)) => return Err(e),
None => {
// Stream is exhausted
break;
}
}
}

// If we still have no batch after attempting to fetch
let current_batch = match &inner.current_batch {
Some(batch) => batch,
None => break,
};

// Calculate how many rows we can take from this batch
let rows_to_take_from_batch = remaining_rows_to_take
.min(inner.total_rows_in_current_batch - inner.current_row_index);

// Slice the current batch to get the desired rows
let taken_batch =
current_batch.slice(inner.current_row_index, rows_to_take_from_batch)?;

// Add the taken batch to accumulated rows
accumulated_rows.push(taken_batch);

// Update cursor and remaining rows
inner.current_row_index += rows_to_take_from_batch;
remaining_rows_to_take -= rows_to_take_from_batch;
}

// If no rows were accumulated, return empty
if accumulated_rows.is_empty() {
return Ok(RecordBatch::new_empty(inner.stream.schema()));
}

// If only one batch was accumulated, return it directly
if accumulated_rows.len() == 1 {
return Ok(accumulated_rows.remove(0));
}

// Merge multiple batches
merge_record_batches(inner.stream.schema(), &accumulated_rows)
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::StringVector;

use super::*;
use crate::RecordBatches;

#[tokio::test]
async fn test_cursor() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"a",
ConcreteDataType::string_datatype(),
false,
)]));

let rbs = RecordBatches::try_from_columns(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();

let cursor = RecordBatchStreamCursor::new(rbs.as_stream());
let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);

let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);

let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 0);

let rb = RecordBatch::new(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let rbs2 =
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
let cursor = RecordBatchStreamCursor::new(rbs2.as_stream());
let result_rb = cursor.take(3).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 3);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 2);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 0);

let rb = RecordBatch::new(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let rbs3 =
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
let cursor = RecordBatchStreamCursor::new(rbs3.as_stream());
let result_rb = cursor.take(10).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 6);
}
}
10 changes: 9 additions & 1 deletion src/common/recordbatch/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ pub enum Error {
#[snafu(source)]
error: tokio::time::error::Elapsed,
},
#[snafu(display("RecordBatch slice index overflow: {visit_index} > {size}"))]
RecordBatchSliceIndexOverflow {
#[snafu(implicit)]
location: Location,
size: usize,
visit_index: usize,
},
}

impl ErrorExt for Error {
Expand All @@ -182,7 +189,8 @@ impl ErrorExt for Error {
| Error::Format { .. }
| Error::ToArrowScalar { .. }
| Error::ProjectArrowRecordBatch { .. }
| Error::PhysicalExpr { .. } => StatusCode::Internal,
| Error::PhysicalExpr { .. }
| Error::RecordBatchSliceIndexOverflow { .. } => StatusCode::Internal,

Error::PollStream { .. } => StatusCode::EngineExecuteQuery,

Expand Down
1 change: 1 addition & 0 deletions src/common/recordbatch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#![feature(never_type)]

pub mod adapter;
pub mod cursor;
pub mod error;
pub mod filter;
mod recordbatch;
Expand Down
121 changes: 120 additions & 1 deletion src/common/recordbatch/src/recordbatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datatypes::value::Value;
use datatypes::vectors::{Helper, VectorRef};
use serde::ser::{Error, SerializeStruct};
use serde::{Serialize, Serializer};
use snafu::{OptionExt, ResultExt};
use snafu::{ensure, OptionExt, ResultExt};

use crate::error::{
self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu,
Expand Down Expand Up @@ -194,6 +194,19 @@ impl RecordBatch {
.map(|t| t.to_string())
.unwrap_or("failed to pretty display a record batch".to_string())
}

/// Return a slice record batch starts from offset, with len rows
pub fn slice(&self, offset: usize, len: usize) -> Result<RecordBatch> {
ensure!(
offset + len <= self.num_rows(),
error::RecordBatchSliceIndexOverflowSnafu {
size: self.num_rows(),
visit_index: offset + len
}
);
let columns = self.columns.iter().map(|vector| vector.slice(offset, len));
RecordBatch::new(self.schema.clone(), columns)
}
}

impl Serialize for RecordBatch {
Expand Down Expand Up @@ -256,6 +269,36 @@ impl Iterator for RecordBatchRowIterator<'_> {
}
}

/// merge multiple recordbatch into a single
pub fn merge_record_batches(schema: SchemaRef, batches: &[RecordBatch]) -> Result<RecordBatch> {
let batches_len = batches.len();
if batches_len == 0 {
return Ok(RecordBatch::new_empty(schema));
}

let n_rows = batches.iter().map(|b| b.num_rows()).sum();
let n_columns = schema.num_columns();
// Collect arrays from each batch
let mut merged_columns = Vec::with_capacity(n_columns);

for col_idx in 0..n_columns {
let mut acc = schema.column_schemas()[col_idx]
.data_type
.create_mutable_vector(n_rows);

for batch in batches {
let column = batch.column(col_idx);
acc.extend_slice_of(column.as_ref(), 0, column.len())
.context(error::DataTypesSnafu)?;
}

merged_columns.push(acc.to_vector());
}

// Create a new RecordBatch with merged columns
RecordBatch::new(schema, merged_columns)
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down Expand Up @@ -375,4 +418,80 @@ mod tests {

assert!(record_batch_iter.next().is_none());
}

#[test]
fn test_record_batch_slice() {
let column_schemas = vec![
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
];
let schema = Arc::new(Schema::new(column_schemas));
let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let recordbatch = recordbatch.slice(1, 2).expect("recordbatch slice");
let mut record_batch_iter = recordbatch.rows();
assert_eq!(
vec![Value::UInt32(2), Value::String("hello".into())],
record_batch_iter
.next()
.unwrap()
.into_iter()
.collect::<Vec<Value>>()
);

assert_eq!(
vec![Value::UInt32(3), Value::String("greptime".into())],
record_batch_iter
.next()
.unwrap()
.into_iter()
.collect::<Vec<Value>>()
);

assert!(record_batch_iter.next().is_none());

assert!(recordbatch.slice(1, 5).is_err());
}

#[test]
fn test_merge_record_batch() {
let column_schemas = vec![
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
];
let schema = Arc::new(Schema::new(column_schemas));
let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();

let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch2 = RecordBatch::new(schema.clone(), columns).unwrap();

let merged = merge_record_batches(schema.clone(), &[recordbatch, recordbatch2])
.expect("merge recordbatch");
assert_eq!(merged.num_rows(), 8);
}
}
8 changes: 7 additions & 1 deletion src/frontend/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,11 @@ pub fn check_permission(
// TODO(dennis): add a hook for admin commands.
Statement::Admin(_) => {}
// These are executed by query engine, and will be checked there.
Statement::Query(_) | Statement::Explain(_) | Statement::Tql(_) | Statement::Delete(_) => {}
Statement::Query(_)
| Statement::Explain(_)
| Statement::Tql(_)
| Statement::Delete(_)
| Statement::DeclareCursor(_) => {}
// database ops won't be checked
Statement::CreateDatabase(_)
| Statement::ShowDatabases(_)
Expand Down Expand Up @@ -580,6 +584,8 @@ pub fn check_permission(
Statement::TruncateTable(stmt) => {
validate_param(stmt.table_name(), query_ctx)?;
}
// cursor operations are always allowed once it's created
Statement::FetchCursor(_) | Statement::CloseCursor(_) => {}
}
Ok(())
}
Expand Down
Loading

0 comments on commit ee9808b

Please sign in to comment.