From 3133f3fb4e39edd94b58d2a350cc9380b3ff259d Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Fri, 6 Dec 2024 17:32:22 +0800 Subject: [PATCH] feat: add cursor statements (#5094) * feat: add sql parsers for cursor operations * feat: cursor operator * feat: implement RecordBatchStreamCursor * feat: implement cursor storage and execution * test: add tests * chore: update docstring * feat: add a temporary sql rewrite for cast in limit this issue is described in #5097 * test: add more sql for cursor integration test * feat: reject non-select query for cursor statement * refactor: address review issues * test: add empty result case * feat: address review comments --- Cargo.lock | 2 + src/common/recordbatch/src/cursor.rs | 173 ++++++++++++++++++++++ src/common/recordbatch/src/error.rs | 10 +- src/common/recordbatch/src/lib.rs | 1 + src/common/recordbatch/src/recordbatch.rs | 121 ++++++++++++++- src/frontend/src/instance.rs | 8 +- src/operator/src/error.rs | 10 +- src/operator/src/statement.rs | 11 ++ src/operator/src/statement/cursor.rs | 98 ++++++++++++ src/servers/src/postgres/fixtures.rs | 17 +++ src/servers/src/postgres/handler.rs | 6 + src/session/Cargo.toml | 2 + src/session/src/context.rs | 25 ++++ src/session/src/lib.rs | 6 + src/sql/src/parser.rs | 6 + src/sql/src/parsers.rs | 1 + src/sql/src/parsers/cursor_parser.rs | 157 ++++++++++++++++++++ src/sql/src/statements.rs | 3 +- src/sql/src/statements/cursor.rs | 60 ++++++++ src/sql/src/statements/statement.rs | 10 ++ tests-integration/tests/sql.rs | 64 ++++++++ 21 files changed, 786 insertions(+), 5 deletions(-) create mode 100644 src/common/recordbatch/src/cursor.rs create mode 100644 src/operator/src/statement/cursor.rs create mode 100644 src/sql/src/parsers/cursor_parser.rs create mode 100644 src/sql/src/statements/cursor.rs diff --git a/Cargo.lock b/Cargo.lock index f677ee269d4e..16a234728983 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10987,9 +10987,11 @@ dependencies = [ "common-catalog", "common-error", "common-macro", + "common-recordbatch", "common-telemetry", "common-time", "derive_builder 0.12.0", + "derive_more", "meter-core", "snafu 0.8.5", "sql", diff --git a/src/common/recordbatch/src/cursor.rs b/src/common/recordbatch/src/cursor.rs new file mode 100644 index 000000000000..a741953ccc25 --- /dev/null +++ b/src/common/recordbatch/src/cursor.rs @@ -0,0 +1,173 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use futures::StreamExt; +use tokio::sync::Mutex; + +use crate::error::Result; +use crate::recordbatch::merge_record_batches; +use crate::{RecordBatch, SendableRecordBatchStream}; + +struct Inner { + stream: SendableRecordBatchStream, + current_row_index: usize, + current_batch: Option, + total_rows_in_current_batch: usize, +} + +/// A cursor on RecordBatchStream that fetches data batch by batch +pub struct RecordBatchStreamCursor { + inner: Mutex, +} + +impl RecordBatchStreamCursor { + pub fn new(stream: SendableRecordBatchStream) -> RecordBatchStreamCursor { + Self { + inner: Mutex::new(Inner { + stream, + current_row_index: 0, + current_batch: None, + total_rows_in_current_batch: 0, + }), + } + } + + /// Take `size` of row from the `RecordBatchStream` and create a new + /// `RecordBatch` for these rows. + pub async fn take(&self, size: usize) -> Result { + let mut remaining_rows_to_take = size; + let mut accumulated_rows = Vec::new(); + + let mut inner = self.inner.lock().await; + + while remaining_rows_to_take > 0 { + // Ensure we have a current batch or fetch the next one + if inner.current_batch.is_none() + || inner.current_row_index >= inner.total_rows_in_current_batch + { + match inner.stream.next().await { + Some(Ok(batch)) => { + inner.total_rows_in_current_batch = batch.num_rows(); + inner.current_batch = Some(batch); + inner.current_row_index = 0; + } + Some(Err(e)) => return Err(e), + None => { + // Stream is exhausted + break; + } + } + } + + // If we still have no batch after attempting to fetch + let current_batch = match &inner.current_batch { + Some(batch) => batch, + None => break, + }; + + // Calculate how many rows we can take from this batch + let rows_to_take_from_batch = remaining_rows_to_take + .min(inner.total_rows_in_current_batch - inner.current_row_index); + + // Slice the current batch to get the desired rows + let taken_batch = + current_batch.slice(inner.current_row_index, rows_to_take_from_batch)?; + + // Add the taken batch to accumulated rows + accumulated_rows.push(taken_batch); + + // Update cursor and remaining rows + inner.current_row_index += rows_to_take_from_batch; + remaining_rows_to_take -= rows_to_take_from_batch; + } + + // If no rows were accumulated, return empty + if accumulated_rows.is_empty() { + return Ok(RecordBatch::new_empty(inner.stream.schema())); + } + + // If only one batch was accumulated, return it directly + if accumulated_rows.len() == 1 { + return Ok(accumulated_rows.remove(0)); + } + + // Merge multiple batches + merge_record_batches(inner.stream.schema(), &accumulated_rows) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datatypes::prelude::ConcreteDataType; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::vectors::StringVector; + + use super::*; + use crate::RecordBatches; + + #[tokio::test] + async fn test_cursor() { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "a", + ConcreteDataType::string_datatype(), + false, + )])); + + let rbs = RecordBatches::try_from_columns( + schema.clone(), + vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _], + ) + .unwrap(); + + let cursor = RecordBatchStreamCursor::new(rbs.as_stream()); + let result_rb = cursor.take(1).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 1); + + let result_rb = cursor.take(1).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 1); + + let result_rb = cursor.take(1).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 0); + + let rb = RecordBatch::new( + schema.clone(), + vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _], + ) + .unwrap(); + let rbs2 = + RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap(); + let cursor = RecordBatchStreamCursor::new(rbs2.as_stream()); + let result_rb = cursor.take(3).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 3); + let result_rb = cursor.take(2).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 2); + let result_rb = cursor.take(2).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 1); + let result_rb = cursor.take(2).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 0); + + let rb = RecordBatch::new( + schema.clone(), + vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _], + ) + .unwrap(); + let rbs3 = + RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap(); + let cursor = RecordBatchStreamCursor::new(rbs3.as_stream()); + let result_rb = cursor.take(10).await.expect("take from cursor failed"); + assert_eq!(result_rb.num_rows(), 6); + } +} diff --git a/src/common/recordbatch/src/error.rs b/src/common/recordbatch/src/error.rs index 6e038d1b7e70..6a1c61c0a0f0 100644 --- a/src/common/recordbatch/src/error.rs +++ b/src/common/recordbatch/src/error.rs @@ -168,6 +168,13 @@ pub enum Error { #[snafu(source)] error: tokio::time::error::Elapsed, }, + #[snafu(display("RecordBatch slice index overflow: {visit_index} > {size}"))] + RecordBatchSliceIndexOverflow { + #[snafu(implicit)] + location: Location, + size: usize, + visit_index: usize, + }, } impl ErrorExt for Error { @@ -182,7 +189,8 @@ impl ErrorExt for Error { | Error::Format { .. } | Error::ToArrowScalar { .. } | Error::ProjectArrowRecordBatch { .. } - | Error::PhysicalExpr { .. } => StatusCode::Internal, + | Error::PhysicalExpr { .. } + | Error::RecordBatchSliceIndexOverflow { .. } => StatusCode::Internal, Error::PollStream { .. } => StatusCode::EngineExecuteQuery, diff --git a/src/common/recordbatch/src/lib.rs b/src/common/recordbatch/src/lib.rs index 0016e02e94ed..257b6f09732a 100644 --- a/src/common/recordbatch/src/lib.rs +++ b/src/common/recordbatch/src/lib.rs @@ -15,6 +15,7 @@ #![feature(never_type)] pub mod adapter; +pub mod cursor; pub mod error; pub mod filter; mod recordbatch; diff --git a/src/common/recordbatch/src/recordbatch.rs b/src/common/recordbatch/src/recordbatch.rs index 71f7f60685e5..4641cc0d9a60 100644 --- a/src/common/recordbatch/src/recordbatch.rs +++ b/src/common/recordbatch/src/recordbatch.rs @@ -23,7 +23,7 @@ use datatypes::value::Value; use datatypes::vectors::{Helper, VectorRef}; use serde::ser::{Error, SerializeStruct}; use serde::{Serialize, Serializer}; -use snafu::{OptionExt, ResultExt}; +use snafu::{ensure, OptionExt, ResultExt}; use crate::error::{ self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu, @@ -194,6 +194,19 @@ impl RecordBatch { .map(|t| t.to_string()) .unwrap_or("failed to pretty display a record batch".to_string()) } + + /// Return a slice record batch starts from offset, with len rows + pub fn slice(&self, offset: usize, len: usize) -> Result { + ensure!( + offset + len <= self.num_rows(), + error::RecordBatchSliceIndexOverflowSnafu { + size: self.num_rows(), + visit_index: offset + len + } + ); + let columns = self.columns.iter().map(|vector| vector.slice(offset, len)); + RecordBatch::new(self.schema.clone(), columns) + } } impl Serialize for RecordBatch { @@ -256,6 +269,36 @@ impl Iterator for RecordBatchRowIterator<'_> { } } +/// merge multiple recordbatch into a single +pub fn merge_record_batches(schema: SchemaRef, batches: &[RecordBatch]) -> Result { + let batches_len = batches.len(); + if batches_len == 0 { + return Ok(RecordBatch::new_empty(schema)); + } + + let n_rows = batches.iter().map(|b| b.num_rows()).sum(); + let n_columns = schema.num_columns(); + // Collect arrays from each batch + let mut merged_columns = Vec::with_capacity(n_columns); + + for col_idx in 0..n_columns { + let mut acc = schema.column_schemas()[col_idx] + .data_type + .create_mutable_vector(n_rows); + + for batch in batches { + let column = batch.column(col_idx); + acc.extend_slice_of(column.as_ref(), 0, column.len()) + .context(error::DataTypesSnafu)?; + } + + merged_columns.push(acc.to_vector()); + } + + // Create a new RecordBatch with merged columns + RecordBatch::new(schema, merged_columns) +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -375,4 +418,80 @@ mod tests { assert!(record_batch_iter.next().is_none()); } + + #[test] + fn test_record_batch_slice() { + let column_schemas = vec![ + ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false), + ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true), + ]; + let schema = Arc::new(Schema::new(column_schemas)); + let columns: Vec = vec![ + Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])), + Arc::new(StringVector::from(vec![ + None, + Some("hello"), + Some("greptime"), + None, + ])), + ]; + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + let recordbatch = recordbatch.slice(1, 2).expect("recordbatch slice"); + let mut record_batch_iter = recordbatch.rows(); + assert_eq!( + vec![Value::UInt32(2), Value::String("hello".into())], + record_batch_iter + .next() + .unwrap() + .into_iter() + .collect::>() + ); + + assert_eq!( + vec![Value::UInt32(3), Value::String("greptime".into())], + record_batch_iter + .next() + .unwrap() + .into_iter() + .collect::>() + ); + + assert!(record_batch_iter.next().is_none()); + + assert!(recordbatch.slice(1, 5).is_err()); + } + + #[test] + fn test_merge_record_batch() { + let column_schemas = vec![ + ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false), + ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true), + ]; + let schema = Arc::new(Schema::new(column_schemas)); + let columns: Vec = vec![ + Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])), + Arc::new(StringVector::from(vec![ + None, + Some("hello"), + Some("greptime"), + None, + ])), + ]; + let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap(); + + let columns: Vec = vec![ + Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])), + Arc::new(StringVector::from(vec![ + None, + Some("hello"), + Some("greptime"), + None, + ])), + ]; + let recordbatch2 = RecordBatch::new(schema.clone(), columns).unwrap(); + + let merged = merge_record_batches(schema.clone(), &[recordbatch, recordbatch2]) + .expect("merge recordbatch"); + assert_eq!(merged.num_rows(), 8); + } } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index ad387cc5dd96..b22bde96e0ff 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -487,7 +487,11 @@ pub fn check_permission( // TODO(dennis): add a hook for admin commands. Statement::Admin(_) => {} // These are executed by query engine, and will be checked there. - Statement::Query(_) | Statement::Explain(_) | Statement::Tql(_) | Statement::Delete(_) => {} + Statement::Query(_) + | Statement::Explain(_) + | Statement::Tql(_) + | Statement::Delete(_) + | Statement::DeclareCursor(_) => {} // database ops won't be checked Statement::CreateDatabase(_) | Statement::ShowDatabases(_) @@ -580,6 +584,8 @@ pub fn check_permission( Statement::TruncateTable(stmt) => { validate_param(stmt.table_name(), query_ctx)?; } + // cursor operations are always allowed once it's created + Statement::FetchCursor(_) | Statement::CloseCursor(_) => {} } Ok(()) } diff --git a/src/operator/src/error.rs b/src/operator/src/error.rs index 48bc7a81c221..3a5aae897399 100644 --- a/src/operator/src/error.rs +++ b/src/operator/src/error.rs @@ -786,6 +786,12 @@ pub enum Error { #[snafu(source)] error: Elapsed, }, + + #[snafu(display("Cursor {name} is not found"))] + CursorNotFound { name: String }, + + #[snafu(display("A cursor named {name} already exists"))] + CursorExists { name: String }, } pub type Result = std::result::Result; @@ -825,7 +831,9 @@ impl ErrorExt for Error { | Error::FunctionArityMismatch { .. } | Error::InvalidPartition { .. } | Error::PhysicalExpr { .. } - | Error::InvalidJsonFormat { .. } => StatusCode::InvalidArguments, + | Error::InvalidJsonFormat { .. } + | Error::CursorNotFound { .. } + | Error::CursorExists { .. } => StatusCode::InvalidArguments, Error::TableAlreadyExists { .. } | Error::ViewAlreadyExists { .. } => { StatusCode::TableAlreadyExists diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index 64417dbd6b0d..b3251ca6bf2c 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -16,6 +16,7 @@ mod admin; mod copy_database; mod copy_table_from; mod copy_table_to; +mod cursor; mod ddl; mod describe; mod dml; @@ -133,6 +134,16 @@ impl StatementExecutor { self.plan_exec(QueryStatement::Sql(stmt), query_ctx).await } + Statement::DeclareCursor(declare_cursor) => { + self.declare_cursor(declare_cursor, query_ctx).await + } + Statement::FetchCursor(fetch_cursor) => { + self.fetch_cursor(fetch_cursor, query_ctx).await + } + Statement::CloseCursor(close_cursor) => { + self.close_cursor(close_cursor, query_ctx).await + } + Statement::Insert(insert) => self.insert(insert, query_ctx).await, Statement::Tql(tql) => self.execute_tql(tql, query_ctx).await, diff --git a/src/operator/src/statement/cursor.rs b/src/operator/src/statement/cursor.rs new file mode 100644 index 000000000000..85de4ef36697 --- /dev/null +++ b/src/operator/src/statement/cursor.rs @@ -0,0 +1,98 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_query::{Output, OutputData}; +use common_recordbatch::cursor::RecordBatchStreamCursor; +use common_recordbatch::RecordBatches; +use common_telemetry::tracing; +use query::parser::QueryStatement; +use session::context::QueryContextRef; +use snafu::ResultExt; +use sql::statements::cursor::{CloseCursor, DeclareCursor, FetchCursor}; +use sql::statements::statement::Statement; + +use crate::error::{self, Result}; +use crate::statement::StatementExecutor; + +impl StatementExecutor { + #[tracing::instrument(skip_all)] + pub(super) async fn declare_cursor( + &self, + declare_cursor: DeclareCursor, + query_ctx: QueryContextRef, + ) -> Result { + let cursor_name = declare_cursor.cursor_name.to_string(); + + if query_ctx.get_cursor(&cursor_name).is_some() { + error::CursorExistsSnafu { + name: cursor_name.to_string(), + } + .fail()?; + } + + let query_stmt = Statement::Query(declare_cursor.query); + + let output = self + .plan_exec(QueryStatement::Sql(query_stmt), query_ctx.clone()) + .await?; + match output.data { + OutputData::RecordBatches(rb) => { + let rbs = rb.as_stream(); + query_ctx.insert_cursor(cursor_name, RecordBatchStreamCursor::new(rbs)); + } + OutputData::Stream(rbs) => { + query_ctx.insert_cursor(cursor_name, RecordBatchStreamCursor::new(rbs)); + } + // Should not happen because we have query type ensured from parser. + OutputData::AffectedRows(_) => error::NotSupportedSnafu { + feat: "Non-query statement on cursor", + } + .fail()?, + } + + Ok(Output::new_with_affected_rows(0)) + } + + #[tracing::instrument(skip_all)] + pub(super) async fn fetch_cursor( + &self, + fetch_cursor: FetchCursor, + query_ctx: QueryContextRef, + ) -> Result { + let cursor_name = fetch_cursor.cursor_name.to_string(); + let fetch_size = fetch_cursor.fetch_size; + if let Some(rb) = query_ctx.get_cursor(&cursor_name) { + let record_batch = rb + .take(fetch_size as usize) + .await + .context(error::BuildRecordBatchSnafu)?; + let record_batches = + RecordBatches::try_new(record_batch.schema.clone(), vec![record_batch]) + .context(error::BuildRecordBatchSnafu)?; + Ok(Output::new_with_record_batches(record_batches)) + } else { + error::CursorNotFoundSnafu { name: cursor_name }.fail() + } + } + + #[tracing::instrument(skip_all)] + pub(super) async fn close_cursor( + &self, + close_cursor: CloseCursor, + query_ctx: QueryContextRef, + ) -> Result { + query_ctx.remove_cursor(&close_cursor.cursor_name.to_string()); + Ok(Output::new_with_affected_rows(0)) + } +} diff --git a/src/servers/src/postgres/fixtures.rs b/src/servers/src/postgres/fixtures.rs index 895f5c03e4a9..2ca3ad02eaa7 100644 --- a/src/servers/src/postgres/fixtures.rs +++ b/src/servers/src/postgres/fixtures.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; @@ -112,6 +113,13 @@ pub(crate) fn process<'a>(query: &str, query_ctx: QueryContextRef) -> Option = + Lazy::new(|| Regex::new("(?i)(LIMIT\\s+\\d+)::bigint").unwrap()); +pub(crate) fn rewrite_sql(query: &str) -> Cow<'_, str> { + //TODO(sunng87): remove this when we upgraded datafusion to 43 or newer + LIMIT_CAST_PATTERN.replace_all(query, "$1") +} + #[cfg(test)] mod test { use session::context::{QueryContext, QueryContextRef}; @@ -195,4 +203,13 @@ mod test { assert!(process("SHOW TABLES ", query_context.clone()).is_none()); assert!(process("SET TIME_ZONE=utc ", query_context.clone()).is_none()); } + + #[test] + fn test_rewrite() { + let sql = "SELECT * FROM number LIMIT 1::bigint"; + let sql2 = "SELECT * FROM number limit 1::BIGINT"; + + assert_eq!("SELECT * FROM number LIMIT 1", rewrite_sql(sql)); + assert_eq!("SELECT * FROM number limit 1", rewrite_sql(sql2)); + } } diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 522c558cdc71..e2e46534b5a1 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -70,6 +70,9 @@ impl SimpleQueryHandler for PostgresServerHandlerInner { return Ok(vec![Response::EmptyQuery]); } + let query = fixtures::rewrite_sql(query); + let query = query.as_ref(); + if let Some(resps) = fixtures::process(query, query_ctx.clone()) { send_warning_opt(client, query_ctx).await?; Ok(resps) @@ -229,6 +232,9 @@ impl QueryParser for DefaultQueryParser { }); } + let sql = fixtures::rewrite_sql(sql); + let sql = sql.as_ref(); + let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default()) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml index b6dbb0095546..f15d3b2609b3 100644 --- a/src/session/Cargo.toml +++ b/src/session/Cargo.toml @@ -17,9 +17,11 @@ auth.workspace = true common-catalog.workspace = true common-error.workspace = true common-macro.workspace = true +common-recordbatch.workspace = true common-telemetry.workspace = true common-time.workspace = true derive_builder.workspace = true +derive_more = { version = "1", default-features = false, features = ["debug"] } meter-core.workspace = true snafu.workspace = true sql.workspace = true diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 4e681253c100..1c621b3ab711 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -23,6 +23,8 @@ use arc_swap::ArcSwap; use auth::UserInfoRef; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string}; +use common_recordbatch::cursor::RecordBatchStreamCursor; +use common_telemetry::warn; use common_time::timezone::parse_timezone; use common_time::Timezone; use derive_builder::Builder; @@ -34,6 +36,8 @@ use crate::MutableInner; pub type QueryContextRef = Arc; pub type ConnInfoRef = Arc; +const CURSOR_COUNT_WARNING_LIMIT: usize = 10; + #[derive(Debug, Builder, Clone)] #[builder(pattern = "owned")] #[builder(build_fn(skip))] @@ -299,6 +303,27 @@ impl QueryContext { pub fn set_query_timeout(&self, timeout: Duration) { self.mutable_session_data.write().unwrap().query_timeout = Some(timeout); } + + pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) { + let mut guard = self.mutable_session_data.write().unwrap(); + guard.cursors.insert(name, Arc::new(rb)); + + let cursor_count = guard.cursors.len(); + if cursor_count > CURSOR_COUNT_WARNING_LIMIT { + warn!("Current connection has {} open cursors", cursor_count); + } + } + + pub fn remove_cursor(&self, name: &str) { + let mut guard = self.mutable_session_data.write().unwrap(); + guard.cursors.remove(name); + } + + pub fn get_cursor(&self, name: &str) -> Option> { + let guard = self.mutable_session_data.read().unwrap(); + let rb = guard.cursors.get(name); + rb.cloned() + } } impl QueryContextBuilder { diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 5ddaae7eb579..f553fef58c42 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -16,6 +16,7 @@ pub mod context; pub mod session_config; pub mod table_name; +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; use std::time::Duration; @@ -23,9 +24,11 @@ use std::time::Duration; use auth::UserInfoRef; use common_catalog::build_db_string; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use common_recordbatch::cursor::RecordBatchStreamCursor; use common_time::timezone::get_timezone; use common_time::Timezone; use context::{ConfigurationVariables, QueryContextBuilder}; +use derive_more::Debug; use crate::context::{Channel, ConnInfo, QueryContextRef}; @@ -47,6 +50,8 @@ pub(crate) struct MutableInner { user_info: UserInfoRef, timezone: Timezone, query_timeout: Option, + #[debug(skip)] + pub(crate) cursors: HashMap>, } impl Default for MutableInner { @@ -56,6 +61,7 @@ impl Default for MutableInner { user_info: auth::userinfo_by_name(None), timezone: get_timezone(None).clone(), query_timeout: None, + cursors: HashMap::with_capacity(0), } } } diff --git a/src/sql/src/parser.rs b/src/sql/src/parser.rs index bf62a1ad9b67..da03031bc44e 100644 --- a/src/sql/src/parser.rs +++ b/src/sql/src/parser.rs @@ -167,6 +167,12 @@ impl ParserContext<'_> { self.parse_tql() } + Keyword::DECLARE => self.parse_declare_cursor(), + + Keyword::FETCH => self.parse_fetch_cursor(), + + Keyword::CLOSE => self.parse_close_cursor(), + Keyword::USE => { let _ = self.parser.next_token(); diff --git a/src/sql/src/parsers.rs b/src/sql/src/parsers.rs index 2ae0697231c5..26f3ae9903d7 100644 --- a/src/sql/src/parsers.rs +++ b/src/sql/src/parsers.rs @@ -16,6 +16,7 @@ pub(crate) mod admin_parser; mod alter_parser; pub(crate) mod copy_parser; pub(crate) mod create_parser; +pub(crate) mod cursor_parser; pub(crate) mod deallocate_parser; pub(crate) mod delete_parser; pub(crate) mod describe_parser; diff --git a/src/sql/src/parsers/cursor_parser.rs b/src/sql/src/parsers/cursor_parser.rs new file mode 100644 index 000000000000..706f820c189e --- /dev/null +++ b/src/sql/src/parsers/cursor_parser.rs @@ -0,0 +1,157 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use snafu::{ensure, ResultExt}; +use sqlparser::keywords::Keyword; +use sqlparser::tokenizer::Token; + +use crate::error::{self, Result}; +use crate::parser::ParserContext; +use crate::statements::cursor::{CloseCursor, DeclareCursor, FetchCursor}; +use crate::statements::statement::Statement; + +impl ParserContext<'_> { + pub(crate) fn parse_declare_cursor(&mut self) -> Result { + let _ = self.parser.expect_keyword(Keyword::DECLARE); + let cursor_name = self + .parser + .parse_object_name(false) + .context(error::SyntaxSnafu)?; + let _ = self + .parser + .expect_keywords(&[Keyword::CURSOR, Keyword::FOR]); + + let mut is_select = false; + if let Token::Word(w) = self.parser.peek_token().token { + match w.keyword { + Keyword::SELECT | Keyword::WITH => { + is_select = true; + } + _ => {} + } + }; + ensure!( + is_select, + error::InvalidSqlSnafu { + msg: "Expect select query in cursor statement".to_string(), + } + ); + + let query_stmt = self.parse_query()?; + match query_stmt { + Statement::Query(query) => Ok(Statement::DeclareCursor(DeclareCursor { + cursor_name: ParserContext::canonicalize_object_name(cursor_name), + query, + })), + _ => error::InvalidSqlSnafu { + msg: format!("Expect query, found {}", query_stmt), + } + .fail(), + } + } + + pub(crate) fn parse_fetch_cursor(&mut self) -> Result { + let _ = self.parser.expect_keyword(Keyword::FETCH); + + let fetch_size = self + .parser + .parse_literal_uint() + .context(error::SyntaxSnafu)?; + let _ = self.parser.parse_keyword(Keyword::FROM); + + let cursor_name = self + .parser + .parse_object_name(false) + .context(error::SyntaxSnafu)?; + + Ok(Statement::FetchCursor(FetchCursor { + cursor_name: ParserContext::canonicalize_object_name(cursor_name), + fetch_size, + })) + } + + pub(crate) fn parse_close_cursor(&mut self) -> Result { + let _ = self.parser.expect_keyword(Keyword::CLOSE); + let cursor_name = self + .parser + .parse_object_name(false) + .context(error::SyntaxSnafu)?; + + Ok(Statement::CloseCursor(CloseCursor { + cursor_name: ParserContext::canonicalize_object_name(cursor_name), + })) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::dialect::GreptimeDbDialect; + use crate::parser::ParseOptions; + + #[test] + fn test_parse_declare_cursor() { + let sql = "DECLARE c1 CURSOR FOR\nSELECT * FROM numbers"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + if let Statement::DeclareCursor(dc) = &result[0] { + assert_eq!("c1", dc.cursor_name.to_string()); + assert_eq!( + "DECLARE c1 CURSOR FOR SELECT * FROM numbers", + dc.to_string() + ); + } else { + panic!("Unexpected statement"); + } + + let sql = "DECLARE c1 CURSOR FOR\nINSERT INTO numbers VALUES (1);"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + assert!(result.is_err()); + } + + #[test] + fn test_parese_fetch_cursor() { + let sql = "FETCH 1000 FROM c1"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + if let Statement::FetchCursor(fc) = &result[0] { + assert_eq!("c1", fc.cursor_name.to_string()); + assert_eq!("1000", fc.fetch_size.to_string()); + assert_eq!(sql, fc.to_string()); + } else { + panic!("Unexpected statement") + } + } + + #[test] + fn test_close_fetch_cursor() { + let sql = "CLOSE c1"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + if let Statement::CloseCursor(cc) = &result[0] { + assert_eq!("c1", cc.cursor_name.to_string()); + assert_eq!(sql, cc.to_string()); + } else { + panic!("Unexpected statement") + } + } +} diff --git a/src/sql/src/statements.rs b/src/sql/src/statements.rs index 3e1e505a9b1b..25cc3bf7e5be 100644 --- a/src/sql/src/statements.rs +++ b/src/sql/src/statements.rs @@ -16,6 +16,7 @@ pub mod admin; pub mod alter; pub mod copy; pub mod create; +pub mod cursor; pub mod delete; pub mod describe; pub mod drop; @@ -224,7 +225,7 @@ pub fn sql_number_to_value(data_type: &ConcreteDataType, n: &str) -> Result(n: &str) -> Result +pub(crate) fn parse_sql_number(n: &str) -> Result where ::Err: std::fmt::Debug, { diff --git a/src/sql/src/statements/cursor.rs b/src/sql/src/statements/cursor.rs new file mode 100644 index 000000000000..72ef4cdcae98 --- /dev/null +++ b/src/sql/src/statements/cursor.rs @@ -0,0 +1,60 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Display; + +use sqlparser::ast::ObjectName; +use sqlparser_derive::{Visit, VisitMut}; + +use super::query::Query; + +/// Represents a DECLARE CURSOR statement +/// +/// This statement will carry a SQL query +#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)] +pub struct DeclareCursor { + pub cursor_name: ObjectName, + pub query: Box, +} + +impl Display for DeclareCursor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DECLARE {} CURSOR FOR {}", self.cursor_name, self.query) + } +} + +/// Represents a FETCH FROM cursor statement +#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)] +pub struct FetchCursor { + pub cursor_name: ObjectName, + pub fetch_size: u64, +} + +impl Display for FetchCursor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "FETCH {} FROM {}", self.fetch_size, self.cursor_name) + } +} + +/// Represents a CLOSE cursor statement +#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)] +pub struct CloseCursor { + pub cursor_name: ObjectName, +} + +impl Display for CloseCursor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CLOSE {}", self.cursor_name) + } +} diff --git a/src/sql/src/statements/statement.rs b/src/sql/src/statements/statement.rs index 0c4b324cd63f..8ad391a00dd2 100644 --- a/src/sql/src/statements/statement.rs +++ b/src/sql/src/statements/statement.rs @@ -24,6 +24,7 @@ use crate::statements::alter::{AlterDatabase, AlterTable}; use crate::statements::create::{ CreateDatabase, CreateExternalTable, CreateFlow, CreateTable, CreateTableLike, CreateView, }; +use crate::statements::cursor::{CloseCursor, DeclareCursor, FetchCursor}; use crate::statements::delete::Delete; use crate::statements::describe::DescribeTable; use crate::statements::drop::{DropDatabase, DropFlow, DropTable, DropView}; @@ -118,6 +119,12 @@ pub enum Statement { Use(String), // Admin statement(extension) Admin(Admin), + // DECLARE ... CURSOR FOR ... + DeclareCursor(DeclareCursor), + // FETCH ... FROM ... + FetchCursor(FetchCursor), + // CLOSE + CloseCursor(CloseCursor), } impl Display for Statement { @@ -165,6 +172,9 @@ impl Display for Statement { Statement::CreateView(s) => s.fmt(f), Statement::Use(s) => s.fmt(f), Statement::Admin(admin) => admin.fmt(f), + Statement::DeclareCursor(s) => s.fmt(f), + Statement::FetchCursor(s) => s.fmt(f), + Statement::CloseCursor(s) => s.fmt(f), } } } diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index f15e3743256d..303a49ac9b01 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -72,6 +72,7 @@ macro_rules! sql_tests { test_postgres_parameter_inference, test_postgres_array_types, test_mysql_prepare_stmt_insert_timestamp, + test_declare_fetch_close_cursor, ); )* }; @@ -1198,3 +1199,66 @@ pub async fn test_postgres_array_types(store_type: StorageType) { let _ = fe_pg_server.shutdown().await; guard.remove_all().await; } + +pub async fn test_declare_fetch_close_cursor(store_type: StorageType) { + let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await; + + let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) + .await + .unwrap(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + connection.await.unwrap(); + tx.send(()).unwrap(); + }); + + client + .execute( + "DECLARE c1 CURSOR FOR SELECT * FROM numbers WHERE number > 2 LIMIT 50::bigint", + &[], + ) + .await + .expect("declare cursor"); + + // duplicated cursor + assert!(client + .execute("DECLARE c1 CURSOR FOR SELECT 1", &[],) + .await + .is_err()); + + let rows = client.query("FETCH 5 FROM c1", &[]).await.unwrap(); + assert_eq!(5, rows.len()); + + let rows = client.query("FETCH 100 FROM c1", &[]).await.unwrap(); + assert_eq!(45, rows.len()); + + let rows = client.query("FETCH 100 FROM c1", &[]).await.unwrap(); + assert_eq!(0, rows.len()); + + client.execute("CLOSE c1", &[]).await.expect("close cursor"); + + // cursor not found + let result = client.query("FETCH 100 FROM c1", &[]).await; + assert!(result.is_err()); + + client + .execute( + "DECLARE c2 CURSOR FOR SELECT * FROM numbers WHERE number < 0", + &[], + ) + .await + .expect("declare cursor"); + + let rows = client.query("FETCH 5 FROM c2", &[]).await.unwrap(); + assert_eq!(0, rows.len()); + + client.execute("CLOSE c2", &[]).await.expect("close cursor"); + + // Shutdown the client. + drop(client); + rx.await.unwrap(); + + let _ = fe_pg_server.shutdown().await; + guard.remove_all().await; +}