diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 625f9ce0a..1e786e146 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -124,6 +124,10 @@ pub trait Dialect: Debug + Any { fn supports_substring_from_for_expr(&self) -> bool { true } + /// Returns true if the dialect supports `(NOT) IN ()` expressions + fn supports_in_empty_list(&self) -> bool { + false + } /// Dialect-specific prefix parser override fn parse_prefix(&self, _parser: &mut Parser) -> Option> { // return None to fall back to the default behavior diff --git a/src/dialect/sqlite.rs b/src/dialect/sqlite.rs index 68515d24f..c9e9ab185 100644 --- a/src/dialect/sqlite.rs +++ b/src/dialect/sqlite.rs @@ -52,4 +52,8 @@ impl Dialect for SQLiteDialect { None } } + + fn supports_in_empty_list(&self) -> bool { + true + } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index f83f019ea..5123c2957 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -2124,7 +2124,11 @@ impl<'a> Parser<'a> { } else { Expr::InList { expr: Box::new(expr), - list: self.parse_comma_separated(Parser::parse_expr)?, + list: if self.dialect.supports_in_empty_list() { + self.parse_comma_separated0(Parser::parse_expr)? + } else { + self.parse_comma_separated(Parser::parse_expr)? + }, negated, } }; @@ -2460,29 +2464,69 @@ impl<'a> Parser<'a> { let mut values = vec![]; loop { values.push(f(self)?); + if !self.consume_token(&Token::Comma) + || self.options.trailing_commas + && Self::is_comma_separated_end(&self.peek_token().token) + { + break; + } + } + Ok(values) + } + + /// Parse a comma-separated list of 0+ items accepted by `F` + pub fn parse_comma_separated0(&mut self, mut f: F) -> Result, ParserError> + where + F: FnMut(&mut Parser<'a>) -> Result, + { + let mut values = vec![]; + let index = self.index; + match f(self) { + Ok(v) => values.push(v), + Err(e) => { + // FIXME: this is a workaround because f (e.g. Parser::parse_expr) + // might eat tokens even thought it fails. + self.index = index; + let peek_token = &self.peek_token().token; + return if Self::is_comma_separated_end(peek_token) { + Ok(values) + } else if matches!(peek_token, Token::Comma) && self.options.trailing_commas { + let _ = self.consume_token(&Token::Comma); + Ok(values) + } else { + Err(e) + }; + } + } + loop { if !self.consume_token(&Token::Comma) { break; - } else if self.options.trailing_commas { - match self.peek_token().token { - Token::Word(kw) - if keywords::RESERVED_FOR_COLUMN_ALIAS - .iter() - .any(|d| kw.keyword == *d) => - { - break; - } - Token::RParen - | Token::SemiColon - | Token::EOF - | Token::RBracket - | Token::RBrace => break, - _ => continue, + } else { + if self.options.trailing_commas + && Self::is_comma_separated_end(&self.peek_token().token) + { + break; } + values.push(f(self)?); } } Ok(values) } + fn is_comma_separated_end(token: &Token) -> bool { + match token { + Token::Word(kw) + if keywords::RESERVED_FOR_COLUMN_ALIAS + .iter() + .any(|d| kw.keyword == *d) => + { + true + } + Token::RParen | Token::SemiColon | Token::EOF | Token::RBracket | Token::RBrace => true, + _ => false, + } + } + /// Run a parser method `f`, reverting back to the current position /// if unsuccessful. #[must_use] @@ -8374,4 +8418,90 @@ mod tests { panic!("fail to parse mysql partition selection"); } } + + #[test] + fn test_comma_separated0() { + let sql = "1, 2, 3"; + let ast = Parser::new(&GenericDialect) + .try_with_sql(sql) + .unwrap() + .parse_comma_separated0(Parser::parse_expr); + #[cfg(feature = "bigdecimal")] + assert_eq!( + ast, + Ok(vec![ + Expr::Value(Value::Number(bigdecimal::BigDecimal::from(1), false)), + Expr::Value(Value::Number(bigdecimal::BigDecimal::from(2), false)), + Expr::Value(Value::Number(bigdecimal::BigDecimal::from(3), false)), + ]) + ); + #[cfg(not(feature = "bigdecimal"))] + assert_eq!( + ast, + Ok(vec![ + Expr::Value(Value::Number("1".to_string(), false)), + Expr::Value(Value::Number("2".to_string(), false)), + Expr::Value(Value::Number("3".to_string(), false)), + ]) + ); + + let sql = ""; + let ast = Parser::new(&GenericDialect) + .try_with_sql(sql) + .unwrap() + .parse_comma_separated0(Parser::parse_expr); + assert_eq!(ast, Ok(vec![])); + + let sql = ","; + let ast = Parser::new(&GenericDialect) + .try_with_sql(sql) + .unwrap() + .parse_comma_separated0(Parser::parse_expr); + assert_eq!( + ast, + Err(ParserError::ParserError( + "Expected an expression:, found: , at Line: 1, Column 1".to_string() + )) + ); + + let sql = ","; + let ast = Parser::new(&GenericDialect) + .with_options(ParserOptions::new().with_trailing_commas(true)) + .try_with_sql(sql) + .unwrap() + .parse_comma_separated0(Parser::parse_expr); + assert_eq!(ast, Ok(vec![])); + + let sql = "1,"; + let ast = Parser::new(&GenericDialect) + .try_with_sql(sql) + .unwrap() + .parse_comma_separated0(Parser::parse_expr); + assert_eq!( + ast, + Err(ParserError::ParserError( + "Expected an expression:, found: EOF".to_string() + )) + ); + + let sql = "1,"; + let ast = Parser::new(&GenericDialect) + .with_options(ParserOptions::new().with_trailing_commas(true)) + .try_with_sql(sql) + .unwrap() + .parse_comma_separated0(Parser::parse_expr); + #[cfg(feature = "bigdecimal")] + assert_eq!( + ast, + Ok(vec![Expr::Value(Value::Number( + bigdecimal::BigDecimal::from(1), + false + )),]) + ); + #[cfg(not(feature = "bigdecimal"))] + assert_eq!( + ast, + Ok(vec![Expr::Value(Value::Number("1".to_string(), false)),]) + ); + } } diff --git a/tests/sqlparser_sqlite.rs b/tests/sqlparser_sqlite.rs index 2fdd4e3de..0a727ed91 100644 --- a/tests/sqlparser_sqlite.rs +++ b/tests/sqlparser_sqlite.rs @@ -22,6 +22,7 @@ use test_utils::*; use sqlparser::ast::SelectItem::UnnamedExpr; use sqlparser::ast::*; use sqlparser::dialect::{GenericDialect, SQLiteDialect}; +use sqlparser::parser::ParserOptions; use sqlparser::tokenizer::Token; #[test] @@ -386,6 +387,22 @@ fn parse_attach_database() { } } +#[test] +fn parse_where_in_empty_list() { + let sql = "SELECT * FROM t1 WHERE a IN ()"; + let select = sqlite().verified_only_select(sql); + if let Expr::InList { list, .. } = select.selection.as_ref().unwrap() { + assert_eq!(list.len(), 0); + } else { + unreachable!() + } + + sqlite_with_options(ParserOptions::new().with_trailing_commas(true)).one_statement_parses_to( + "SELECT * FROM t1 WHERE a IN (,)", + "SELECT * FROM t1 WHERE a IN ()", + ); +} + fn sqlite() -> TestedDialects { TestedDialects { dialects: vec![Box::new(SQLiteDialect {})], @@ -393,6 +410,13 @@ fn sqlite() -> TestedDialects { } } +fn sqlite_with_options(options: ParserOptions) -> TestedDialects { + TestedDialects { + dialects: vec![Box::new(SQLiteDialect {})], + options: Some(options), + } +} + fn sqlite_and_generic() -> TestedDialects { TestedDialects { // we don't have a separate SQLite dialect, so test only the generic dialect for now