From 64c232ee121a277115d34f4dc74f049ec1f1bb0c Mon Sep 17 00:00:00 2001 From: lovasoa Date: Wed, 18 Oct 2023 18:48:42 +0200 Subject: [PATCH] support FILTER clause in window functions fixes #1006 --- README.md | 2 +- src/ast/mod.rs | 9 +++++++++ src/ast/visitor.rs | 2 +- src/dialect/sqlite.rs | 4 ++++ src/parser/mod.rs | 14 ++++++++++++++ tests/sqlparser_bigquery.rs | 1 + tests/sqlparser_clickhouse.rs | 4 ++++ tests/sqlparser_common.rs | 19 +++++++++++++++++++ tests/sqlparser_hive.rs | 1 + tests/sqlparser_mssql.rs | 1 + tests/sqlparser_mysql.rs | 6 ++++++ tests/sqlparser_postgres.rs | 6 ++++++ tests/sqlparser_redshift.rs | 1 + tests/sqlparser_snowflake.rs | 1 + tests/sqlparser_sqlite.rs | 25 +++++++++++++++++++++++++ 15 files changed, 94 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 454ea6c29..e987c2a21 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ println!("AST: {:?}", ast); This outputs ```rust -AST: [Query(Query { ctes: [], body: Select(Select { distinct: false, projection: [UnnamedExpr(Identifier("a")), UnnamedExpr(Identifier("b")), UnnamedExpr(Value(Long(123))), UnnamedExpr(Function(Function { name: ObjectName(["myfunc"]), args: [Identifier("b")], over: None, distinct: false }))], from: [TableWithJoins { relation: Table { name: ObjectName(["table_1"]), alias: None, args: [], with_hints: [] }, joins: [] }], selection: Some(BinaryOp { left: BinaryOp { left: Identifier("a"), op: Gt, right: Identifier("b") }, op: And, right: BinaryOp { left: Identifier("b"), op: Lt, right: Value(Long(100)) } }), group_by: [], having: None }), order_by: [OrderByExpr { expr: Identifier("a"), asc: Some(false) }, OrderByExpr { expr: Identifier("b"), asc: None }], limit: None, offset: None, fetch: None })] +AST: [Query(Query { ctes: [], body: Select(Select { distinct: false, projection: [UnnamedExpr(Identifier("a")), UnnamedExpr(Identifier("b")), UnnamedExpr(Value(Long(123))), UnnamedExpr(Function(Function { name: ObjectName(["myfunc"]), args: [Identifier("b")], filter: None, over: None, distinct: false }))], from: [TableWithJoins { relation: Table { name: ObjectName(["table_1"]), alias: None, args: [], with_hints: [] }, joins: [] }], selection: Some(BinaryOp { left: BinaryOp { left: Identifier("a"), op: Gt, right: Identifier("b") }, op: And, right: BinaryOp { left: Identifier("b"), op: Lt, right: Value(Long(100)) } }), group_by: [], having: None }), order_by: [OrderByExpr { expr: Identifier("a"), asc: Some(false) }, OrderByExpr { expr: Identifier("b"), asc: None }], limit: None, offset: None, fetch: None })] ``` diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 87f7ebb37..cab7616b8 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -991,8 +991,11 @@ impl Display for WindowType { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct WindowSpec { + // OVER (PARTITION BY ...) pub partition_by: Vec, + // OVER (ORDER BY ...) pub order_by: Vec, + // OVER (window frame) pub window_frame: Option, } @@ -3650,6 +3653,8 @@ impl fmt::Display for CloseCursor { pub struct Function { pub name: ObjectName, pub args: Vec, + // e.g. `x > 5` in `COUNT(x) FILTER (WHERE x > 5)` + pub filter: Option>, pub over: Option, // aggregate functions may specify eg `COUNT(DISTINCT x)` pub distinct: bool, @@ -3698,6 +3703,10 @@ impl fmt::Display for Function { display_comma_separated(&self.order_by), )?; + if let Some(filter_cond) = &self.filter { + write!(f, " FILTER (WHERE {filter_cond})")?; + } + if let Some(o) = &self.over { write!(f, " OVER {o}")?; } diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 09cb20a0c..4e025f962 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -506,7 +506,7 @@ where /// *expr = Expr::Function(Function { /// name: ObjectName(vec![Ident::new("f")]), /// args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))], -/// over: None, distinct: false, special: false, order_by: vec![], +/// filter: None, over: None, distinct: false, special: false, order_by: vec![], /// }); /// } /// ControlFlow::<()>::Continue(()) diff --git a/src/dialect/sqlite.rs b/src/dialect/sqlite.rs index fa21224f6..37c7c7fa7 100644 --- a/src/dialect/sqlite.rs +++ b/src/dialect/sqlite.rs @@ -35,6 +35,10 @@ impl Dialect for SQLiteDialect { || ('\u{007f}'..='\u{ffff}').contains(&ch) } + fn supports_filter_during_aggregation(&self) -> bool { + true + } + fn is_identifier_part(&self, ch: char) -> bool { self.is_identifier_start(ch) || ch.is_ascii_digit() } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 95f1f8edc..8f5d7a757 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -772,6 +772,7 @@ impl<'a> Parser<'a> { Ok(Expr::Function(Function { name: ObjectName(vec![w.to_ident()]), args: vec![], + filter: None, over: None, distinct: false, special: true, @@ -957,6 +958,17 @@ impl<'a> Parser<'a> { self.expect_token(&Token::LParen)?; let distinct = self.parse_all_or_distinct()?.is_some(); let (args, order_by) = self.parse_optional_args_with_orderby()?; + let filter = if self.dialect.supports_filter_during_aggregation() + && self.parse_keyword(Keyword::FILTER) + && self.consume_token(&Token::LParen) + && self.parse_keyword(Keyword::WHERE) + { + let filter = Some(Box::new(self.parse_expr()?)); + self.expect_token(&Token::RParen)?; + filter + } else { + None + }; let over = if self.parse_keyword(Keyword::OVER) { if self.consume_token(&Token::LParen) { let window_spec = self.parse_window_spec()?; @@ -970,6 +982,7 @@ impl<'a> Parser<'a> { Ok(Expr::Function(Function { name, args, + filter, over, distinct, special: false, @@ -987,6 +1000,7 @@ impl<'a> Parser<'a> { Ok(Expr::Function(Function { name, args, + filter: None, over: None, distinct: false, special, diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index 7a9a8d1c4..abdd595cb 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -533,6 +533,7 @@ fn parse_map_access_offset() { args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( number("0") ))),], + filter: None, over: None, distinct: false, special: false, diff --git a/tests/sqlparser_clickhouse.rs b/tests/sqlparser_clickhouse.rs index 9efe4a368..10f816eeb 100644 --- a/tests/sqlparser_clickhouse.rs +++ b/tests/sqlparser_clickhouse.rs @@ -50,6 +50,7 @@ fn parse_map_access_expr() { Value::SingleQuotedString("endpoint".to_string()) ))), ], + filter: None, over: None, distinct: false, special: false, @@ -89,6 +90,7 @@ fn parse_map_access_expr() { Value::SingleQuotedString("app".to_string()) ))), ], + filter: None, over: None, distinct: false, special: false, @@ -138,6 +140,7 @@ fn parse_array_fn() { FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(Ident::new("x1")))), FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(Ident::new("x2")))), ], + filter: None, over: None, distinct: false, special: false, @@ -196,6 +199,7 @@ fn parse_delimited_identifiers() { &Expr::Function(Function { name: ObjectName(vec![Ident::with_quote('"', "myfun")]), args: vec![], + filter: None, over: None, distinct: false, special: false, diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 1511aa76e..b06464e08 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -875,6 +875,7 @@ fn parse_select_count_wildcard() { &Expr::Function(Function { name: ObjectName(vec![Ident::new("COUNT")]), args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard)], + filter: None, over: None, distinct: false, special: false, @@ -895,6 +896,7 @@ fn parse_select_count_distinct() { op: UnaryOperator::Plus, expr: Box::new(Expr::Identifier(Ident::new("x"))), }))], + filter: None, over: None, distinct: true, special: false, @@ -1862,6 +1864,7 @@ fn parse_select_having() { left: Box::new(Expr::Function(Function { name: ObjectName(vec![Ident::new("COUNT")]), args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard)], + filter: None, over: None, distinct: false, special: false, @@ -1887,6 +1890,7 @@ fn parse_select_qualify() { left: Box::new(Expr::Function(Function { name: ObjectName(vec![Ident::new("ROW_NUMBER")]), args: vec![], + filter: None, over: Some(WindowType::WindowSpec(WindowSpec { partition_by: vec![Expr::Identifier(Ident::new("p"))], order_by: vec![OrderByExpr { @@ -3332,6 +3336,7 @@ fn parse_scalar_function_in_projection() { args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( Expr::Identifier(Ident::new("id")) ))], + filter: None, over: None, distinct: false, special: false, @@ -3451,6 +3456,7 @@ fn parse_named_argument_function() { ))), }, ], + filter: None, over: None, distinct: false, special: false, @@ -3482,6 +3488,7 @@ fn parse_window_functions() { &Expr::Function(Function { name: ObjectName(vec![Ident::new("row_number")]), args: vec![], + filter: None, over: Some(WindowType::WindowSpec(WindowSpec { partition_by: vec![], order_by: vec![OrderByExpr { @@ -3525,6 +3532,7 @@ fn test_parse_named_window() { quote_style: None, }), ))], + filter: None, over: Some(WindowType::NamedWindow(Ident { value: "window1".to_string(), quote_style: None, @@ -3550,6 +3558,7 @@ fn test_parse_named_window() { quote_style: None, }), ))], + filter: None, over: Some(WindowType::NamedWindow(Ident { value: "window2".to_string(), quote_style: None, @@ -4019,6 +4028,7 @@ fn parse_at_timezone() { quote_style: None, }]), args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(zero.clone()))], + filter: None, over: None, distinct: false, special: false, @@ -4046,6 +4056,7 @@ fn parse_at_timezone() { quote_style: None, },],), args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(zero))], + filter: None, over: None, distinct: false, special: false, @@ -4057,6 +4068,7 @@ fn parse_at_timezone() { Value::SingleQuotedString("%Y-%m-%dT%H".to_string()), ),),), ], + filter: None, over: None, distinct: false, special: false, @@ -4215,6 +4227,7 @@ fn parse_table_function() { args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( Value::SingleQuotedString("1".to_owned()), )))], + filter: None, over: None, distinct: false, special: false, @@ -4366,6 +4379,7 @@ fn parse_unnest_in_from_clause() { FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("2")))), FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("3")))), ], + filter: None, over: None, distinct: false, special: false, @@ -4395,6 +4409,7 @@ fn parse_unnest_in_from_clause() { FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("2")))), FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("3")))), ], + filter: None, over: None, distinct: false, special: false, @@ -4406,6 +4421,7 @@ fn parse_unnest_in_from_clause() { FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("5")))), FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("6")))), ], + filter: None, over: None, distinct: false, special: false, @@ -6878,6 +6894,7 @@ fn parse_time_functions() { let select_localtime_func_call_ast = Function { name: ObjectName(vec![Ident::new(func_name)]), args: vec![], + filter: None, over: None, distinct: false, special: false, @@ -7364,6 +7381,7 @@ fn parse_pivot_table() { args: (vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( Expr::CompoundIdentifier(vec![Ident::new("a"), Ident::new("amount"),]) ))]), + filter: None, over: None, distinct: false, special: false, @@ -7513,6 +7531,7 @@ fn parse_pivot_unpivot_table() { args: (vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( Expr::Identifier(Ident::new("population")) ))]), + filter: None, over: None, distinct: false, special: false, diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 6ca47e12c..6f3a8f994 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -346,6 +346,7 @@ fn parse_delimited_identifiers() { &Expr::Function(Function { name: ObjectName(vec![Ident::with_quote('"', "myfun")]), args: vec![], + filter: None, over: None, distinct: false, special: false, diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index f9eb4d8fb..ebadf95f2 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -334,6 +334,7 @@ fn parse_delimited_identifiers() { &Expr::Function(Function { name: ObjectName(vec![Ident::with_quote('"', "myfun")]), args: vec![], + filter: None, over: None, distinct: false, special: false, diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 80b9dcfd8..180104717 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -1071,6 +1071,7 @@ fn parse_insert_with_on_duplicate_update() { args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( Expr::Identifier(Ident::new("description")) ))], + filter: None, over: None, distinct: false, special: false, @@ -1084,6 +1085,7 @@ fn parse_insert_with_on_duplicate_update() { args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( Expr::Identifier(Ident::new("perm_create")) ))], + filter: None, over: None, distinct: false, special: false, @@ -1097,6 +1099,7 @@ fn parse_insert_with_on_duplicate_update() { args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( Expr::Identifier(Ident::new("perm_read")) ))], + filter: None, over: None, distinct: false, special: false, @@ -1110,6 +1113,7 @@ fn parse_insert_with_on_duplicate_update() { args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( Expr::Identifier(Ident::new("perm_update")) ))], + filter: None, over: None, distinct: false, special: false, @@ -1123,6 +1127,7 @@ fn parse_insert_with_on_duplicate_update() { args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( Expr::Identifier(Ident::new("perm_delete")) ))], + filter: None, over: None, distinct: false, special: false, @@ -1500,6 +1505,7 @@ fn parse_table_colum_option_on_update() { option: ColumnOption::OnUpdate(Expr::Function(Function { name: ObjectName(vec![Ident::new("CURRENT_TIMESTAMP")]), args: vec![], + filter: None, over: None, distinct: false, special: false, diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index fe336bda7..95fac10c3 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -2274,6 +2274,7 @@ fn test_composite_value() { named: true } )))], + filter: None, over: None, distinct: false, special: false, @@ -2435,6 +2436,7 @@ fn parse_current_functions() { &Expr::Function(Function { name: ObjectName(vec![Ident::new("CURRENT_CATALOG")]), args: vec![], + filter: None, over: None, distinct: false, special: true, @@ -2446,6 +2448,7 @@ fn parse_current_functions() { &Expr::Function(Function { name: ObjectName(vec![Ident::new("CURRENT_USER")]), args: vec![], + filter: None, over: None, distinct: false, special: true, @@ -2457,6 +2460,7 @@ fn parse_current_functions() { &Expr::Function(Function { name: ObjectName(vec![Ident::new("SESSION_USER")]), args: vec![], + filter: None, over: None, distinct: false, special: true, @@ -2468,6 +2472,7 @@ fn parse_current_functions() { &Expr::Function(Function { name: ObjectName(vec![Ident::new("USER")]), args: vec![], + filter: None, over: None, distinct: false, special: true, @@ -2918,6 +2923,7 @@ fn parse_delimited_identifiers() { &Expr::Function(Function { name: ObjectName(vec![Ident::with_quote('"', "myfun")]), args: vec![], + filter: None, over: None, distinct: false, special: false, diff --git a/tests/sqlparser_redshift.rs b/tests/sqlparser_redshift.rs index 5ae539b3c..6238d1eca 100644 --- a/tests/sqlparser_redshift.rs +++ b/tests/sqlparser_redshift.rs @@ -137,6 +137,7 @@ fn parse_delimited_identifiers() { &Expr::Function(Function { name: ObjectName(vec![Ident::with_quote('"', "myfun")]), args: vec![], + filter: None, over: None, distinct: false, special: false, diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index e92656d0b..ab3304c14 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -247,6 +247,7 @@ fn parse_delimited_identifiers() { &Expr::Function(Function { name: ObjectName(vec![Ident::with_quote('"', "myfun")]), args: vec![], + filter: None, over: None, distinct: false, special: false, diff --git a/tests/sqlparser_sqlite.rs b/tests/sqlparser_sqlite.rs index 39a82cc8b..1b86fd701 100644 --- a/tests/sqlparser_sqlite.rs +++ b/tests/sqlparser_sqlite.rs @@ -290,6 +290,31 @@ fn parse_create_table_with_strict() { } } +#[test] +fn parse_window_function_with_filter() { + let sql = "SELECT max(x) FILTER (WHERE y) OVER () FROM t"; + let select = sqlite().verified_only_select(sql); + assert_eq!(select.to_string(), sql); + assert_eq!( + select.projection, + vec![SelectItem::UnnamedExpr(Expr::Function(Function { + name: ObjectName(vec![Ident::new("max")]), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( + Expr::Identifier(Ident::new("x")) + ))], + over: Some(WindowType::WindowSpec(WindowSpec { + partition_by: vec![], + order_by: vec![], + window_frame: None, + })), + filter: Some(Box::new(Expr::Identifier(Ident::new("y")))), + distinct: false, + special: false, + order_by: vec![] + }))] + ); +} + #[test] fn parse_attach_database() { let sql = "ATTACH DATABASE 'test.db' AS test";