Skip to content

Commit

Permalink
test: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sunng87 committed Dec 4, 2024
1 parent 780fc03 commit a7f3151
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 4 deletions.
68 changes: 67 additions & 1 deletion src/common/recordbatch/src/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct Inner {
total_rows_in_current_batch: usize,
}

/// A cursor on RecordBatchStream that fetches data batch by batch
pub struct RecordBatchStreamCursor {
inner: Mutex<Inner>,
}
Expand Down Expand Up @@ -91,7 +92,7 @@ impl RecordBatchStreamCursor {
remaining_rows_to_take -= rows_to_take_from_batch;
}

// If no rows were accumulated, return None
// If no rows were accumulated, return empty
if accumulated_rows.is_empty() {
return Ok(RecordBatch::new_empty(inner.stream.schema()));
}
Expand All @@ -105,3 +106,68 @@ impl RecordBatchStreamCursor {
merge_record_batches(inner.stream.schema(), &accumulated_rows)
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::StringVector;

use super::*;
use crate::RecordBatches;

#[tokio::test]
async fn test_cursor() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"a",
ConcreteDataType::string_datatype(),
false,
)]));

let rbs = RecordBatches::try_from_columns(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();

let cursor = RecordBatchStreamCursor::new(rbs.as_stream());
let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);

let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);

let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 0);

let rb = RecordBatch::new(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let rbs2 =
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
let cursor = RecordBatchStreamCursor::new(rbs2.as_stream());
let result_rb = cursor.take(3).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 3);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 2);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 0);

let rb = RecordBatch::new(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let rbs3 =
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
let cursor = RecordBatchStreamCursor::new(rbs3.as_stream());
let result_rb = cursor.take(10).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 6);
}
}
8 changes: 7 additions & 1 deletion src/common/recordbatch/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ pub enum Error {
#[snafu(source)]
error: tokio::time::error::Elapsed,
},
#[snafu(display("RecordBatch slice index overflow"))]
RecordBatchSliceIndexOverflow {
#[snafu(implicit)]
location: Location,
},
}

impl ErrorExt for Error {
Expand All @@ -182,7 +187,8 @@ impl ErrorExt for Error {
| Error::Format { .. }
| Error::ToArrowScalar { .. }
| Error::ProjectArrowRecordBatch { .. }
| Error::PhysicalExpr { .. } => StatusCode::Internal,
| Error::PhysicalExpr { .. }
| Error::RecordBatchSliceIndexOverflow { .. } => StatusCode::Internal,

Error::PollStream { .. } => StatusCode::EngineExecuteQuery,

Expand Down
82 changes: 81 additions & 1 deletion src/common/recordbatch/src/recordbatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datatypes::value::Value;
use datatypes::vectors::{Helper, VectorRef};
use serde::ser::{Error, SerializeStruct};
use serde::{Serialize, Serializer};
use snafu::{OptionExt, ResultExt};
use snafu::{ensure, OptionExt, ResultExt};

use crate::error::{
self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu,
Expand Down Expand Up @@ -197,6 +197,10 @@ impl RecordBatch {

/// Return a slice record batch starts from offset to len
pub fn slice(&self, offset: usize, len: usize) -> Result<RecordBatch> {
ensure!(
offset + len <= self.num_rows(),
error::RecordBatchSliceIndexOverflowSnafu
);
let columns = self.columns.iter().map(|vector| vector.slice(offset, len));
RecordBatch::new(self.schema.clone(), columns)
}
Expand Down Expand Up @@ -411,4 +415,80 @@ mod tests {

assert!(record_batch_iter.next().is_none());
}

#[test]
fn test_record_batch_slice() {
let column_schemas = vec![
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
];
let schema = Arc::new(Schema::new(column_schemas));
let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let recordbatch = recordbatch.slice(1, 2).expect("recordbatch slice");
let mut record_batch_iter = recordbatch.rows();
assert_eq!(
vec![Value::UInt32(2), Value::String("hello".into())],
record_batch_iter
.next()
.unwrap()
.into_iter()
.collect::<Vec<Value>>()
);

assert_eq!(
vec![Value::UInt32(3), Value::String("greptime".into())],
record_batch_iter
.next()
.unwrap()
.into_iter()
.collect::<Vec<Value>>()
);

assert!(record_batch_iter.next().is_none());

assert!(recordbatch.slice(1, 5).is_err());
}

#[test]
fn test_merge_record_batch() {
let column_schemas = vec![
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
];
let schema = Arc::new(Schema::new(column_schemas));
let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();

let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch2 = RecordBatch::new(schema.clone(), columns).unwrap();

let merged = merge_record_batches(schema.clone(), &[recordbatch, recordbatch2])
.expect("merge recordbatch");
assert_eq!(merged.num_rows(), 8);
}
}
8 changes: 7 additions & 1 deletion src/frontend/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,11 @@ pub fn check_permission(
// TODO(dennis): add a hook for admin commands.
Statement::Admin(_) => {}
// These are executed by query engine, and will be checked there.
Statement::Query(_) | Statement::Explain(_) | Statement::Tql(_) | Statement::Delete(_) => {}
Statement::Query(_)
| Statement::Explain(_)
| Statement::Tql(_)
| Statement::Delete(_)
| Statement::DeclareCursor(_) => {}
// database ops won't be checked
Statement::CreateDatabase(_)
| Statement::ShowDatabases(_)
Expand Down Expand Up @@ -580,6 +584,8 @@ pub fn check_permission(
Statement::TruncateTable(stmt) => {
validate_param(stmt.table_name(), query_ctx)?;
}
// cursor operations are always allowed once it's created
Statement::FetchCursor(_) | Statement::CloseCursor(_) => {}
}
Ok(())
}
Expand Down
38 changes: 38 additions & 0 deletions tests-integration/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ macro_rules! sql_tests {
test_postgres_parameter_inference,
test_postgres_array_types,
test_mysql_prepare_stmt_insert_timestamp,
test_declare_fetch_close_cursor,
);
)*
};
Expand Down Expand Up @@ -1198,3 +1199,40 @@ pub async fn test_postgres_array_types(store_type: StorageType) {
let _ = fe_pg_server.shutdown().await;
guard.remove_all().await;
}

pub async fn test_declare_fetch_close_cursor(store_type: StorageType) {
let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await;

let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls)
.await
.unwrap();

let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
connection.await.unwrap();
tx.send(()).unwrap();
});

client
.execute(
"DECLARE c1 CURSOR FOR SELECT * FROM numbers WHERE number > 2",
&[],
)
.await
.expect("declare cursor");

let rows = client.query("FETCH 5 FROM c1", &[]).await.unwrap();
assert_eq!(5, rows.len());

let rows = client.query("FETCH 100 FROM c1", &[]).await.unwrap();
assert_eq!(92, rows.len());

client.execute("CLOSE c1", &[]).await.expect("close cursor");

// Shutdown the client.
drop(client);
rx.await.unwrap();

let _ = fe_pg_server.shutdown().await;
guard.remove_all().await;
}

0 comments on commit a7f3151

Please sign in to comment.