Skip to content

Commit

Permalink
Support relation visitor to visit the Option field (#1556)
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal authored Nov 29, 2024
1 parent 6291afb commit 92c6e7f
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 8 deletions.
49 changes: 49 additions & 0 deletions derive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,55 @@ visitor.post_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null expr>)
```

If the field is a `Option` and add `#[with = "visit_xxx"]` to the field, the generated code
will try to access the field only if it is `Some`:

```rust
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ShowStatementIn {
pub clause: ShowStatementInClause,
pub parent_type: Option<ShowStatementInParentType>,
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
pub parent_name: Option<ObjectName>,
}
```

This will generate

```rust
impl sqlparser::ast::Visit for ShowStatementIn {
fn visit<V: sqlparser::ast::Visitor>(
&self,
visitor: &mut V,
) -> ::std::ops::ControlFlow<V::Break> {
sqlparser::ast::Visit::visit(&self.clause, visitor)?;
sqlparser::ast::Visit::visit(&self.parent_type, visitor)?;
if let Some(value) = &self.parent_name {
visitor.pre_visit_relation(value)?;
sqlparser::ast::Visit::visit(value, visitor)?;
visitor.post_visit_relation(value)?;
}
::std::ops::ControlFlow::Continue(())
}
}

impl sqlparser::ast::VisitMut for ShowStatementIn {
fn visit<V: sqlparser::ast::VisitorMut>(
&mut self,
visitor: &mut V,
) -> ::std::ops::ControlFlow<V::Break> {
sqlparser::ast::VisitMut::visit(&mut self.clause, visitor)?;
sqlparser::ast::VisitMut::visit(&mut self.parent_type, visitor)?;
if let Some(value) = &mut self.parent_name {
visitor.pre_visit_relation(value)?;
sqlparser::ast::VisitMut::visit(value, visitor)?;
visitor.post_visit_relation(value)?;
}
::std::ops::ControlFlow::Continue(())
}
}
```

## Releasing

This crate's release is not automated. Instead it is released manually as needed
Expand Down
36 changes: 29 additions & 7 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
Ident, Index, LitStr, Meta, Token,
};
use syn::{parse::{Parse, ParseStream}, parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, Ident, Index, LitStr, Meta, Token, Type, TypePath};
use syn::{Path, PathArguments};

/// Implementation of `[#derive(Visit)]`
#[proc_macro_derive(VisitMut, attributes(visit))]
Expand Down Expand Up @@ -182,9 +179,21 @@ fn visit_children(
Fields::Named(fields) => {
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
let is_option = is_option(&f.ty);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit)
if is_option && attributes.with.is_some() {
let (pre_visit, post_visit) = attributes.visit(quote!(value));
quote_spanned!(f.span() =>
if let Some(value) = &#modifier self.#name {
#pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit
}
)
} else {
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
quote_spanned!(f.span() =>
#pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit
)
}
});
quote! {
#(#recurse)*
Expand Down Expand Up @@ -256,3 +265,16 @@ fn visit_children(
Data::Union(_) => unimplemented!(),
}
}

fn is_option(ty: &Type) -> bool {
if let Type::Path(TypePath { path: Path { segments, .. }, .. }) = ty {
if let Some(segment) = segments.last() {
if segment.ident == "Option" {
if let PathArguments::AngleBracketed(args) = &segment.arguments {
return args.args.len() == 1;
}
}
}
}
false
}
1 change: 1 addition & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7653,6 +7653,7 @@ impl fmt::Display for ShowStatementInParentType {
pub struct ShowStatementIn {
pub clause: ShowStatementInClause,
pub parent_type: Option<ShowStatementInParentType>,
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
pub parent_name: Option<ObjectName>,
}

Expand Down
11 changes: 10 additions & 1 deletion src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,16 @@ mod tests {
"POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
]
)
),
(
"SHOW COLUMNS FROM t1",
vec![
"PRE: STATEMENT: SHOW COLUMNS FROM t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: STATEMENT: SHOW COLUMNS FROM t1",
],
),
];
for (sql, expected) in tests {
let actual = do_visit(sql);
Expand Down

0 comments on commit 92c6e7f

Please sign in to comment.