diff --git a/derive/README.md b/derive/README.md index aa70e7c71..b5ccc69e0 100644 --- a/derive/README.md +++ b/derive/README.md @@ -151,6 +151,55 @@ visitor.post_visit_expr() visitor.post_visit_expr() ``` +If the field is a `Option` and add `#[with = "visit_xxx"]` to the field, the generated code +will try to access the field only if it is `Some`: + +```rust +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct ShowStatementIn { + pub clause: ShowStatementInClause, + pub parent_type: Option, + #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] + pub parent_name: Option, +} +``` + +This will generate + +```rust +impl sqlparser::ast::Visit for ShowStatementIn { + fn visit( + &self, + visitor: &mut V, + ) -> ::std::ops::ControlFlow { + sqlparser::ast::Visit::visit(&self.clause, visitor)?; + sqlparser::ast::Visit::visit(&self.parent_type, visitor)?; + if let Some(value) = &self.parent_name { + visitor.pre_visit_relation(value)?; + sqlparser::ast::Visit::visit(value, visitor)?; + visitor.post_visit_relation(value)?; + } + ::std::ops::ControlFlow::Continue(()) + } +} + +impl sqlparser::ast::VisitMut for ShowStatementIn { + fn visit( + &mut self, + visitor: &mut V, + ) -> ::std::ops::ControlFlow { + sqlparser::ast::VisitMut::visit(&mut self.clause, visitor)?; + sqlparser::ast::VisitMut::visit(&mut self.parent_type, visitor)?; + if let Some(value) = &mut self.parent_name { + visitor.pre_visit_relation(value)?; + sqlparser::ast::VisitMut::visit(value, visitor)?; + visitor.post_visit_relation(value)?; + } + ::std::ops::ControlFlow::Continue(()) + } +} +``` + ## Releasing This crate's release is not automated. Instead it is released manually as needed diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 5ad1607f9..dd4d37b41 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -18,11 +18,8 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::spanned::Spanned; -use syn::{ - parse::{Parse, ParseStream}, - parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, - Ident, Index, LitStr, Meta, Token, -}; +use syn::{parse::{Parse, ParseStream}, parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, Ident, Index, LitStr, Meta, Token, Type, TypePath}; +use syn::{Path, PathArguments}; /// Implementation of `[#derive(Visit)]` #[proc_macro_derive(VisitMut, attributes(visit))] @@ -182,9 +179,21 @@ fn visit_children( Fields::Named(fields) => { let recurse = fields.named.iter().map(|f| { let name = &f.ident; + let is_option = is_option(&f.ty); let attributes = Attributes::parse(&f.attrs); - let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name)); - quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit) + if is_option && attributes.with.is_some() { + let (pre_visit, post_visit) = attributes.visit(quote!(value)); + quote_spanned!(f.span() => + if let Some(value) = &#modifier self.#name { + #pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit + } + ) + } else { + let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name)); + quote_spanned!(f.span() => + #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit + ) + } }); quote! { #(#recurse)* @@ -256,3 +265,16 @@ fn visit_children( Data::Union(_) => unimplemented!(), } } + +fn is_option(ty: &Type) -> bool { + if let Type::Path(TypePath { path: Path { segments, .. }, .. }) = ty { + if let Some(segment) = segments.last() { + if segment.ident == "Option" { + if let PathArguments::AngleBracketed(args) = &segment.arguments { + return args.args.len() == 1; + } + } + } + } + false +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 386e42fb3..19da04c62 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -7653,6 +7653,7 @@ impl fmt::Display for ShowStatementInParentType { pub struct ShowStatementIn { pub clause: ShowStatementInClause, pub parent_type: Option, + #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] pub parent_name: Option, } diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 418e0a299..eacd268a4 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -876,7 +876,16 @@ mod tests { "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", ] - ) + ), + ( + "SHOW COLUMNS FROM t1", + vec![ + "PRE: STATEMENT: SHOW COLUMNS FROM t1", + "PRE: RELATION: t1", + "POST: RELATION: t1", + "POST: STATEMENT: SHOW COLUMNS FROM t1", + ], + ), ]; for (sql, expected) in tests { let actual = do_visit(sql);