Skip to content

Commit a464f8e

Browse files
aharpervciffyio
andauthored
Improve support for cursors for SQL Server (#1831)
Co-authored-by: Ifeanyi Ubah <[email protected]>
1 parent 483394c commit a464f8e

File tree

7 files changed

+289
-20
lines changed

7 files changed

+289
-20
lines changed

src/ast/mod.rs

+86-4
Original file line numberDiff line numberDiff line change
@@ -2228,7 +2228,33 @@ impl fmt::Display for IfStatement {
22282228
}
22292229
}
22302230

2231-
/// A block within a [Statement::Case] or [Statement::If]-like statement
2231+
/// A `WHILE` statement.
2232+
///
2233+
/// Example:
2234+
/// ```sql
2235+
/// WHILE @@FETCH_STATUS = 0
2236+
/// BEGIN
2237+
/// FETCH NEXT FROM c1 INTO @var1, @var2;
2238+
/// END
2239+
/// ```
2240+
///
2241+
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/while-transact-sql)
2242+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2243+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2244+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2245+
pub struct WhileStatement {
2246+
pub while_block: ConditionalStatementBlock,
2247+
}
2248+
2249+
impl fmt::Display for WhileStatement {
2250+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2251+
let WhileStatement { while_block } = self;
2252+
write!(f, "{while_block}")?;
2253+
Ok(())
2254+
}
2255+
}
2256+
2257+
/// A block within a [Statement::Case] or [Statement::If] or [Statement::While]-like statement
22322258
///
22332259
/// Example 1:
22342260
/// ```sql
@@ -2244,6 +2270,14 @@ impl fmt::Display for IfStatement {
22442270
/// ```sql
22452271
/// ELSE SELECT 1; SELECT 2;
22462272
/// ```
2273+
///
2274+
/// Example 4:
2275+
/// ```sql
2276+
/// WHILE @@FETCH_STATUS = 0
2277+
/// BEGIN
2278+
/// FETCH NEXT FROM c1 INTO @var1, @var2;
2279+
/// END
2280+
/// ```
22472281
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
22482282
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22492283
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@@ -2983,6 +3017,8 @@ pub enum Statement {
29833017
Case(CaseStatement),
29843018
/// An `IF` statement.
29853019
If(IfStatement),
3020+
/// A `WHILE` statement.
3021+
While(WhileStatement),
29863022
/// A `RAISE` statement.
29873023
Raise(RaiseStatement),
29883024
/// ```sql
@@ -3034,6 +3070,11 @@ pub enum Statement {
30343070
partition: Option<Box<Expr>>,
30353071
},
30363072
/// ```sql
3073+
/// OPEN cursor_name
3074+
/// ```
3075+
/// Opens a cursor.
3076+
Open(OpenStatement),
3077+
/// ```sql
30373078
/// CLOSE
30383079
/// ```
30393080
/// Closes the portal underlying an open cursor.
@@ -3413,6 +3454,7 @@ pub enum Statement {
34133454
/// Cursor name
34143455
name: Ident,
34153456
direction: FetchDirection,
3457+
position: FetchPosition,
34163458
/// Optional, It's possible to fetch rows form cursor to the table
34173459
into: Option<ObjectName>,
34183460
},
@@ -4235,11 +4277,10 @@ impl fmt::Display for Statement {
42354277
Statement::Fetch {
42364278
name,
42374279
direction,
4280+
position,
42384281
into,
42394282
} => {
4240-
write!(f, "FETCH {direction} ")?;
4241-
4242-
write!(f, "IN {name}")?;
4283+
write!(f, "FETCH {direction} {position} {name}")?;
42434284

42444285
if let Some(into) = into {
42454286
write!(f, " INTO {into}")?;
@@ -4329,6 +4370,9 @@ impl fmt::Display for Statement {
43294370
Statement::If(stmt) => {
43304371
write!(f, "{stmt}")
43314372
}
4373+
Statement::While(stmt) => {
4374+
write!(f, "{stmt}")
4375+
}
43324376
Statement::Raise(stmt) => {
43334377
write!(f, "{stmt}")
43344378
}
@@ -4498,6 +4542,7 @@ impl fmt::Display for Statement {
44984542
Ok(())
44994543
}
45004544
Statement::Delete(delete) => write!(f, "{delete}"),
4545+
Statement::Open(open) => write!(f, "{open}"),
45014546
Statement::Close { cursor } => {
45024547
write!(f, "CLOSE {cursor}")?;
45034548

@@ -6187,6 +6232,28 @@ impl fmt::Display for FetchDirection {
61876232
}
61886233
}
61896234

6235+
/// The "position" for a FETCH statement.
6236+
///
6237+
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/fetch-transact-sql)
6238+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
6239+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6240+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
6241+
pub enum FetchPosition {
6242+
From,
6243+
In,
6244+
}
6245+
6246+
impl fmt::Display for FetchPosition {
6247+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
6248+
match self {
6249+
FetchPosition::From => f.write_str("FROM")?,
6250+
FetchPosition::In => f.write_str("IN")?,
6251+
};
6252+
6253+
Ok(())
6254+
}
6255+
}
6256+
61906257
/// A privilege on a database object (table, sequence, etc.).
61916258
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
61926259
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -9354,6 +9421,21 @@ pub enum ReturnStatementValue {
93549421
Expr(Expr),
93559422
}
93569423

9424+
/// Represents an `OPEN` statement.
9425+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
9426+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9427+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
9428+
pub struct OpenStatement {
9429+
/// Cursor name
9430+
pub cursor_name: Ident,
9431+
}
9432+
9433+
impl fmt::Display for OpenStatement {
9434+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
9435+
write!(f, "OPEN {}", self.cursor_name)
9436+
}
9437+
}
9438+
93579439
#[cfg(test)]
93589440
mod tests {
93599441
use super::*;

src/ast/spans.rs

+24-7
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ use super::{
3131
FunctionArguments, GroupByExpr, HavingBound, IfStatement, IlikeSelectItem, Insert, Interpolate,
3232
InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, JsonPathElem, LateralView,
3333
LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart,
34-
Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition,
35-
PivotValueSource, ProjectionSelect, Query, RaiseStatement, RaiseStatementValue,
36-
ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select,
37-
SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias,
38-
TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered,
39-
TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
40-
WildcardAdditionalOptions, With, WithFill,
34+
Offset, OnConflict, OnConflictAction, OnInsert, OpenStatement, OrderBy, OrderByExpr,
35+
OrderByKind, Partition, PivotValueSource, ProjectionSelect, Query, RaiseStatement,
36+
RaiseStatementValue, ReferentialAction, RenameSelectItem, ReplaceSelectElement,
37+
ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript,
38+
SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, TableFactor, TableObject,
39+
TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
40+
WhileStatement, WildcardAdditionalOptions, With, WithFill,
4141
};
4242

4343
/// Given an iterator of spans, return the [Span::union] of all spans.
@@ -339,6 +339,7 @@ impl Spanned for Statement {
339339
} => source.span(),
340340
Statement::Case(stmt) => stmt.span(),
341341
Statement::If(stmt) => stmt.span(),
342+
Statement::While(stmt) => stmt.span(),
342343
Statement::Raise(stmt) => stmt.span(),
343344
Statement::Call(function) => function.span(),
344345
Statement::Copy {
@@ -365,6 +366,7 @@ impl Spanned for Statement {
365366
from_query: _,
366367
partition: _,
367368
} => Span::empty(),
369+
Statement::Open(open) => open.span(),
368370
Statement::Close { cursor } => match cursor {
369371
CloseCursor::All => Span::empty(),
370372
CloseCursor::Specific { name } => name.span,
@@ -776,6 +778,14 @@ impl Spanned for IfStatement {
776778
}
777779
}
778780

781+
impl Spanned for WhileStatement {
782+
fn span(&self) -> Span {
783+
let WhileStatement { while_block } = self;
784+
785+
while_block.span()
786+
}
787+
}
788+
779789
impl Spanned for ConditionalStatements {
780790
fn span(&self) -> Span {
781791
match self {
@@ -2297,6 +2307,13 @@ impl Spanned for BeginEndStatements {
22972307
}
22982308
}
22992309

2310+
impl Spanned for OpenStatement {
2311+
fn span(&self) -> Span {
2312+
let OpenStatement { cursor_name } = self;
2313+
cursor_name.span
2314+
}
2315+
}
2316+
23002317
#[cfg(test)]
23012318
pub mod tests {
23022319
use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect};

src/keywords.rs

+2
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,7 @@ define_keywords!(
985985
WHEN,
986986
WHENEVER,
987987
WHERE,
988+
WHILE,
988989
WIDTH_BUCKET,
989990
WINDOW,
990991
WITH,
@@ -1068,6 +1069,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[
10681069
Keyword::SAMPLE,
10691070
Keyword::TABLESAMPLE,
10701071
Keyword::FROM,
1072+
Keyword::OPEN,
10711073
];
10721074

10731075
/// Can't be used as a column alias, so that `SELECT <expr> alias`

src/parser/mod.rs

+63-7
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,10 @@ impl<'a> Parser<'a> {
536536
self.prev_token();
537537
self.parse_if_stmt()
538538
}
539+
Keyword::WHILE => {
540+
self.prev_token();
541+
self.parse_while()
542+
}
539543
Keyword::RAISE => {
540544
self.prev_token();
541545
self.parse_raise_stmt()
@@ -570,6 +574,10 @@ impl<'a> Parser<'a> {
570574
Keyword::ALTER => self.parse_alter(),
571575
Keyword::CALL => self.parse_call(),
572576
Keyword::COPY => self.parse_copy(),
577+
Keyword::OPEN => {
578+
self.prev_token();
579+
self.parse_open()
580+
}
573581
Keyword::CLOSE => self.parse_close(),
574582
Keyword::SET => self.parse_set(),
575583
Keyword::SHOW => self.parse_show(),
@@ -700,8 +708,18 @@ impl<'a> Parser<'a> {
700708
}))
701709
}
702710

711+
/// Parse a `WHILE` statement.
712+
///
713+
/// See [Statement::While]
714+
fn parse_while(&mut self) -> Result<Statement, ParserError> {
715+
self.expect_keyword_is(Keyword::WHILE)?;
716+
let while_block = self.parse_conditional_statement_block(&[Keyword::END])?;
717+
718+
Ok(Statement::While(WhileStatement { while_block }))
719+
}
720+
703721
/// Parses an expression and associated list of statements
704-
/// belonging to a conditional statement like `IF` or `WHEN`.
722+
/// belonging to a conditional statement like `IF` or `WHEN` or `WHILE`.
705723
///
706724
/// Example:
707725
/// ```sql
@@ -716,20 +734,36 @@ impl<'a> Parser<'a> {
716734

717735
let condition = match &start_token.token {
718736
Token::Word(w) if w.keyword == Keyword::ELSE => None,
737+
Token::Word(w) if w.keyword == Keyword::WHILE => {
738+
let expr = self.parse_expr()?;
739+
Some(expr)
740+
}
719741
_ => {
720742
let expr = self.parse_expr()?;
721743
then_token = Some(AttachedToken(self.expect_keyword(Keyword::THEN)?));
722744
Some(expr)
723745
}
724746
};
725747

726-
let statements = self.parse_statement_list(terminal_keywords)?;
748+
let conditional_statements = if self.peek_keyword(Keyword::BEGIN) {
749+
let begin_token = self.expect_keyword(Keyword::BEGIN)?;
750+
let statements = self.parse_statement_list(terminal_keywords)?;
751+
let end_token = self.expect_keyword(Keyword::END)?;
752+
ConditionalStatements::BeginEnd(BeginEndStatements {
753+
begin_token: AttachedToken(begin_token),
754+
statements,
755+
end_token: AttachedToken(end_token),
756+
})
757+
} else {
758+
let statements = self.parse_statement_list(terminal_keywords)?;
759+
ConditionalStatements::Sequence { statements }
760+
};
727761

728762
Ok(ConditionalStatementBlock {
729763
start_token: AttachedToken(start_token),
730764
condition,
731765
then_token,
732-
conditional_statements: ConditionalStatements::Sequence { statements },
766+
conditional_statements,
733767
})
734768
}
735769

@@ -4467,11 +4501,16 @@ impl<'a> Parser<'a> {
44674501
) -> Result<Vec<Statement>, ParserError> {
44684502
let mut values = vec![];
44694503
loop {
4470-
if let Token::Word(w) = &self.peek_nth_token_ref(0).token {
4471-
if w.quote_style.is_none() && terminal_keywords.contains(&w.keyword) {
4472-
break;
4504+
match &self.peek_nth_token_ref(0).token {
4505+
Token::EOF => break,
4506+
Token::Word(w) => {
4507+
if w.quote_style.is_none() && terminal_keywords.contains(&w.keyword) {
4508+
break;
4509+
}
44734510
}
4511+
_ => {}
44744512
}
4513+
44754514
values.push(self.parse_statement()?);
44764515
self.expect_token(&Token::SemiColon)?;
44774516
}
@@ -6644,7 +6683,15 @@ impl<'a> Parser<'a> {
66446683
}
66456684
};
66466685

6647-
self.expect_one_of_keywords(&[Keyword::FROM, Keyword::IN])?;
6686+
let position = if self.peek_keyword(Keyword::FROM) {
6687+
self.expect_keyword(Keyword::FROM)?;
6688+
FetchPosition::From
6689+
} else if self.peek_keyword(Keyword::IN) {
6690+
self.expect_keyword(Keyword::IN)?;
6691+
FetchPosition::In
6692+
} else {
6693+
return parser_err!("Expected FROM or IN", self.peek_token().span.start);
6694+
};
66486695

66496696
let name = self.parse_identifier()?;
66506697

@@ -6657,6 +6704,7 @@ impl<'a> Parser<'a> {
66576704
Ok(Statement::Fetch {
66586705
name,
66596706
direction,
6707+
position,
66606708
into,
66616709
})
66626710
}
@@ -8770,6 +8818,14 @@ impl<'a> Parser<'a> {
87708818
})
87718819
}
87728820

8821+
/// Parse [Statement::Open]
8822+
fn parse_open(&mut self) -> Result<Statement, ParserError> {
8823+
self.expect_keyword(Keyword::OPEN)?;
8824+
Ok(Statement::Open(OpenStatement {
8825+
cursor_name: self.parse_identifier()?,
8826+
}))
8827+
}
8828+
87738829
pub fn parse_close(&mut self) -> Result<Statement, ParserError> {
87748830
let cursor = if self.parse_keyword(Keyword::ALL) {
87758831
CloseCursor::All

0 commit comments

Comments
 (0)