Skip to content

Commit 1dc4815

Browse files
committed
Update derive field overwrite support
1 parent 9d2a978 commit 1dc4815

File tree

5 files changed

+177
-33
lines changed

5 files changed

+177
-33
lines changed

Diff for: components/salsa-macros/src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ pub fn tracked(args: TokenStream, input: TokenStream) -> TokenStream {
8080
tracked::tracked(args, input)
8181
}
8282

83-
#[proc_macro_derive(Update)]
83+
#[proc_macro_derive(Update, attributes(update))]
8484
pub fn update(input: TokenStream) -> TokenStream {
8585
let item = parse_macro_input!(input as syn::DeriveInput);
8686
match update::update_derive(item) {
8787
Ok(tokens) => tokens.into(),
88-
Err(error) => token_stream_with_error(input, error),
88+
Err(error) => error.into_compile_error().into(),
8989
}
9090
}
9191

Diff for: components/salsa-macros/src/update.rs

+87-31
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
use proc_macro2::{Literal, TokenStream};
2-
use syn::spanned::Spanned;
1+
use proc_macro2::{Literal, Span, TokenStream};
2+
use syn::{parenthesized, parse::ParseStream, spanned::Spanned, Token};
33
use synstructure::BindStyle;
44

55
use crate::hygiene::Hygiene;
66

77
pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream> {
88
let hygiene = Hygiene::from2(&input);
99

10-
if let syn::Data::Union(_) = &input.data {
10+
if let syn::Data::Union(u) = &input.data {
1111
return Err(syn::Error::new_spanned(
12-
&input.ident,
12+
u.union_token,
1313
"`derive(Update)` does not support `union`",
1414
));
1515
}
@@ -27,6 +27,24 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
2727
.variants()
2828
.iter()
2929
.map(|variant| {
30+
let err = variant
31+
.ast()
32+
.attrs
33+
.iter()
34+
.filter(|attr| attr.path().is_ident("update"))
35+
.map(|attr| {
36+
syn::Error::new(
37+
attr.path().span(),
38+
"unexpected attribute `#[update]` on variant",
39+
)
40+
})
41+
.reduce(|mut acc, err| {
42+
acc.combine(err);
43+
acc
44+
});
45+
if let Some(err) = err {
46+
return Err(err);
47+
}
3048
let variant_pat = variant.pat();
3149

3250
// First check that the `new_value` has same variant.
@@ -35,7 +53,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
3553
.bindings()
3654
.iter()
3755
.fold(quote!(), |tokens, binding| quote!(#tokens #binding,));
38-
let make_new_value = quote_spanned! {variant.ast().ident.span()=>
56+
let make_new_value = quote! {
3957
let #new_value = if let #variant_pat = #new_value {
4058
(#make_tuple)
4159
} else {
@@ -47,40 +65,78 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
4765
// For each field, invoke `maybe_update` recursively to update its value.
4866
// Or the results together (using `|`, not `||`, to avoid shortcircuiting)
4967
// to get the final return value.
50-
let update_fields = variant.bindings().iter().enumerate().fold(
51-
quote!(false),
52-
|tokens, (index, binding)| {
53-
let field_ty = &binding.ast().ty;
54-
let field_index = Literal::usize_unsuffixed(index);
55-
56-
let field_span = binding
57-
.ast()
58-
.ident
59-
.as_ref()
60-
.map(Spanned::span)
61-
.unwrap_or(binding.ast().span());
62-
63-
let update_field = quote_spanned! {field_span=>
64-
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
65-
#binding,
66-
#new_value.#field_index,
67-
)
68-
};
68+
let mut update_fields = quote!(false);
69+
for (index, binding) in variant.bindings().iter().enumerate() {
70+
let mut attrs = binding
71+
.ast()
72+
.attrs
73+
.iter()
74+
.filter(|attr| attr.path().is_ident("update"));
75+
let attr = attrs.next();
76+
if let Some(attr) = attrs.next() {
77+
return Err(syn::Error::new(
78+
attr.path().span(),
79+
"multiple #[update(with)] attributes on field",
80+
));
81+
}
6982

70-
quote! {
71-
#tokens | unsafe { #update_field }
83+
let field_ty = &binding.ast().ty;
84+
let field_index = Literal::usize_unsuffixed(index);
85+
86+
let (maybe_update, unsafe_token) = match attr {
87+
Some(attr) => {
88+
mod kw {
89+
syn::custom_keyword!(with);
90+
}
91+
attr.parse_args_with(|parser: ParseStream| {
92+
let mut content;
93+
94+
let unsafe_token = parser.parse::<Token![unsafe]>()?;
95+
parenthesized!(content in parser);
96+
content.parse::<kw::with>()?;
97+
parenthesized!(content in content);
98+
let r = content.parse::<syn::ExprPath>()?;
99+
Ok((
100+
quote! { (#r as unsafe fn(*mut #field_ty, #field_ty) -> bool) },
101+
unsafe_token,
102+
))
103+
})?
72104
}
73-
},
74-
);
105+
None => {
106+
let field_span = binding
107+
.ast()
108+
.ident
109+
.as_ref()
110+
.map(Spanned::span)
111+
.unwrap_or(binding.ast().span());
112+
(
113+
quote_spanned!(field_span=>
114+
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update
115+
),
116+
Token![unsafe](Span::call_site()),
117+
)
118+
}
119+
};
120+
let update_field = quote! {
121+
#maybe_update(
122+
#binding,
123+
#new_value.#field_index,
124+
)
125+
};
126+
127+
update_fields = quote! {
128+
#update_fields | #unsafe_token { #update_field }
129+
};
130+
}
75131

76-
quote!(
132+
Ok(quote!(
77133
#variant_pat => {
78134
#make_new_value
79135
#update_fields
80136
}
81-
)
137+
))
82138
})
83-
.collect();
139+
.collect::<syn::Result<_>>()?;
84140

85141
let ident = &input.ident;
86142
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

Diff for: tests/compile-fail/invalid_update.rs

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#[derive(salsa::Update)]
2+
union U {
3+
field: i32,
4+
}
5+
6+
#[derive(salsa::Update)]
7+
struct S {
8+
recursive: Box<S>,
9+
}
10+
11+
#[derive(salsa::Update)]
12+
struct S2 {
13+
#[update(with(unsafe(my_wrong_update)))]
14+
recursive: i32,
15+
#[update(with(missing_unsafe))]
16+
recursive: i32,
17+
}
18+
19+
fn my_wrong_update() {}
20+
fn missing_unsafe(_: *mut i32, _: i32) -> bool {
21+
true
22+
}
23+
24+
fn main() {}

Diff for: tests/compile-fail/invalid_update.stderr

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
error: `derive(Update)` does not support `union`
2+
--> tests/compile-fail/invalid_update.rs:2:1
3+
|
4+
2 | union U {
5+
| ^^^^^
6+
7+
error: expected `unsafe`
8+
--> tests/compile-fail/invalid_update.rs:13:14
9+
|
10+
13 | #[update(with(unsafe(my_wrong_update)))]
11+
| ^^^^
12+
13+
error[E0124]: field `recursive` is already declared
14+
--> tests/compile-fail/invalid_update.rs:16:5
15+
|
16+
14 | recursive: i32,
17+
| -------------- `recursive` first declared here
18+
15 | #[update(with(missing_unsafe))]
19+
16 | recursive: i32,
20+
| ^^^^^^^^^^^^^^ field already declared

Diff for: tests/derive_update.rs

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//! Test that the `Update` derive works as expected
2+
3+
#[derive(salsa::Update)]
4+
struct MyInput {
5+
field: &'static str,
6+
}
7+
8+
#[derive(salsa::Update)]
9+
struct MyInput2 {
10+
#[update(unsafe(with(custom_update)))]
11+
field: &'static str,
12+
}
13+
14+
unsafe fn custom_update(dest: *mut &'static str, _data: &'static str) -> bool {
15+
unsafe { *dest = "ill-behaved for testing purposes" };
16+
true
17+
}
18+
19+
#[test]
20+
fn derived() {
21+
let mut m = MyInput { field: "foo" };
22+
assert_eq!(m.field, "foo");
23+
assert!(unsafe { salsa::Update::maybe_update(&mut m, MyInput { field: "bar" }) });
24+
assert_eq!(m.field, "bar");
25+
assert!(!unsafe { salsa::Update::maybe_update(&mut m, MyInput { field: "bar" }) });
26+
assert_eq!(m.field, "bar");
27+
}
28+
29+
#[test]
30+
fn derived_with() {
31+
let mut m = MyInput2 { field: "foo" };
32+
assert_eq!(m.field, "foo");
33+
assert!(unsafe { salsa::Update::maybe_update(&mut m, MyInput2 { field: "bar" }) });
34+
assert_eq!(m.field, "ill-behaved for testing purposes");
35+
assert!(unsafe {
36+
salsa::Update::maybe_update(
37+
&mut m,
38+
MyInput2 {
39+
field: "ill-behaved for testing purposes",
40+
},
41+
)
42+
});
43+
assert_eq!(m.field, "ill-behaved for testing purposes");
44+
}

0 commit comments

Comments
 (0)