diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 790a39bdb..0573240a2 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -3208,7 +3208,7 @@ pub enum Statement { /// Table confs options: Vec, /// Cache table as a Query - query: Option, + query: Option>, }, /// ```sql /// UNCACHE TABLE [ IF EXISTS ] @@ -6883,7 +6883,7 @@ impl fmt::Display for MacroArg { #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub enum MacroDefinition { Expr(Expr), - Table(Query), + Table(Box), } impl fmt::Display for MacroDefinition { diff --git a/src/ast/query.rs b/src/ast/query.rs index ec0198674..dc5966e5e 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -1103,7 +1103,7 @@ pub enum PivotValueSource { /// Pivot on all values returned by a subquery. /// /// See . - Subquery(Query), + Subquery(Box), } impl fmt::Display for PivotValueSource { diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index be97f929b..28e7ac7d1 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -486,9 +486,9 @@ pub trait Dialect: Debug + Any { fn parse_column_option( &self, _parser: &mut Parser, - ) -> Option, ParserError>> { + ) -> Result, ParserError>>, ParserError> { // return None to fall back to the default behavior - None + Ok(None) } /// Decide the lexical Precedence of operators. diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index 7c80f0461..d9331d952 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -156,7 +156,7 @@ impl Dialect for SnowflakeDialect { fn parse_column_option( &self, parser: &mut Parser, - ) -> Option, ParserError>> { + ) -> Result, ParserError>>, ParserError> { parser.maybe_parse(|parser| { let with = parser.parse_keyword(Keyword::WITH); @@ -247,7 +247,7 @@ pub fn parse_create_table( builder = builder.comment(parser.parse_optional_inline_comment()?); } Keyword::AS => { - let query = parser.parse_boxed_query()?; + let query = parser.parse_query()?; builder = builder.query(Some(query)); break; } diff --git a/src/parser/alter.rs b/src/parser/alter.rs index 28fdaf764..534105790 100644 --- a/src/parser/alter.rs +++ b/src/parser/alter.rs @@ -192,7 +192,7 @@ impl<'a> Parser<'a> { let _ = self.parse_keyword(Keyword::WITH); // option let mut options = vec![]; - while let Some(opt) = self.maybe_parse(|parser| parser.parse_pg_role_option()) { + while let Some(opt) = self.maybe_parse(|parser| parser.parse_pg_role_option())? { options.push(opt); } // check option diff --git a/src/parser/mod.rs b/src/parser/mod.rs index a1079f6f7..a9a5b1df4 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -478,7 +478,7 @@ impl<'a> Parser<'a> { Keyword::ANALYZE => self.parse_analyze(), Keyword::SELECT | Keyword::WITH | Keyword::VALUES => { self.prev_token(); - self.parse_boxed_query().map(Statement::Query) + self.parse_query().map(Statement::Query) } Keyword::TRUNCATE => self.parse_truncate(), Keyword::ATTACH => { @@ -551,7 +551,7 @@ impl<'a> Parser<'a> { }, Token::LParen => { self.prev_token(); - self.parse_boxed_query().map(Statement::Query) + self.parse_query().map(Statement::Query) } _ => self.expected("an SQL statement", next_token), } @@ -662,7 +662,7 @@ impl<'a> Parser<'a> { }; parser.expect_keyword(Keyword::PARTITIONS)?; Ok(pa) - }) + })? .unwrap_or_default(); Ok(Statement::Msck { repair, @@ -829,7 +829,7 @@ impl<'a> Parser<'a> { columns = self .maybe_parse(|parser| { parser.parse_comma_separated(|p| p.parse_identifier(false)) - }) + })? .unwrap_or_default(); for_columns = true } @@ -986,7 +986,7 @@ impl<'a> Parser<'a> { value: parser.parse_literal_string()?, }), } - }); + })?; if let Some(expr) = opt_expr { return Ok(expr); @@ -1061,7 +1061,7 @@ impl<'a> Parser<'a> { && !dialect_of!(self is ClickHouseDialect | DatabricksDialect) => { self.expect_token(&Token::LParen)?; - let query = self.parse_boxed_query()?; + let query = self.parse_query()?; self.expect_token(&Token::RParen)?; Ok(Expr::Function(Function { name: ObjectName(vec![w.to_ident()]), @@ -1228,7 +1228,7 @@ impl<'a> Parser<'a> { Token::LParen => { let expr = if let Some(expr) = self.try_parse_expr_sub_query()? { expr - } else if let Some(lambda) = self.try_parse_lambda() { + } else if let Some(lambda) = self.try_parse_lambda()? { return Ok(lambda); } else { let exprs = self.parse_comma_separated(Parser::parse_expr)?; @@ -1307,12 +1307,12 @@ impl<'a> Parser<'a> { return Ok(None); } - Ok(Some(Expr::Subquery(self.parse_boxed_query()?))) + Ok(Some(Expr::Subquery(self.parse_query()?))) } - fn try_parse_lambda(&mut self) -> Option { + fn try_parse_lambda(&mut self) -> Result, ParserError> { if !self.dialect.supports_lambda_functions() { - return None; + return Ok(None); } self.maybe_parse(|p| { let params = p.parse_comma_separated(|p| p.parse_identifier(false))?; @@ -1332,7 +1332,7 @@ impl<'a> Parser<'a> { // Snowflake permits a subquery to be passed as an argument without // an enclosing set of parens if it's the only argument. if dialect_of!(self is SnowflakeDialect) && self.peek_sub_query() { - let subquery = self.parse_boxed_query()?; + let subquery = self.parse_query()?; self.expect_token(&Token::RParen)?; return Ok(Expr::Function(Function { name, @@ -1697,7 +1697,7 @@ impl<'a> Parser<'a> { self.expect_token(&Token::LParen)?; let exists_node = Expr::Exists { negated, - subquery: self.parse_boxed_query()?, + subquery: self.parse_query()?, }; self.expect_token(&Token::RParen)?; Ok(exists_node) @@ -1777,7 +1777,7 @@ impl<'a> Parser<'a> { expr: Box::new(expr), r#in: Box::new(from), }) - }); + })?; match position_expr { Some(expr) => Ok(expr), // Snowflake supports `position` as an ordinary function call @@ -3032,7 +3032,7 @@ impl<'a> Parser<'a> { self.prev_token(); Expr::InSubquery { expr: Box::new(expr), - subquery: self.parse_boxed_query()?, + subquery: self.parse_query()?, negated, } } else { @@ -3513,17 +3513,19 @@ impl<'a> Parser<'a> { } /// Run a parser method `f`, reverting back to the current position if unsuccessful. - #[must_use] - pub fn maybe_parse(&mut self, mut f: F) -> Option + pub fn maybe_parse(&mut self, mut f: F) -> Result, ParserError> where F: FnMut(&mut Parser) -> Result, { let index = self.index; - if let Ok(t) = f(self) { - Some(t) - } else { - self.index = index; - None + match f(self) { + Ok(t) => Ok(Some(t)), + // Unwind stack if limit exceeded + Err(ParserError::RecursionLimitExceeded) => Err(ParserError::RecursionLimitExceeded), + Err(_) => { + self.index = index; + Ok(None) + } } } @@ -3759,7 +3761,7 @@ impl<'a> Parser<'a> { } /// Parse 'AS' before as query,such as `WITH XXX AS SELECT XXX` oer `CACHE TABLE AS SELECT XXX` - pub fn parse_as_query(&mut self) -> Result<(bool, Query), ParserError> { + pub fn parse_as_query(&mut self) -> Result<(bool, Box), ParserError> { match self.peek_token().token { Token::Word(word) => match word.keyword { Keyword::AS => { @@ -4523,7 +4525,7 @@ impl<'a> Parser<'a> { }; self.expect_keyword(Keyword::AS)?; - let query = self.parse_boxed_query()?; + let query = self.parse_query()?; // Optional `WITH [ CASCADED | LOCAL ] CHECK OPTION` is widely supported here. let with_no_schema_binding = dialect_of!(self is RedshiftSqlDialect | GenericDialect) @@ -5102,7 +5104,7 @@ impl<'a> Parser<'a> { self.expect_keyword(Keyword::FOR)?; - let query = Some(self.parse_boxed_query()?); + let query = Some(self.parse_query()?); Ok(Statement::Declare { stmts: vec![Declare { @@ -5196,7 +5198,7 @@ impl<'a> Parser<'a> { match self.peek_token().token { Token::Word(w) if w.keyword == Keyword::SELECT => ( Some(DeclareType::Cursor), - Some(self.parse_boxed_query()?), + Some(self.parse_query()?), None, None, ), @@ -5889,7 +5891,7 @@ impl<'a> Parser<'a> { // Parse optional `AS ( query )` let query = if self.parse_keyword(Keyword::AS) { - Some(self.parse_boxed_query()?) + Some(self.parse_query()?) } else { None }; @@ -6109,7 +6111,7 @@ impl<'a> Parser<'a> { } pub fn parse_optional_column_option(&mut self) -> Result, ParserError> { - if let Some(option) = self.dialect.parse_column_option(self) { + if let Some(option) = self.dialect.parse_column_option(self)? { return option; } @@ -6483,7 +6485,7 @@ impl<'a> Parser<'a> { } // optional index name - let index_name = self.parse_optional_indent(); + let index_name = self.parse_optional_indent()?; let index_type = self.parse_optional_using_then_index_type()?; let columns = self.parse_parenthesized_column_list(Mandatory, false)?; @@ -6504,7 +6506,7 @@ impl<'a> Parser<'a> { self.expect_keyword(Keyword::KEY)?; // optional index name - let index_name = self.parse_optional_indent(); + let index_name = self.parse_optional_indent()?; let index_type = self.parse_optional_using_then_index_type()?; let columns = self.parse_parenthesized_column_list(Mandatory, false)?; @@ -6566,7 +6568,7 @@ impl<'a> Parser<'a> { let name = match self.peek_token().token { Token::Word(word) if word.keyword == Keyword::USING => None, - _ => self.parse_optional_indent(), + _ => self.parse_optional_indent()?, }; let index_type = self.parse_optional_using_then_index_type()?; @@ -6597,7 +6599,7 @@ impl<'a> Parser<'a> { let index_type_display = self.parse_index_type_display(); - let opt_index_name = self.parse_optional_indent(); + let opt_index_name = self.parse_optional_indent()?; let columns = self.parse_parenthesized_column_list(Mandatory, false)?; @@ -6679,7 +6681,7 @@ impl<'a> Parser<'a> { /// Parse `[ident]`, mostly `ident` is name, like: /// `window_name`, `index_name`, ... - pub fn parse_optional_indent(&mut self) -> Option { + pub fn parse_optional_indent(&mut self) -> Result, ParserError> { self.maybe_parse(|parser| parser.parse_identifier(false)) } @@ -7278,7 +7280,7 @@ impl<'a> Parser<'a> { let with_options = self.parse_options(Keyword::WITH)?; self.expect_keyword(Keyword::AS)?; - let query = self.parse_boxed_query()?; + let query = self.parse_query()?; Ok(Statement::AlterView { name, @@ -7317,7 +7319,7 @@ impl<'a> Parser<'a> { pub fn parse_copy(&mut self) -> Result { let source; if self.consume_token(&Token::LParen) { - source = CopySource::Query(self.parse_boxed_query()?); + source = CopySource::Query(self.parse_query()?); self.expect_token(&Token::RParen)?; } else { let table_name = self.parse_object_name(false)?; @@ -7361,7 +7363,7 @@ impl<'a> Parser<'a> { self.expect_token(&Token::RParen)?; } let mut legacy_options = vec![]; - while let Some(opt) = self.maybe_parse(|parser| parser.parse_copy_legacy_option()) { + while let Some(opt) = self.maybe_parse(|parser| parser.parse_copy_legacy_option())? { legacy_options.push(opt); } let values = if let CopyTarget::Stdin = target { @@ -7453,7 +7455,7 @@ impl<'a> Parser<'a> { Some(Keyword::CSV) => CopyLegacyOption::Csv({ let mut opts = vec![]; while let Some(opt) = - self.maybe_parse(|parser| parser.parse_copy_legacy_csv_option()) + self.maybe_parse(|parser| parser.parse_copy_legacy_csv_option())? { opts.push(opt); } @@ -8035,7 +8037,7 @@ impl<'a> Parser<'a> { // Keyword::ARRAY syntax from above while self.consume_token(&Token::LBracket) { let size = if dialect_of!(self is GenericDialect | DuckDbDialect | PostgreSqlDialect) { - self.maybe_parse(|p| p.parse_literal_uint()) + self.maybe_parse(|p| p.parse_literal_uint())? } else { None }; @@ -8712,7 +8714,7 @@ impl<'a> Parser<'a> { } } - match self.maybe_parse(|parser| parser.parse_statement()) { + match self.maybe_parse(|parser| parser.parse_statement())? { Some(Statement::Explain { .. }) | Some(Statement::ExplainTable { .. }) => Err( ParserError::ParserError("Explain must be root of the plan".to_string()), ), @@ -8751,20 +8753,11 @@ impl<'a> Parser<'a> { } } - /// Call's [`Self::parse_query`] returning a `Box`'ed result. - /// - /// This function can be used to reduce the stack size required in debug - /// builds. Instead of `sizeof(Query)` only a pointer (`Box`) - /// is used. - pub fn parse_boxed_query(&mut self) -> Result, ParserError> { - self.parse_query().map(Box::new) - } - /// Parse a query expression, i.e. a `SELECT` statement optionally /// preceded with some `WITH` CTE declarations and optionally followed /// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't /// expect the initial keyword to be already consumed - pub fn parse_query(&mut self) -> Result { + pub fn parse_query(&mut self) -> Result, ParserError> { let _guard = self.recursion_counter.try_decrease()?; let with = if self.parse_keyword(Keyword::WITH) { Some(With { @@ -8787,7 +8780,8 @@ impl<'a> Parser<'a> { for_clause: None, settings: None, format_clause: None, - }) + } + .into()) } else if self.parse_keyword(Keyword::UPDATE) { Ok(Query { with, @@ -8801,9 +8795,10 @@ impl<'a> Parser<'a> { for_clause: None, settings: None, format_clause: None, - }) + } + .into()) } else { - let body = self.parse_boxed_query_body(self.dialect.prec_unknown())?; + let body = self.parse_query_body(self.dialect.prec_unknown())?; let order_by = self.parse_optional_order_by()?; @@ -8885,7 +8880,8 @@ impl<'a> Parser<'a> { for_clause, settings, format_clause, - }) + } + .into()) } } @@ -9022,7 +9018,7 @@ impl<'a> Parser<'a> { } } self.expect_token(&Token::LParen)?; - let query = self.parse_boxed_query()?; + let query = self.parse_query()?; self.expect_token(&Token::RParen)?; let alias = TableAlias { name, @@ -9046,7 +9042,7 @@ impl<'a> Parser<'a> { } } self.expect_token(&Token::LParen)?; - let query = self.parse_boxed_query()?; + let query = self.parse_query()?; self.expect_token(&Token::RParen)?; let alias = TableAlias { name, columns }; Cte { @@ -9062,15 +9058,6 @@ impl<'a> Parser<'a> { Ok(cte) } - /// Call's [`Self::parse_query_body`] returning a `Box`'ed result. - /// - /// This function can be used to reduce the stack size required in debug - /// builds. Instead of `sizeof(QueryBody)` only a pointer (`Box`) - /// is used. - fn parse_boxed_query_body(&mut self, precedence: u8) -> Result, ParserError> { - self.parse_query_body(precedence).map(Box::new) - } - /// Parse a "query body", which is an expression with roughly the /// following grammar: /// ```sql @@ -9079,17 +9066,14 @@ impl<'a> Parser<'a> { /// subquery ::= query_body [ order_by_limit ] /// set_operation ::= query_body { 'UNION' | 'EXCEPT' | 'INTERSECT' } [ 'ALL' ] query_body /// ``` - /// - /// If you need `Box` then maybe there is sense to use `parse_boxed_query_body` - /// due to prevent stack overflow in debug building(to reserve less memory on stack). - pub fn parse_query_body(&mut self, precedence: u8) -> Result { + pub fn parse_query_body(&mut self, precedence: u8) -> Result, ParserError> { // We parse the expression using a Pratt parser, as in `parse_expr()`. // Start by parsing a restricted SELECT or a `(subquery)`: let expr = if self.parse_keyword(Keyword::SELECT) { SetExpr::Select(self.parse_select().map(Box::new)?) } else if self.consume_token(&Token::LParen) { // CTEs are not allowed here, but the parser currently accepts them - let subquery = self.parse_boxed_query()?; + let subquery = self.parse_query()?; self.expect_token(&Token::RParen)?; SetExpr::Query(subquery) } else if self.parse_keyword(Keyword::VALUES) { @@ -9114,7 +9098,7 @@ impl<'a> Parser<'a> { &mut self, mut expr: SetExpr, precedence: u8, - ) -> Result { + ) -> Result, ParserError> { loop { // The query can be optionally followed by a set operator: let op = self.parse_set_operator(&self.peek_token().token); @@ -9135,11 +9119,11 @@ impl<'a> Parser<'a> { left: Box::new(expr), op: op.unwrap(), set_quantifier, - right: self.parse_boxed_query_body(next_precedence)?, + right: self.parse_query_body(next_precedence)?, }; } - Ok(expr) + Ok(expr.into()) } pub fn parse_set_operator(&mut self, token: &Token) -> Option { @@ -9466,7 +9450,7 @@ impl<'a> Parser<'a> { if let Some(Keyword::HIVEVAR) = modifier { self.expect_token(&Token::Colon)?; } else if let Some(set_role_stmt) = - self.maybe_parse(|parser| parser.parse_set_role(modifier)) + self.maybe_parse(|parser| parser.parse_set_role(modifier))? { return Ok(set_role_stmt); } @@ -9932,7 +9916,7 @@ impl<'a> Parser<'a> { // subquery, followed by the closing ')', and the alias of the derived table. // In the example above this is case (3). if let Some(mut table) = - self.maybe_parse(|parser| parser.parse_derived_table_factor(NotLateral)) + self.maybe_parse(|parser| parser.parse_derived_table_factor(NotLateral))? { while let Some(kw) = self.parse_one_of_keywords(&[Keyword::PIVOT, Keyword::UNPIVOT]) { @@ -10462,7 +10446,7 @@ impl<'a> Parser<'a> { &mut self, lateral: IsLateral, ) -> Result { - let subquery = self.parse_boxed_query()?; + let subquery = self.parse_query()?; self.expect_token(&Token::RParen)?; let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; Ok(TableFactor::Derived { @@ -10836,7 +10820,7 @@ impl<'a> Parser<'a> { } else { None }; - let source = self.parse_boxed_query()?; + let source = self.parse_query()?; Ok(Statement::Directory { local, path, @@ -10872,7 +10856,7 @@ impl<'a> Parser<'a> { vec![] }; - let source = Some(self.parse_boxed_query()?); + let source = Some(self.parse_query()?); (columns, partitioned, after_columns, source) }; @@ -11786,7 +11770,7 @@ impl<'a> Parser<'a> { pub fn parse_unload(&mut self) -> Result { self.expect_token(&Token::LParen)?; - let query = self.parse_boxed_query()?; + let query = self.parse_query()?; self.expect_token(&Token::RParen)?; self.expect_keyword(Keyword::TO)?; @@ -12130,7 +12114,9 @@ impl<'a> Parser<'a> { pub fn parse_window_spec(&mut self) -> Result { let window_name = match self.peek_token().token { - Token::Word(word) if word.keyword == Keyword::NoKeyword => self.parse_optional_indent(), + Token::Word(word) if word.keyword == Keyword::NoKeyword => { + self.parse_optional_indent()? + } _ => None, }; @@ -12342,10 +12328,8 @@ mod tests { #[test] fn test_ansii_character_string_types() { // Character string types: - let dialect = TestedDialects { - dialects: vec![Box::new(GenericDialect {}), Box::new(AnsiDialect {})], - options: None, - }; + let dialect = + TestedDialects::new(vec![Box::new(GenericDialect {}), Box::new(AnsiDialect {})]); test_parse_data_type!(dialect, "CHARACTER", DataType::Character(None)); @@ -12472,10 +12456,8 @@ mod tests { #[test] fn test_ansii_character_large_object_types() { // Character large object types: - let dialect = TestedDialects { - dialects: vec![Box::new(GenericDialect {}), Box::new(AnsiDialect {})], - options: None, - }; + let dialect = + TestedDialects::new(vec![Box::new(GenericDialect {}), Box::new(AnsiDialect {})]); test_parse_data_type!( dialect, @@ -12505,10 +12487,9 @@ mod tests { #[test] fn test_parse_custom_types() { - let dialect = TestedDialects { - dialects: vec![Box::new(GenericDialect {}), Box::new(AnsiDialect {})], - options: None, - }; + let dialect = + TestedDialects::new(vec![Box::new(GenericDialect {}), Box::new(AnsiDialect {})]); + test_parse_data_type!( dialect, "GEOMETRY", @@ -12537,10 +12518,8 @@ mod tests { #[test] fn test_ansii_exact_numeric_types() { // Exact numeric types: - let dialect = TestedDialects { - dialects: vec![Box::new(GenericDialect {}), Box::new(AnsiDialect {})], - options: None, - }; + let dialect = + TestedDialects::new(vec![Box::new(GenericDialect {}), Box::new(AnsiDialect {})]); test_parse_data_type!(dialect, "NUMERIC", DataType::Numeric(ExactNumberInfo::None)); @@ -12588,10 +12567,8 @@ mod tests { #[test] fn test_ansii_date_type() { // Datetime types: - let dialect = TestedDialects { - dialects: vec![Box::new(GenericDialect {}), Box::new(AnsiDialect {})], - options: None, - }; + let dialect = + TestedDialects::new(vec![Box::new(GenericDialect {}), Box::new(AnsiDialect {})]); test_parse_data_type!(dialect, "DATE", DataType::Date); @@ -12700,10 +12677,8 @@ mod tests { }}; } - let dialect = TestedDialects { - dialects: vec![Box::new(GenericDialect {}), Box::new(MySqlDialect {})], - options: None, - }; + let dialect = + TestedDialects::new(vec![Box::new(GenericDialect {}), Box::new(MySqlDialect {})]); test_parse_table_constraint!( dialect, @@ -12822,10 +12797,7 @@ mod tests { #[test] fn test_parse_multipart_identifier_positive() { - let dialect = TestedDialects { - dialects: vec![Box::new(GenericDialect {})], - options: None, - }; + let dialect = TestedDialects::new(vec![Box::new(GenericDialect {})]); // parse multipart with quotes let expected = vec![ diff --git a/src/test_utils.rs b/src/test_utils.rs index e588b3506..b35fc45c2 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -44,6 +44,7 @@ use pretty_assertions::assert_eq; pub struct TestedDialects { pub dialects: Vec>, pub options: Option, + pub recursion_limit: Option, } impl TestedDialects { @@ -52,16 +53,38 @@ impl TestedDialects { Self { dialects, options: None, + recursion_limit: None, } } + pub fn new_with_options(dialects: Vec>, options: ParserOptions) -> Self { + Self { + dialects, + options: Some(options), + recursion_limit: None, + } + } + + pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self { + self.recursion_limit = Some(recursion_limit); + self + } + fn new_parser<'a>(&self, dialect: &'a dyn Dialect) -> Parser<'a> { let parser = Parser::new(dialect); - if let Some(options) = &self.options { + let parser = if let Some(options) = &self.options { parser.with_options(options.clone()) } else { parser - } + }; + + let parser = if let Some(recursion_limit) = &self.recursion_limit { + parser.with_recursion_limit(*recursion_limit) + } else { + parser + }; + + parser } /// Run the given function for all of `self.dialects`, assert that they diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index 63517fe57..2bf470f71 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -40,10 +40,10 @@ fn parse_literal_string() { r#""""triple-double\"escaped""", "#, r#""""triple-double"unescaped""""#, ); - let dialect = TestedDialects { - dialects: vec![Box::new(BigQueryDialect {})], - options: Some(ParserOptions::new().with_unescape(false)), - }; + let dialect = TestedDialects::new_with_options( + vec![Box::new(BigQueryDialect {})], + ParserOptions::new().with_unescape(false), + ); let select = dialect.verified_only_select(sql); assert_eq!(10, select.projection.len()); assert_eq!( @@ -1936,17 +1936,14 @@ fn parse_big_query_declare() { } fn bigquery() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(BigQueryDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(BigQueryDialect {})]) } fn bigquery_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(BigQueryDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![ + Box::new(BigQueryDialect {}), + Box::new(GenericDialect {}), + ]) } #[test] diff --git a/tests/sqlparser_clickhouse.rs b/tests/sqlparser_clickhouse.rs index e30c33678..f8c349a37 100644 --- a/tests/sqlparser_clickhouse.rs +++ b/tests/sqlparser_clickhouse.rs @@ -1613,15 +1613,12 @@ fn parse_explain_table() { } fn clickhouse() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(ClickHouseDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(ClickHouseDialect {})]) } fn clickhouse_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(ClickHouseDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![ + Box::new(ClickHouseDialect {}), + Box::new(GenericDialect {}), + ]) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 5683bcf91..a2eb5070d 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -341,19 +341,16 @@ fn parse_update() { #[test] fn parse_update_set_from() { let sql = "UPDATE t1 SET name = t2.name FROM (SELECT name, id FROM t1 GROUP BY id) AS t2 WHERE t1.id = t2.id"; - let dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(DuckDbDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(BigQueryDialect {}), - Box::new(SnowflakeDialect {}), - Box::new(RedshiftSqlDialect {}), - Box::new(MsSqlDialect {}), - Box::new(SQLiteDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(DuckDbDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(BigQueryDialect {}), + Box::new(SnowflakeDialect {}), + Box::new(RedshiftSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(SQLiteDialect {}), + ]); let stmt = dialects.verified_stmt(sql); assert_eq!( stmt, @@ -1051,10 +1048,7 @@ fn test_eof_after_as() { #[test] fn test_no_infix_error() { - let dialects = TestedDialects { - dialects: vec![Box::new(ClickHouseDialect {})], - options: None, - }; + let dialects = TestedDialects::new(vec![Box::new(ClickHouseDialect {})]); let res = dialects.parse_sql_statements("ASSERT-URA<<"); assert_eq!( @@ -1182,23 +1176,20 @@ fn parse_null_in_select() { #[test] fn parse_exponent_in_select() -> Result<(), ParserError> { // all except Hive, as it allows numbers to start an identifier - let dialects = TestedDialects { - dialects: vec![ - Box::new(AnsiDialect {}), - Box::new(BigQueryDialect {}), - Box::new(ClickHouseDialect {}), - Box::new(DuckDbDialect {}), - Box::new(GenericDialect {}), - // Box::new(HiveDialect {}), - Box::new(MsSqlDialect {}), - Box::new(MySqlDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(RedshiftSqlDialect {}), - Box::new(SnowflakeDialect {}), - Box::new(SQLiteDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(AnsiDialect {}), + Box::new(BigQueryDialect {}), + Box::new(ClickHouseDialect {}), + Box::new(DuckDbDialect {}), + Box::new(GenericDialect {}), + // Box::new(HiveDialect {}), + Box::new(MsSqlDialect {}), + Box::new(MySqlDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(RedshiftSqlDialect {}), + Box::new(SnowflakeDialect {}), + Box::new(SQLiteDialect {}), + ]); let sql = "SELECT 10e-20, 1e3, 1e+3, 1e3a, 1e, 0.5e2"; let mut select = dialects.parse_sql_statements(sql)?; @@ -1271,14 +1262,12 @@ fn parse_escaped_single_quote_string_predicate_with_no_escape() { let sql = "SELECT id, fname, lname FROM customer \ WHERE salary <> 'Jim''s salary'"; - let ast = TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: Some( - ParserOptions::new() - .with_trailing_commas(true) - .with_unescape(false), - ), - } + let ast = TestedDialects::new_with_options( + vec![Box::new(MySqlDialect {})], + ParserOptions::new() + .with_trailing_commas(true) + .with_unescape(false), + ) .verified_only_select(sql); assert_eq!( @@ -1400,10 +1389,10 @@ fn parse_mod() { } fn pg_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(PostgreSqlDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![ + Box::new(PostgreSqlDialect {}), + Box::new(GenericDialect {}), + ]) } #[test] @@ -1868,14 +1857,13 @@ fn parse_string_agg() { /// selects all dialects but PostgreSQL pub fn all_dialects_but_pg() -> TestedDialects { - TestedDialects { - dialects: all_dialects() + TestedDialects::new( + all_dialects() .dialects .into_iter() .filter(|x| !x.is::()) .collect(), - options: None, - } + ) } #[test] @@ -2691,17 +2679,14 @@ fn parse_listagg() { #[test] fn parse_array_agg_func() { - let supported_dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(DuckDbDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(MsSqlDialect {}), - Box::new(AnsiDialect {}), - Box::new(HiveDialect {}), - ], - options: None, - }; + let supported_dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(DuckDbDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiDialect {}), + Box::new(HiveDialect {}), + ]); for sql in [ "SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T", @@ -2716,16 +2701,13 @@ fn parse_array_agg_func() { #[test] fn parse_agg_with_order_by() { - let supported_dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(MsSqlDialect {}), - Box::new(AnsiDialect {}), - Box::new(HiveDialect {}), - ], - options: None, - }; + let supported_dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiDialect {}), + Box::new(HiveDialect {}), + ]); for sql in [ "SELECT FIRST_VALUE(x ORDER BY x) AS a FROM T", @@ -2739,17 +2721,14 @@ fn parse_agg_with_order_by() { #[test] fn parse_window_rank_function() { - let supported_dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(MsSqlDialect {}), - Box::new(AnsiDialect {}), - Box::new(HiveDialect {}), - Box::new(SnowflakeDialect {}), - ], - options: None, - }; + let supported_dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiDialect {}), + Box::new(HiveDialect {}), + Box::new(SnowflakeDialect {}), + ]); for sql in [ "SELECT column1, column2, FIRST_VALUE(column2) OVER (PARTITION BY column1 ORDER BY column2 NULLS LAST) AS column2_first FROM t1", @@ -2761,10 +2740,10 @@ fn parse_window_rank_function() { supported_dialects.verified_stmt(sql); } - let supported_dialects_nulls = TestedDialects { - dialects: vec![Box::new(MsSqlDialect {}), Box::new(SnowflakeDialect {})], - options: None, - }; + let supported_dialects_nulls = TestedDialects::new(vec![ + Box::new(MsSqlDialect {}), + Box::new(SnowflakeDialect {}), + ]); for sql in [ "SELECT column1, column2, FIRST_VALUE(column2) IGNORE NULLS OVER (PARTITION BY column1 ORDER BY column2 NULLS LAST) AS column2_first FROM t1", @@ -3321,10 +3300,7 @@ fn parse_create_table_hive_array() { true, ), ] { - let dialects = TestedDialects { - dialects, - options: None, - }; + let dialects = TestedDialects::new(dialects); let sql = format!( "CREATE TABLE IF NOT EXISTS something (name INT, val {})", @@ -3374,14 +3350,11 @@ fn parse_create_table_hive_array() { } // SnowflakeDialect using array different - let dialects = TestedDialects { - dialects: vec![ - Box::new(PostgreSqlDialect {}), - Box::new(HiveDialect {}), - Box::new(MySqlDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(PostgreSqlDialect {}), + Box::new(HiveDialect {}), + Box::new(MySqlDialect {}), + ]); let sql = "CREATE TABLE IF NOT EXISTS something (name int, val array Result<(), Par #[test] fn parse_create_table_with_options() { - let generic = TestedDialects { - dialects: vec![Box::new(GenericDialect {})], - options: None, - }; + let generic = TestedDialects::new(vec![Box::new(GenericDialect {})]); let sql = "CREATE TABLE t (c INT) WITH (foo = 'bar', a = 123)"; match generic.verified_stmt(sql) { @@ -3695,10 +3662,7 @@ fn parse_create_table_clone() { #[test] fn parse_create_table_trailing_comma() { - let dialect = TestedDialects { - dialects: vec![Box::new(DuckDbDialect {})], - options: None, - }; + let dialect = TestedDialects::new(vec![Box::new(DuckDbDialect {})]); let sql = "CREATE TABLE foo (bar int,);"; dialect.one_statement_parses_to(sql, "CREATE TABLE foo (bar INT)"); @@ -4040,15 +4004,12 @@ fn parse_alter_table_add_column() { #[test] fn parse_alter_table_add_column_if_not_exists() { - let dialects = TestedDialects { - dialects: vec![ - Box::new(PostgreSqlDialect {}), - Box::new(BigQueryDialect {}), - Box::new(GenericDialect {}), - Box::new(DuckDbDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(PostgreSqlDialect {}), + Box::new(BigQueryDialect {}), + Box::new(GenericDialect {}), + Box::new(DuckDbDialect {}), + ]); match alter_table_op(dialects.verified_stmt("ALTER TABLE tab ADD IF NOT EXISTS foo TEXT")) { AlterTableOperation::AddColumn { if_not_exists, .. } => { @@ -4191,10 +4152,7 @@ fn parse_alter_table_alter_column_type() { _ => unreachable!(), } - let dialect = TestedDialects { - dialects: vec![Box::new(GenericDialect {})], - options: None, - }; + let dialect = TestedDialects::new(vec![Box::new(GenericDialect {})]); let res = dialect.parse_sql_statements(&format!("{alter_stmt} ALTER COLUMN is_active TYPE TEXT")); @@ -4611,15 +4569,12 @@ fn parse_window_functions() { #[test] fn parse_named_window_functions() { - let supported_dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(MySqlDialect {}), - Box::new(BigQueryDialect {}), - ], - options: None, - }; + let supported_dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MySqlDialect {}), + Box::new(BigQueryDialect {}), + ]); let sql = "SELECT row_number() OVER (w ORDER BY dt DESC), \ sum(foo) OVER (win PARTITION BY a, b ORDER BY c, d \ @@ -5684,10 +5639,10 @@ fn parse_unnest_in_from_clause() { let select = dialects.verified_only_select(sql); assert_eq!(select.from, want); } - let dialects = TestedDialects { - dialects: vec![Box::new(BigQueryDialect {}), Box::new(GenericDialect {})], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(BigQueryDialect {}), + Box::new(GenericDialect {}), + ]); // 1. both Alias and WITH OFFSET clauses. chk( "expr", @@ -6670,22 +6625,20 @@ fn parse_trim() { ); //keep Snowflake/BigQuery TRIM syntax failing - let all_expected_snowflake = TestedDialects { - dialects: vec![ - //Box::new(GenericDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(MsSqlDialect {}), - Box::new(AnsiDialect {}), - //Box::new(SnowflakeDialect {}), - Box::new(HiveDialect {}), - Box::new(RedshiftSqlDialect {}), - Box::new(MySqlDialect {}), - //Box::new(BigQueryDialect {}), - Box::new(SQLiteDialect {}), - Box::new(DuckDbDialect {}), - ], - options: None, - }; + let all_expected_snowflake = TestedDialects::new(vec![ + //Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiDialect {}), + //Box::new(SnowflakeDialect {}), + Box::new(HiveDialect {}), + Box::new(RedshiftSqlDialect {}), + Box::new(MySqlDialect {}), + //Box::new(BigQueryDialect {}), + Box::new(SQLiteDialect {}), + Box::new(DuckDbDialect {}), + ]); + assert_eq!( ParserError::ParserError("Expected: ), found: 'a'".to_owned()), all_expected_snowflake @@ -8582,20 +8535,17 @@ fn test_lock_nonblock() { #[test] fn test_placeholder() { - let dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(DuckDbDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(MsSqlDialect {}), - Box::new(AnsiDialect {}), - Box::new(BigQueryDialect {}), - Box::new(SnowflakeDialect {}), - // Note: `$` is the starting word for the HiveDialect identifier - // Box::new(sqlparser::dialect::HiveDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(DuckDbDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiDialect {}), + Box::new(BigQueryDialect {}), + Box::new(SnowflakeDialect {}), + // Note: `$` is the starting word for the HiveDialect identifier + // Box::new(sqlparser::dialect::HiveDialect {}), + ]); let sql = "SELECT * FROM student WHERE id = $Id1"; let ast = dialects.verified_only_select(sql); assert_eq!( @@ -8621,21 +8571,18 @@ fn test_placeholder() { }), ); - let dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(DuckDbDialect {}), - // Note: `?` is for jsonb operators in PostgreSqlDialect - // Box::new(PostgreSqlDialect {}), - Box::new(MsSqlDialect {}), - Box::new(AnsiDialect {}), - Box::new(BigQueryDialect {}), - Box::new(SnowflakeDialect {}), - // Note: `$` is the starting word for the HiveDialect identifier - // Box::new(sqlparser::dialect::HiveDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(DuckDbDialect {}), + // Note: `?` is for jsonb operators in PostgreSqlDialect + // Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiDialect {}), + Box::new(BigQueryDialect {}), + Box::new(SnowflakeDialect {}), + // Note: `$` is the starting word for the HiveDialect identifier + // Box::new(sqlparser::dialect::HiveDialect {}), + ]); let sql = "SELECT * FROM student WHERE id = ?"; let ast = dialects.verified_only_select(sql); assert_eq!( @@ -9023,7 +8970,7 @@ fn parse_cache_table() { value: Expr::Value(number("0.88")), }, ], - query: Some(query.clone()), + query: Some(query.clone().into()), } ); @@ -9048,7 +8995,7 @@ fn parse_cache_table() { value: Expr::Value(number("0.88")), }, ], - query: Some(query.clone()), + query: Some(query.clone().into()), } ); @@ -9059,7 +9006,7 @@ fn parse_cache_table() { table_name: ObjectName(vec![Ident::with_quote('\'', cache_table_name)]), has_as: false, options: vec![], - query: Some(query.clone()), + query: Some(query.clone().into()), } ); @@ -9070,7 +9017,7 @@ fn parse_cache_table() { table_name: ObjectName(vec![Ident::with_quote('\'', cache_table_name)]), has_as: true, options: vec![], - query: Some(query), + query: Some(query.into()), } ); @@ -9243,14 +9190,11 @@ fn parse_with_recursion_limit() { #[test] fn parse_escaped_string_with_unescape() { fn assert_mysql_query_value(sql: &str, quoted: &str) { - let stmt = TestedDialects { - dialects: vec![ - Box::new(MySqlDialect {}), - Box::new(BigQueryDialect {}), - Box::new(SnowflakeDialect {}), - ], - options: None, - } + let stmt = TestedDialects::new(vec![ + Box::new(MySqlDialect {}), + Box::new(BigQueryDialect {}), + Box::new(SnowflakeDialect {}), + ]) .one_statement_parses_to(sql, ""); match stmt { @@ -9283,14 +9227,14 @@ fn parse_escaped_string_with_unescape() { #[test] fn parse_escaped_string_without_unescape() { fn assert_mysql_query_value(sql: &str, quoted: &str) { - let stmt = TestedDialects { - dialects: vec![ + let stmt = TestedDialects::new_with_options( + vec![ Box::new(MySqlDialect {}), Box::new(BigQueryDialect {}), Box::new(SnowflakeDialect {}), ], - options: Some(ParserOptions::new().with_unescape(false)), - } + ParserOptions::new().with_unescape(false), + ) .one_statement_parses_to(sql, ""); match stmt { @@ -9558,17 +9502,14 @@ fn make_where_clause(num: usize) -> String { #[test] fn parse_non_latin_identifiers() { - let supported_dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(DuckDbDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(MsSqlDialect {}), - Box::new(RedshiftSqlDialect {}), - Box::new(MySqlDialect {}), - ], - options: None, - }; + let supported_dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(DuckDbDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(RedshiftSqlDialect {}), + Box::new(MySqlDialect {}), + ]); supported_dialects.verified_stmt("SELECT a.説明 FROM test.public.inter01 AS a"); supported_dialects.verified_stmt("SELECT a.説明 FROM inter01 AS a, inter01_transactions AS b WHERE a.説明 = b.取引 GROUP BY a.説明"); @@ -9582,10 +9523,7 @@ fn parse_non_latin_identifiers() { fn parse_trailing_comma() { // At the moment, DuckDB is the only dialect that allows // trailing commas anywhere in the query - let trailing_commas = TestedDialects { - dialects: vec![Box::new(DuckDbDialect {})], - options: None, - }; + let trailing_commas = TestedDialects::new(vec![Box::new(DuckDbDialect {})]); trailing_commas.one_statement_parses_to( "SELECT album_id, name, FROM track", @@ -9624,10 +9562,7 @@ fn parse_trailing_comma() { trailing_commas.verified_stmt(r#"SELECT "from" FROM "from""#); // doesn't allow any trailing commas - let trailing_commas = TestedDialects { - dialects: vec![Box::new(GenericDialect {})], - options: None, - }; + let trailing_commas = TestedDialects::new(vec![Box::new(GenericDialect {})]); assert_eq!( trailing_commas @@ -9656,10 +9591,10 @@ fn parse_trailing_comma() { #[test] fn parse_projection_trailing_comma() { // Some dialects allow trailing commas only in the projection - let trailing_commas = TestedDialects { - dialects: vec![Box::new(SnowflakeDialect {}), Box::new(BigQueryDialect {})], - options: None, - }; + let trailing_commas = TestedDialects::new(vec![ + Box::new(SnowflakeDialect {}), + Box::new(BigQueryDialect {}), + ]); trailing_commas.one_statement_parses_to( "SELECT album_id, name, FROM track", @@ -9946,14 +9881,11 @@ fn test_release_savepoint() { #[test] fn test_comment_hash_syntax() { - let dialects = TestedDialects { - dialects: vec![ - Box::new(BigQueryDialect {}), - Box::new(SnowflakeDialect {}), - Box::new(MySqlDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(BigQueryDialect {}), + Box::new(SnowflakeDialect {}), + Box::new(MySqlDialect {}), + ]); let sql = r#" # comment SELECT a, b, c # , d, e @@ -10013,10 +9945,10 @@ fn test_buffer_reuse() { #[test] fn parse_map_access_expr() { let sql = "users[-1][safe_offset(2)]"; - let dialects = TestedDialects { - dialects: vec![Box::new(BigQueryDialect {}), Box::new(ClickHouseDialect {})], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(BigQueryDialect {}), + Box::new(ClickHouseDialect {}), + ]); let expr = dialects.verified_expr(sql); let expected = Expr::MapAccess { column: Expr::Identifier(Ident::new("users")).into(), @@ -10591,16 +10523,13 @@ fn test_match_recognize_patterns() { #[test] fn test_select_wildcard_with_replace() { let sql = r#"SELECT * REPLACE (lower(city) AS city) FROM addresses"#; - let dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(BigQueryDialect {}), - Box::new(ClickHouseDialect {}), - Box::new(SnowflakeDialect {}), - Box::new(DuckDbDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(BigQueryDialect {}), + Box::new(ClickHouseDialect {}), + Box::new(SnowflakeDialect {}), + Box::new(DuckDbDialect {}), + ]); let select = dialects.verified_only_select(sql); let expected = SelectItem::Wildcard(WildcardAdditionalOptions { opt_replace: Some(ReplaceSelectItem { @@ -10657,14 +10586,11 @@ fn test_select_wildcard_with_replace() { #[test] fn parse_sized_list() { - let dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(DuckDbDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(DuckDbDialect {}), + ]); let sql = r#"CREATE TABLE embeddings (data FLOAT[1536])"#; dialects.verified_stmt(sql); let sql = r#"CREATE TABLE embeddings (data FLOAT[1536][3])"#; @@ -10675,14 +10601,11 @@ fn parse_sized_list() { #[test] fn insert_into_with_parentheses() { - let dialects = TestedDialects { - dialects: vec![ - Box::new(SnowflakeDialect {}), - Box::new(RedshiftSqlDialect {}), - Box::new(GenericDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(SnowflakeDialect {}), + Box::new(RedshiftSqlDialect {}), + Box::new(GenericDialect {}), + ]); dialects.verified_stmt("INSERT INTO t1 (id, name) (SELECT t2.id, t2.name FROM t2)"); } @@ -10850,14 +10773,11 @@ fn parse_within_group() { #[test] fn tests_select_values_without_parens() { - let dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(SnowflakeDialect {}), - Box::new(DatabricksDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(SnowflakeDialect {}), + Box::new(DatabricksDialect {}), + ]); let sql = "SELECT * FROM VALUES (1, 2), (2,3) AS tbl (id, val)"; let canonical = "SELECT * FROM (VALUES (1, 2), (2, 3)) AS tbl (id, val)"; dialects.verified_only_select_with_canonical(sql, canonical); @@ -10865,14 +10785,12 @@ fn tests_select_values_without_parens() { #[test] fn tests_select_values_without_parens_and_set_op() { - let dialects = TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(SnowflakeDialect {}), - Box::new(DatabricksDialect {}), - ], - options: None, - }; + let dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(SnowflakeDialect {}), + Box::new(DatabricksDialect {}), + ]); + let sql = "SELECT id + 1, name FROM VALUES (1, 'Apple'), (2, 'Banana'), (3, 'Orange') AS fruits (id, name) UNION ALL SELECT 5, 'Strawberry'"; let canonical = "SELECT id + 1, name FROM (VALUES (1, 'Apple'), (2, 'Banana'), (3, 'Orange')) AS fruits (id, name) UNION ALL SELECT 5, 'Strawberry'"; let query = dialects.verified_query_with_canonical(sql, canonical); diff --git a/tests/sqlparser_databricks.rs b/tests/sqlparser_databricks.rs index 7dcfee68a..7b917bd06 100644 --- a/tests/sqlparser_databricks.rs +++ b/tests/sqlparser_databricks.rs @@ -24,17 +24,14 @@ use test_utils::*; mod test_utils; fn databricks() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(DatabricksDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(DatabricksDialect {})]) } fn databricks_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(DatabricksDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![ + Box::new(DatabricksDialect {}), + Box::new(GenericDialect {}), + ]) } #[test] diff --git a/tests/sqlparser_duckdb.rs b/tests/sqlparser_duckdb.rs index 4703f4b60..a4109b0a3 100644 --- a/tests/sqlparser_duckdb.rs +++ b/tests/sqlparser_duckdb.rs @@ -24,17 +24,14 @@ use sqlparser::ast::*; use sqlparser::dialect::{DuckDbDialect, GenericDialect}; fn duckdb() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(DuckDbDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(DuckDbDialect {})]) } fn duckdb_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(DuckDbDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![ + Box::new(DuckDbDialect {}), + Box::new(GenericDialect {}), + ]) } #[test] @@ -242,7 +239,7 @@ fn test_create_table_macro() { MacroArg::new("col1_value"), MacroArg::new("col2_value"), ]), - definition: MacroDefinition::Table(duckdb().verified_query(query)), + definition: MacroDefinition::Table(duckdb().verified_query(query).into()), }; assert_eq!(expected, macro_); } diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 069500bf6..10bd374c0 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -418,10 +418,7 @@ fn parse_create_function() { } // Test error in dialect that doesn't support parsing CREATE FUNCTION - let unsupported_dialects = TestedDialects { - dialects: vec![Box::new(MsSqlDialect {})], - options: None, - }; + let unsupported_dialects = TestedDialects::new(vec![Box::new(MsSqlDialect {})]); assert_eq!( unsupported_dialects.parse_sql_statements(sql).unwrap_err(), @@ -538,15 +535,9 @@ fn parse_use() { } fn hive() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(HiveDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(HiveDialect {})]) } fn hive_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(HiveDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(HiveDialect {}), Box::new(GenericDialect {})]) } diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 58765f6c0..0223e2915 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -1030,14 +1030,8 @@ fn parse_create_table_with_identity_column() { } fn ms() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(MsSqlDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(MsSqlDialect {})]) } fn ms_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(MsSqlDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(MsSqlDialect {}), Box::new(GenericDialect {})]) } diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 19dbda21f..db5b9ec8d 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -944,11 +944,7 @@ fn parse_quote_identifiers() { fn parse_escaped_quote_identifiers_with_escape() { let sql = "SELECT `quoted `` identifier`"; assert_eq!( - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: None, - } - .verified_stmt(sql), + TestedDialects::new(vec![Box::new(MySqlDialect {})]).verified_stmt(sql), Statement::Query(Box::new(Query { with: None, body: Box::new(SetExpr::Select(Box::new(Select { @@ -991,13 +987,13 @@ fn parse_escaped_quote_identifiers_with_escape() { fn parse_escaped_quote_identifiers_with_no_escape() { let sql = "SELECT `quoted `` identifier`"; assert_eq!( - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: Some(ParserOptions { + TestedDialects::new_with_options( + vec![Box::new(MySqlDialect {})], + ParserOptions { trailing_commas: false, unescape: false, - }), - } + } + ) .verified_stmt(sql), Statement::Query(Box::new(Query { with: None, @@ -1041,11 +1037,7 @@ fn parse_escaped_quote_identifiers_with_no_escape() { fn parse_escaped_backticks_with_escape() { let sql = "SELECT ```quoted identifier```"; assert_eq!( - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: None, - } - .verified_stmt(sql), + TestedDialects::new(vec![Box::new(MySqlDialect {})]).verified_stmt(sql), Statement::Query(Box::new(Query { with: None, body: Box::new(SetExpr::Select(Box::new(Select { @@ -1088,10 +1080,10 @@ fn parse_escaped_backticks_with_escape() { fn parse_escaped_backticks_with_no_escape() { let sql = "SELECT ```quoted identifier```"; assert_eq!( - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: Some(ParserOptions::new().with_unescape(false)), - } + TestedDialects::new_with_options( + vec![Box::new(MySqlDialect {})], + ParserOptions::new().with_unescape(false) + ) .verified_stmt(sql), Statement::Query(Box::new(Query { with: None, @@ -1144,55 +1136,26 @@ fn parse_unterminated_escape() { #[test] fn check_roundtrip_of_escaped_string() { - let options = Some(ParserOptions::new().with_unescape(false)); - - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: options.clone(), - } - .verified_stmt(r"SELECT 'I\'m fine'"); - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: options.clone(), - } - .verified_stmt(r#"SELECT 'I''m fine'"#); - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: options.clone(), - } - .verified_stmt(r"SELECT 'I\\\'m fine'"); - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: options.clone(), - } - .verified_stmt(r"SELECT 'I\\\'m fine'"); - - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: options.clone(), - } - .verified_stmt(r#"SELECT "I\"m fine""#); - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: options.clone(), - } - .verified_stmt(r#"SELECT "I""m fine""#); - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: options.clone(), - } - .verified_stmt(r#"SELECT "I\\\"m fine""#); - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: options.clone(), - } - .verified_stmt(r#"SELECT "I\\\"m fine""#); - - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options, - } - .verified_stmt(r#"SELECT "I'm ''fine''""#); + let options = ParserOptions::new().with_unescape(false); + + TestedDialects::new_with_options(vec![Box::new(MySqlDialect {})], options.clone()) + .verified_stmt(r"SELECT 'I\'m fine'"); + TestedDialects::new_with_options(vec![Box::new(MySqlDialect {})], options.clone()) + .verified_stmt(r#"SELECT 'I''m fine'"#); + TestedDialects::new_with_options(vec![Box::new(MySqlDialect {})], options.clone()) + .verified_stmt(r"SELECT 'I\\\'m fine'"); + TestedDialects::new_with_options(vec![Box::new(MySqlDialect {})], options.clone()) + .verified_stmt(r"SELECT 'I\\\'m fine'"); + TestedDialects::new_with_options(vec![Box::new(MySqlDialect {})], options.clone()) + .verified_stmt(r#"SELECT "I\"m fine""#); + TestedDialects::new_with_options(vec![Box::new(MySqlDialect {})], options.clone()) + .verified_stmt(r#"SELECT "I""m fine""#); + TestedDialects::new_with_options(vec![Box::new(MySqlDialect {})], options.clone()) + .verified_stmt(r#"SELECT "I\\\"m fine""#); + TestedDialects::new_with_options(vec![Box::new(MySqlDialect {})], options.clone()) + .verified_stmt(r#"SELECT "I\\\"m fine""#); + TestedDialects::new_with_options(vec![Box::new(MySqlDialect {})], options.clone()) + .verified_stmt(r#"SELECT "I'm ''fine''""#); } #[test] @@ -2624,17 +2587,11 @@ fn parse_create_table_with_fulltext_definition_should_not_accept_constraint_name } fn mysql() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(MySqlDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(MySqlDialect {})]) } fn mysql_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(MySqlDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(MySqlDialect {}), Box::new(GenericDialect {})]) } #[test] diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index bd37214ce..b9b3811ba 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -2973,17 +2973,14 @@ fn parse_on_commit() { } fn pg() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(PostgreSqlDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(PostgreSqlDialect {})]) } fn pg_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(PostgreSqlDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![ + Box::new(PostgreSqlDialect {}), + Box::new(GenericDialect {}), + ]) } #[test] diff --git a/tests/sqlparser_redshift.rs b/tests/sqlparser_redshift.rs index eeba37957..a25d50605 100644 --- a/tests/sqlparser_redshift.rs +++ b/tests/sqlparser_redshift.rs @@ -171,17 +171,14 @@ fn parse_delimited_identifiers() { } fn redshift() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(RedshiftSqlDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(RedshiftSqlDialect {})]) } fn redshift_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(RedshiftSqlDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![ + Box::new(RedshiftSqlDialect {}), + Box::new(GenericDialect {}), + ]) } #[test] diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index d7e967ffe..c17c7b958 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -854,10 +854,8 @@ fn parse_sf_create_or_replace_view_with_comment_missing_equal() { #[test] fn parse_sf_create_or_replace_with_comment_for_snowflake() { let sql = "CREATE OR REPLACE VIEW v COMMENT = 'hello, world' AS SELECT 1"; - let dialect = test_utils::TestedDialects { - dialects: vec![Box::new(SnowflakeDialect {}) as Box], - options: None, - }; + let dialect = + test_utils::TestedDialects::new(vec![Box::new(SnowflakeDialect {}) as Box]); match dialect.verified_stmt(sql) { Statement::CreateView { @@ -1250,24 +1248,25 @@ fn test_array_agg_func() { } fn snowflake() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(SnowflakeDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(SnowflakeDialect {})]) +} + +fn snowflake_with_recursion_limit(recursion_limit: usize) -> TestedDialects { + TestedDialects::new(vec![Box::new(SnowflakeDialect {})]).with_recursion_limit(recursion_limit) } fn snowflake_without_unescape() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(SnowflakeDialect {})], - options: Some(ParserOptions::new().with_unescape(false)), - } + TestedDialects::new_with_options( + vec![Box::new(SnowflakeDialect {})], + ParserOptions::new().with_unescape(false), + ) } fn snowflake_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(SnowflakeDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![ + Box::new(SnowflakeDialect {}), + Box::new(GenericDialect {}), + ]) } #[test] @@ -2759,3 +2758,26 @@ fn parse_view_column_descriptions() { _ => unreachable!(), }; } + +#[test] +fn test_parentheses_overflow() { + let max_nesting_level: usize = 30; + + // Verify the recursion check is not too wasteful... (num of parentheses - 2 is acceptable) + let slack = 2; + let l_parens = "(".repeat(max_nesting_level - slack); + let r_parens = ")".repeat(max_nesting_level - slack); + let sql = format!("SELECT * FROM {l_parens}a.b.c{r_parens}"); + let parsed = + snowflake_with_recursion_limit(max_nesting_level).parse_sql_statements(sql.as_str()); + assert_eq!(parsed.err(), None); + + // Verify the recursion check triggers... (num of parentheses - 1 is acceptable) + let slack = 1; + let l_parens = "(".repeat(max_nesting_level - slack); + let r_parens = ")".repeat(max_nesting_level - slack); + let sql = format!("SELECT * FROM {l_parens}a.b.c{r_parens}"); + let parsed = + snowflake_with_recursion_limit(max_nesting_level).parse_sql_statements(sql.as_str()); + assert_eq!(parsed.err(), Some(ParserError::RecursionLimitExceeded)); +} diff --git a/tests/sqlparser_sqlite.rs b/tests/sqlparser_sqlite.rs index d3e670e32..6f8bbb2d8 100644 --- a/tests/sqlparser_sqlite.rs +++ b/tests/sqlparser_sqlite.rs @@ -529,14 +529,13 @@ fn parse_start_transaction_with_modifier() { sqlite_and_generic().one_statement_parses_to("BEGIN IMMEDIATE", "BEGIN IMMEDIATE TRANSACTION"); sqlite_and_generic().one_statement_parses_to("BEGIN EXCLUSIVE", "BEGIN EXCLUSIVE TRANSACTION"); - let unsupported_dialects = TestedDialects { - dialects: all_dialects() + let unsupported_dialects = TestedDialects::new( + all_dialects() .dialects .into_iter() .filter(|x| !(x.is::() || x.is::())) .collect(), - options: None, - }; + ); let res = unsupported_dialects.parse_sql_statements("BEGIN DEFERRED"); assert_eq!( ParserError::ParserError("Expected: end of statement, found: DEFERRED".to_string()), @@ -571,22 +570,16 @@ fn test_dollar_identifier_as_placeholder() { } fn sqlite() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(SQLiteDialect {})], - options: None, - } + TestedDialects::new(vec![Box::new(SQLiteDialect {})]) } fn sqlite_with_options(options: ParserOptions) -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(SQLiteDialect {})], - options: Some(options), - } + TestedDialects::new_with_options(vec![Box::new(SQLiteDialect {})], options) } fn sqlite_and_generic() -> TestedDialects { - TestedDialects { - dialects: vec![Box::new(SQLiteDialect {}), Box::new(GenericDialect {})], - options: None, - } + TestedDialects::new(vec![ + Box::new(SQLiteDialect {}), + Box::new(GenericDialect {}), + ]) }