From a4a5794904e7cebf5017fbcc3269e532de0ccd65 Mon Sep 17 00:00:00 2001 From: blaginin Date: Thu, 14 Nov 2024 20:48:02 +0000 Subject: [PATCH] Add `#[recursive]` --- Cargo.toml | 4 +++- derive/src/lib.rs | 1 + src/ast/mod.rs | 1 + src/ast/visitor.rs | 25 +++++++++++++++++++++++++ tests/sqlparser_common.rs | 10 ++++++++++ 5 files changed, 40 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 18b246e04..e5f6efbc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ path = "src/lib.rs" [features] default = ["std"] -std = [] +std = ["recursive"] # Enable JSON output in the `cli` example: json_example = ["serde_json", "serde"] visitor = ["sqlparser_derive"] @@ -46,6 +46,8 @@ visitor = ["sqlparser_derive"] [dependencies] bigdecimal = { version = "0.4.1", features = ["serde"], optional = true } log = "0.4" +recursive = { version = "0.1.1", optional = true} + serde = { version = "1.0", features = ["derive"], optional = true } # serde_json is only used in examples/cli, but we have to put it outside # of dev-dependencies because of diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 5ad1607f9..ffa56a533 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -78,6 +78,7 @@ fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_ let expanded = quote! { // The generated impl. impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause { + #[cfg_attr(feature = "std", recursive::recursive)] fn visit( &#modifier self, visitor: &mut V diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 505386fbf..6c77470e3 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1188,6 +1188,7 @@ impl fmt::Display for CastFormat { } impl fmt::Display for Expr { + #[cfg_attr(feature = "std", recursive::recursive)] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Expr::Identifier(s) => write!(f, "{s}"), diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 418e0a299..2ff05da55 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -884,4 +884,29 @@ mod tests { assert_eq!(actual, expected) } } + + struct QuickVisitor; + + impl Visitor for QuickVisitor { + type Break = (); + } + + #[test] + fn overflow() { + let cond = (0..1000) + .map(|n| format!("X = {}", n)) + .collect::>() + .join(" OR "); + let sql = format!("SELECT x where {0}", cond); + + let dialect = GenericDialect {}; + let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap(); + let s = Parser::new(&dialect) + .with_tokens(tokens) + .parse_statement() + .unwrap(); + + let mut visitor = QuickVisitor {} ; + s.visit(&mut visitor); + } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index daf65edf1..14481d477 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11748,3 +11748,13 @@ fn parse_create_table_select() { ); } } + +#[test] +fn overflow() { + let expr = std::iter::repeat("1").take(1000).collect::>().join(" + "); + let sql = format!("SELECT {}", expr); + + let mut statements = Parser::parse_sql(&GenericDialect {}, sql.as_str()).unwrap(); + let statement = statements.pop().unwrap(); + assert_eq!(statement.to_string(), sql); +}