Skip to content

Commit

Permalink
feat: implement cursor storage and execution
Browse files Browse the repository at this point in the history
  • Loading branch information
sunng87 committed Dec 4, 2024
1 parent 34ede04 commit 780fc03
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 19 deletions.
6 changes: 5 additions & 1 deletion src/operator/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,9 @@ pub enum Error {
#[snafu(source)]
error: Elapsed,
},

#[snafu(display("Cursor {name} is not found"))]
CursorNotFound { name: String },
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -825,7 +828,8 @@ impl ErrorExt for Error {
| Error::FunctionArityMismatch { .. }
| Error::InvalidPartition { .. }
| Error::PhysicalExpr { .. }
| Error::InvalidJsonFormat { .. } => StatusCode::InvalidArguments,
| Error::InvalidJsonFormat { .. }
| Error::CursorNotFound { .. } => StatusCode::InvalidArguments,

Error::TableAlreadyExists { .. } | Error::ViewAlreadyExists { .. } => {
StatusCode::TableAlreadyExists
Expand Down
41 changes: 36 additions & 5 deletions src/operator/src/statement/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,43 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use common_query::Output;
use common_query::{Output, OutputData};
use common_recordbatch::cursor::RecordBatchStreamCursor;
use common_recordbatch::RecordBatches;
use query::parser::QueryStatement;
use session::context::QueryContextRef;
use snafu::ResultExt;
use sql::statements::cursor::{CloseCursor, DeclareCursor, FetchCursor};
use sql::statements::statement::Statement;

use crate::error::{self, Result};
use crate::statement::StatementExecutor;

use crate::error::Result;

impl StatementExecutor {
pub(super) async fn declare_cursor(
&self,
declare_cursor: DeclareCursor,
query_ctx: QueryContextRef,
) -> Result<Output> {
let cursor_name = fetch_cursor.cursor_name.to_string();
let cursor_name = declare_cursor.cursor_name.to_string();
let query_stmt = Statement::Query(declare_cursor.query);

let output = self
.plan_exec(QueryStatement::Sql(query_stmt), query_ctx.clone())
.await?;
match output.data {
OutputData::RecordBatches(rb) => {
let rbs = rb.as_stream();
query_ctx.insert_cursor(cursor_name, RecordBatchStreamCursor::new(rbs));
}
OutputData::Stream(rbs) => {
query_ctx.insert_cursor(cursor_name, RecordBatchStreamCursor::new(rbs));
}
OutputData::AffectedRows(_) => error::NotSupportedSnafu {
feat: "Non-query statement on cursor",
}
.fail()?,
}

Ok(Output::new_with_affected_rows(0))
}
Expand All @@ -37,9 +59,18 @@ impl StatementExecutor {
query_ctx: QueryContextRef,
) -> Result<Output> {
let cursor_name = fetch_cursor.cursor_name.to_string();
let fetch_size = fetch_cursor.fetch_size;
if let Some(rb) = query_ctx.get_cursor(&cursor_name) {
let record_batch = rb
.take(fetch_size as usize)
.await
.context(error::BuildRecordBatchSnafu)?;
let record_batches =
RecordBatches::try_new(record_batch.schema.clone(), vec![record_batch])
.context(error::BuildRecordBatchSnafu)?;
Ok(Output::new_with_record_batches(record_batches))
} else {
todo!("cursor not found")
error::CursorNotFoundSnafu { name: cursor_name }.fail()
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/session/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use arc_swap::ArcSwap;
use auth::UserInfoRef;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
use common_recordbatch::SendableRecordBatchStream;
use common_recordbatch::cursor::RecordBatchStreamCursor;
use common_time::timezone::parse_timezone;
use common_time::Timezone;
use derive_builder::Builder;
Expand Down Expand Up @@ -301,7 +301,7 @@ impl QueryContext {
self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
}

pub fn insert_cursor(&self, name: String, rb: SendableRecordBatchStream) {
pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
let mut guard = self.mutable_session_data.write().unwrap();
guard.cursors.insert(name, Arc::new(rb));
}
Expand All @@ -311,7 +311,7 @@ impl QueryContext {
guard.cursors.remove(name);
}

pub fn get_cursor(&self, name: &str) -> Option<Arc<SendableRecordBatchStream>> {
pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
let guard = self.mutable_session_data.read().unwrap();
let rb = guard.cursors.get(name);
rb.cloned()
Expand Down
6 changes: 3 additions & 3 deletions src/session/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pub mod context;
pub mod session_config;
pub mod table_name;

use derive_more::Debug;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
Expand All @@ -25,10 +24,11 @@ use std::time::Duration;
use auth::UserInfoRef;
use common_catalog::build_db_string;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_recordbatch::SendableRecordBatchStream;
use common_recordbatch::cursor::RecordBatchStreamCursor;
use common_time::timezone::get_timezone;
use common_time::Timezone;
use context::{ConfigurationVariables, QueryContextBuilder};
use derive_more::Debug;

use crate::context::{Channel, ConnInfo, QueryContextRef};

Expand All @@ -51,7 +51,7 @@ pub(crate) struct MutableInner {
timezone: Timezone,
query_timeout: Option<Duration>,
#[debug(skip)]
pub(crate) cursors: HashMap<String, Arc<SendableRecordBatchStream>>,
pub(crate) cursors: HashMap<String, Arc<RecordBatchStreamCursor>>,
}

impl Default for MutableInner {
Expand Down
10 changes: 6 additions & 4 deletions src/sql/src/parsers/cursor_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ impl ParserContext<'_> {
pub(crate) fn parse_fetch_cursor(&mut self) -> Result<Statement> {
let _ = self.parser.expect_keyword(Keyword::FETCH);

let fetch_size = self.parser.parse_expr().context(error::SyntaxSnafu)?;

let fetch_size = self
.parser
.parse_literal_uint()
.context(error::SyntaxSnafu)?;
let _ = self.parser.parse_keyword(Keyword::FROM);

let cursor_name = self
Expand Down Expand Up @@ -76,9 +78,9 @@ impl ParserContext<'_> {
#[cfg(test)]
mod tests {

use crate::{dialect::GreptimeDbDialect, parser::ParseOptions};

use super::*;
use crate::dialect::GreptimeDbDialect;
use crate::parser::ParseOptions;

#[test]
fn test_parse_declare_cursor() {
Expand Down
2 changes: 1 addition & 1 deletion src/sql/src/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ pub fn sql_number_to_value(data_type: &ConcreteDataType, n: &str) -> Result<Valu
// TODO(hl): also Date/DateTime
}

fn parse_sql_number<R: FromStr + std::fmt::Debug>(n: &str) -> Result<R>
pub(crate) fn parse_sql_number<R: FromStr + std::fmt::Debug>(n: &str) -> Result<R>
where
<R as FromStr>::Err: std::fmt::Debug,
{
Expand Down
4 changes: 2 additions & 2 deletions src/sql/src/statements/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use std::fmt::Display;

use sqlparser::ast::{Expr, ObjectName};
use sqlparser::ast::ObjectName;
use sqlparser_derive::{Visit, VisitMut};

use super::query::Query;
Expand All @@ -34,7 +34,7 @@ impl Display for DeclareCursor {
#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)]
pub struct FetchCursor {
pub cursor_name: ObjectName,
pub fetch_size: Expr,
pub fetch_size: u64,
}

impl Display for FetchCursor {
Expand Down

0 comments on commit 780fc03

Please sign in to comment.