1
- use proc_macro2:: { Literal , TokenStream } ;
2
- use syn:: spanned:: Spanned ;
1
+ use proc_macro2:: { Literal , Span , TokenStream } ;
2
+ use syn:: { parenthesized , parse :: ParseStream , spanned:: Spanned , Token } ;
3
3
use synstructure:: BindStyle ;
4
4
5
5
use crate :: hygiene:: Hygiene ;
6
6
7
7
pub ( crate ) fn update_derive ( input : syn:: DeriveInput ) -> syn:: Result < TokenStream > {
8
8
let hygiene = Hygiene :: from2 ( & input) ;
9
9
10
- if let syn:: Data :: Union ( _ ) = & input. data {
10
+ if let syn:: Data :: Union ( u ) = & input. data {
11
11
return Err ( syn:: Error :: new_spanned (
12
- & input . ident ,
12
+ u . union_token ,
13
13
"`derive(Update)` does not support `union`" ,
14
14
) ) ;
15
15
}
@@ -27,6 +27,24 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
27
27
. variants ( )
28
28
. iter ( )
29
29
. map ( |variant| {
30
+ let err = variant
31
+ . ast ( )
32
+ . attrs
33
+ . iter ( )
34
+ . filter ( |attr| attr. path ( ) . is_ident ( "update" ) )
35
+ . map ( |attr| {
36
+ syn:: Error :: new (
37
+ attr. path ( ) . span ( ) ,
38
+ "unexpected attribute `#[update]` on variant" ,
39
+ )
40
+ } )
41
+ . reduce ( |mut acc, err| {
42
+ acc. combine ( err) ;
43
+ acc
44
+ } ) ;
45
+ if let Some ( err) = err {
46
+ return Err ( err) ;
47
+ }
30
48
let variant_pat = variant. pat ( ) ;
31
49
32
50
// First check that the `new_value` has same variant.
@@ -35,7 +53,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
35
53
. bindings ( )
36
54
. iter ( )
37
55
. fold ( quote ! ( ) , |tokens, binding| quote ! ( #tokens #binding, ) ) ;
38
- let make_new_value = quote_spanned ! { variant . ast ( ) . ident . span ( ) =>
56
+ let make_new_value = quote ! {
39
57
let #new_value = if let #variant_pat = #new_value {
40
58
( #make_tuple)
41
59
} else {
@@ -47,40 +65,78 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
47
65
// For each field, invoke `maybe_update` recursively to update its value.
48
66
// Or the results together (using `|`, not `||`, to avoid shortcircuiting)
49
67
// to get the final return value.
50
- let update_fields = variant. bindings ( ) . iter ( ) . enumerate ( ) . fold (
51
- quote ! ( false ) ,
52
- |tokens, ( index, binding) | {
53
- let field_ty = & binding. ast ( ) . ty ;
54
- let field_index = Literal :: usize_unsuffixed ( index) ;
55
-
56
- let field_span = binding
57
- . ast ( )
58
- . ident
59
- . as_ref ( )
60
- . map ( Spanned :: span)
61
- . unwrap_or ( binding. ast ( ) . span ( ) ) ;
62
-
63
- let update_field = quote_spanned ! { field_span=>
64
- salsa:: plumbing:: UpdateDispatch :: <#field_ty>:: maybe_update(
65
- #binding,
66
- #new_value. #field_index,
67
- )
68
- } ;
68
+ let mut update_fields = quote ! ( false ) ;
69
+ for ( index, binding) in variant. bindings ( ) . iter ( ) . enumerate ( ) {
70
+ let mut attrs = binding
71
+ . ast ( )
72
+ . attrs
73
+ . iter ( )
74
+ . filter ( |attr| attr. path ( ) . is_ident ( "update" ) ) ;
75
+ let attr = attrs. next ( ) ;
76
+ if let Some ( attr) = attrs. next ( ) {
77
+ return Err ( syn:: Error :: new (
78
+ attr. path ( ) . span ( ) ,
79
+ "multiple #[update(with)] attributes on field" ,
80
+ ) ) ;
81
+ }
69
82
70
- quote ! {
71
- #tokens | unsafe { #update_field }
83
+ let field_ty = & binding. ast ( ) . ty ;
84
+ let field_index = Literal :: usize_unsuffixed ( index) ;
85
+
86
+ let ( maybe_update, unsafe_token) = match attr {
87
+ Some ( attr) => {
88
+ mod kw {
89
+ syn:: custom_keyword!( with) ;
90
+ }
91
+ attr. parse_args_with ( |parser : ParseStream | {
92
+ let mut content;
93
+
94
+ let unsafe_token = parser. parse :: < Token ! [ unsafe ] > ( ) ?;
95
+ parenthesized ! ( content in parser) ;
96
+ content. parse :: < kw:: with > ( ) ?;
97
+ parenthesized ! ( content in content) ;
98
+ let r = content. parse :: < syn:: ExprPath > ( ) ?;
99
+ Ok ( (
100
+ quote ! { ( #r as unsafe fn ( * mut #field_ty, #field_ty) -> bool ) } ,
101
+ unsafe_token,
102
+ ) )
103
+ } ) ?
72
104
}
73
- } ,
74
- ) ;
105
+ None => {
106
+ let field_span = binding
107
+ . ast ( )
108
+ . ident
109
+ . as_ref ( )
110
+ . map ( Spanned :: span)
111
+ . unwrap_or ( binding. ast ( ) . span ( ) ) ;
112
+ (
113
+ quote_spanned ! ( field_span=>
114
+ salsa:: plumbing:: UpdateDispatch :: <#field_ty>:: maybe_update
115
+ ) ,
116
+ Token ! [ unsafe ] ( Span :: call_site ( ) ) ,
117
+ )
118
+ }
119
+ } ;
120
+ let update_field = quote ! {
121
+ #maybe_update(
122
+ #binding,
123
+ #new_value. #field_index,
124
+ )
125
+ } ;
126
+
127
+ update_fields = quote ! {
128
+ #update_fields | #unsafe_token { #update_field }
129
+ } ;
130
+ }
75
131
76
- quote ! (
132
+ Ok ( quote ! (
77
133
#variant_pat => {
78
134
#make_new_value
79
135
#update_fields
80
136
}
81
- )
137
+ ) )
82
138
} )
83
- . collect ( ) ;
139
+ . collect :: < syn :: Result < _ > > ( ) ? ;
84
140
85
141
let ident = & input. ident ;
86
142
let ( impl_generics, ty_generics, where_clause) = input. generics . split_for_impl ( ) ;
0 commit comments