From a7f3151e5ea65a2d80fe5b47bacbe2351ce49649 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 4 Dec 2024 18:55:31 +0800 Subject: [PATCH] test: add tests --- src/common/recordbatch/src/cursor.rs | 68 ++++++++++++++++++- src/common/recordbatch/src/error.rs | 8 ++- src/common/recordbatch/src/recordbatch.rs | 82 ++++++++++++++++++++++- src/frontend/src/instance.rs | 8 ++- tests-integration/tests/sql.rs | 38 +++++++++++ 5 files changed, 200 insertions(+), 4 deletions(-) diff --git a/src/common/recordbatch/src/cursor.rs b/src/common/recordbatch/src/cursor.rs index 93fefae7c441..a741953ccc25 100644 --- a/src/common/recordbatch/src/cursor.rs +++ b/src/common/recordbatch/src/cursor.rs @@ -26,6 +26,7 @@ struct Inner { total_rows_in_current_batch: usize, } +/// A cursor on RecordBatchStream that fetches data batch by batch pub struct RecordBatchStreamCursor { inner: Mutex, } @@ -91,7 +92,7 @@ impl RecordBatchStreamCursor { remaining_rows_to_take -= rows_to_take_from_batch; } - // If no rows were accumulated, return None + // If no rows were accumulated, return empty if accumulated_rows.is_empty() { return Ok(RecordBatch::new_empty(inner.stream.schema())); } @@ -105,3 +106,68 @@ impl RecordBatchStreamCursor { merge_record_batches(inner.stream.schema(), &accumulated_rows) } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datatypes::prelude::ConcreteDataType; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::StringVector; + + use super::*; + use crate::RecordBatches; + + #[tokio::test] + async fn test_cursor() { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "a", + ConcreteDataType::string_datatype(), + false, + )])); + + let rbs = RecordBatches::try_from_columns( + schema.clone(), + vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _], + ) + .unwrap(); + + let cursor = RecordBatchStreamCursor::new(rbs.as_stream()); + let result_rb = cursor.take(1).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 1); + + let result_rb = cursor.take(1).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 1); + + let result_rb = cursor.take(1).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 0); + + let rb = RecordBatch::new( + schema.clone(), + vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _], + ) + .unwrap(); + let rbs2 = + RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap(); + let cursor = RecordBatchStreamCursor::new(rbs2.as_stream()); + let result_rb = cursor.take(3).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 3); + let result_rb = cursor.take(2).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 2); + let result_rb = cursor.take(2).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 1); + let result_rb = cursor.take(2).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 0); + + let rb = RecordBatch::new( + schema.clone(), + vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _], + ) + .unwrap(); + let rbs3 = + RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap(); + let cursor = RecordBatchStreamCursor::new(rbs3.as_stream()); + let result_rb = cursor.take(10).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 6); + } +} diff --git a/src/common/recordbatch/src/error.rs b/src/common/recordbatch/src/error.rs index 6e038d1b7e70..d98131f40a5a 100644 --- a/src/common/recordbatch/src/error.rs +++ b/src/common/recordbatch/src/error.rs @@ -168,6 +168,11 @@ pub enum Error { #[snafu(source)] error: tokio::time::error::Elapsed, }, + #[snafu(display("RecordBatch slice index overflow"))] + RecordBatchSliceIndexOverflow { + #[snafu(implicit)] + location: Location, + }, } impl ErrorExt for Error { @@ -182,7 +187,8 @@ impl ErrorExt for Error { | Error::Format { .. } | Error::ToArrowScalar { .. } | Error::ProjectArrowRecordBatch { .. } - | Error::PhysicalExpr { .. } => StatusCode::Internal, + | Error::PhysicalExpr { .. } + | Error::RecordBatchSliceIndexOverflow { .. } => StatusCode::Internal, Error::PollStream { .. } => StatusCode::EngineExecuteQuery, diff --git a/src/common/recordbatch/src/recordbatch.rs b/src/common/recordbatch/src/recordbatch.rs index 2c3d4c09a1c4..04a13d759c28 100644 --- a/src/common/recordbatch/src/recordbatch.rs +++ b/src/common/recordbatch/src/recordbatch.rs @@ -23,7 +23,7 @@ use datatypes::value::Value; use datatypes::vectors::{Helper, VectorRef}; use serde::ser::{Error, SerializeStruct}; use serde::{Serialize, Serializer}; -use snafu::{OptionExt, ResultExt}; +use snafu::{ensure, OptionExt, ResultExt}; use crate::error::{ self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu, @@ -197,6 +197,10 @@ impl RecordBatch { /// Return a slice record batch starts from offset to len pub fn slice(&self, offset: usize, len: usize) -> Result { + ensure!( + offset + len <= self.num_rows(), + error::RecordBatchSliceIndexOverflowSnafu + ); let columns = self.columns.iter().map(|vector| vector.slice(offset, len)); RecordBatch::new(self.schema.clone(), columns) } @@ -411,4 +415,80 @@ mod tests { assert!(record_batch_iter.next().is_none()); } + + #[test] + fn test_record_batch_slice() { + let column_schemas = vec![ + ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false), + ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true), + ]; + let schema = Arc::new(Schema::new(column_schemas)); + let columns: Vec = vec![ + Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])), + Arc::new(StringVector::from(vec![ + None, + Some("hello"), + Some("greptime"), + None, + ])), + ]; + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + let recordbatch = recordbatch.slice(1, 2).expect("recordbatch slice"); + let mut record_batch_iter = recordbatch.rows(); + assert_eq!( + vec![Value::UInt32(2), Value::String("hello".into())], + record_batch_iter + .next() + .unwrap() + .into_iter() + .collect::>() + ); + + assert_eq!( + vec![Value::UInt32(3), Value::String("greptime".into())], + record_batch_iter + .next() + .unwrap() + .into_iter() + .collect::>() + ); + + assert!(record_batch_iter.next().is_none()); + + assert!(recordbatch.slice(1, 5).is_err()); + } + + #[test] + fn test_merge_record_batch() { + let column_schemas = vec![ + ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false), + ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true), + ]; + let schema = Arc::new(Schema::new(column_schemas)); + let columns: Vec = vec![ + Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])), + Arc::new(StringVector::from(vec![ + None, + Some("hello"), + Some("greptime"), + None, + ])), + ]; + let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap(); + + let columns: Vec = vec![ + Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])), + Arc::new(StringVector::from(vec![ + None, + Some("hello"), + Some("greptime"), + None, + ])), + ]; + let recordbatch2 = RecordBatch::new(schema.clone(), columns).unwrap(); + + let merged = merge_record_batches(schema.clone(), &[recordbatch, recordbatch2]) + .expect("merge recordbatch"); + assert_eq!(merged.num_rows(), 8); + } } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 6ffab3c1f619..b22647989753 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -487,7 +487,11 @@ pub fn check_permission( // TODO(dennis): add a hook for admin commands. Statement::Admin(_) => {} // These are executed by query engine, and will be checked there. - Statement::Query(_) | Statement::Explain(_) | Statement::Tql(_) | Statement::Delete(_) => {} + Statement::Query(_) + | Statement::Explain(_) + | Statement::Tql(_) + | Statement::Delete(_) + | Statement::DeclareCursor(_) => {} // database ops won't be checked Statement::CreateDatabase(_) | Statement::ShowDatabases(_) @@ -580,6 +584,8 @@ pub fn check_permission( Statement::TruncateTable(stmt) => { validate_param(stmt.table_name(), query_ctx)?; } + // cursor operations are always allowed once it's created + Statement::FetchCursor(_) | Statement::CloseCursor(_) => {} } Ok(()) } diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index f15e3743256d..d7dac6a05d21 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -72,6 +72,7 @@ macro_rules! sql_tests { test_postgres_parameter_inference, test_postgres_array_types, test_mysql_prepare_stmt_insert_timestamp, + test_declare_fetch_close_cursor, ); )* }; @@ -1198,3 +1199,40 @@ pub async fn test_postgres_array_types(store_type: StorageType) { let _ = fe_pg_server.shutdown().await; guard.remove_all().await; } + +pub async fn test_declare_fetch_close_cursor(store_type: StorageType) { + let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await; + + let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) + .await + .unwrap(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + connection.await.unwrap(); + tx.send(()).unwrap(); + }); + + client + .execute( + "DECLARE c1 CURSOR FOR SELECT * FROM numbers WHERE number > 2", + &[], + ) + .await + .expect("declare cursor"); + + let rows = client.query("FETCH 5 FROM c1", &[]).await.unwrap(); + assert_eq!(5, rows.len()); + + let rows = client.query("FETCH 100 FROM c1", &[]).await.unwrap(); + assert_eq!(92, rows.len()); + + client.execute("CLOSE c1", &[]).await.expect("close cursor"); + + // Shutdown the client. + drop(client); + rx.await.unwrap(); + + let _ = fe_pg_server.shutdown().await; + guard.remove_all().await; +}