diff --git a/newsfragments/4829.added.md b/newsfragments/4829.added.md new file mode 100644 index 00000000000..849aaaf39a6 --- /dev/null +++ b/newsfragments/4829.added.md @@ -0,0 +1 @@ +`derive(FromPyObject)` allow a `default` attribute to set a default value for extracted fields. The default value is either provided explicitly or fetched via `Default::default()`. \ No newline at end of file diff --git a/pyo3-macros-backend/src/attributes.rs b/pyo3-macros-backend/src/attributes.rs index 6fe75e44302..c5a81e47ec1 100644 --- a/pyo3-macros-backend/src/attributes.rs +++ b/pyo3-macros-backend/src/attributes.rs @@ -45,6 +45,7 @@ pub mod kw { syn::custom_keyword!(unsendable); syn::custom_keyword!(weakref); syn::custom_keyword!(gil_used); + syn::custom_keyword!(default); } fn take_int(read: &mut &str, tracker: &mut usize) -> String { @@ -351,6 +352,8 @@ impl ToTokens for OptionalKeywordAttribute { pub type FromPyWithAttribute = KeywordAttribute>; +pub type DefaultAttribute = OptionalKeywordAttribute; + /// For specifying the path to the pyo3 crate. pub type CrateAttribute = KeywordAttribute>; diff --git a/pyo3-macros-backend/src/frompyobject.rs b/pyo3-macros-backend/src/frompyobject.rs index 565c54da1f3..a2799f002c8 100644 --- a/pyo3-macros-backend/src/frompyobject.rs +++ b/pyo3-macros-backend/src/frompyobject.rs @@ -1,7 +1,9 @@ -use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute}; +use crate::attributes::{ + self, get_pyo3_options, CrateAttribute, DefaultAttribute, FromPyWithAttribute, +}; use crate::utils::Ctx; use proc_macro2::TokenStream; -use quote::{format_ident, quote}; +use quote::{format_ident, quote, ToTokens}; use syn::{ ext::IdentExt, parenthesized, @@ -90,6 +92,7 @@ struct NamedStructField<'a> { ident: &'a syn::Ident, getter: Option, from_py_with: Option, + default: Option, } struct TupleStructField { @@ -193,6 +196,7 @@ impl<'a> Container<'a> { ident, getter: attrs.getter, from_py_with: attrs.from_py_with, + default: attrs.default, }) }) .collect::>>()?; @@ -346,18 +350,33 @@ impl<'a> Container<'a> { quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name))) } }; - let extractor = match &field.from_py_with { - None => { - quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?) - } - Some(FromPyWithAttribute { - value: expr_path, .. - }) => { - quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?) - } + let extractor = if let Some(FromPyWithAttribute { + value: expr_path, .. + }) = &field.from_py_with + { + quote!(#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &value, #struct_name, #field_name)?) + } else { + quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?) + }; + let extracted = if let Some(default) = &field.default { + let default_expr = if let Some(default_expr) = &default.value { + default_expr.to_token_stream() + } else { + quote!(Default::default()) + }; + quote!(if let Ok(value) = #getter { + #extractor + } else { + #default_expr + }) + } else { + quote!({ + let value = #getter?; + #extractor + }) }; - fields.push(quote!(#ident: #extractor)); + fields.push(quote!(#ident: #extracted)); } quote!(::std::result::Result::Ok(#self_ty{#fields})) @@ -458,6 +477,7 @@ impl ContainerOptions { struct FieldPyO3Attributes { getter: Option, from_py_with: Option, + default: Option, } #[derive(Clone, Debug)] @@ -469,6 +489,7 @@ enum FieldGetter { enum FieldPyO3Attribute { Getter(FieldGetter), FromPyWith(FromPyWithAttribute), + Default(DefaultAttribute), } impl Parse for FieldPyO3Attribute { @@ -512,6 +533,8 @@ impl Parse for FieldPyO3Attribute { } } else if lookahead.peek(attributes::kw::from_py_with) { input.parse().map(FieldPyO3Attribute::FromPyWith) + } else if lookahead.peek(attributes::kw::default) { + input.parse().map(FieldPyO3Attribute::Default) } else { Err(lookahead.error()) } @@ -523,6 +546,7 @@ impl FieldPyO3Attributes { fn from_attrs(attrs: &[Attribute]) -> Result { let mut getter = None; let mut from_py_with = None; + let mut default = None; for attr in attrs { if let Some(pyo3_attrs) = get_pyo3_options(attr)? { @@ -542,6 +566,13 @@ impl FieldPyO3Attributes { ); from_py_with = Some(from_py_with_attr); } + FieldPyO3Attribute::Default(default_attr) => { + ensure_spanned!( + default.is_none(), + attr.span() => "`default` may only be provided once" + ); + default = Some(default_attr); + } } } } @@ -550,6 +581,7 @@ impl FieldPyO3Attributes { Ok(FieldPyO3Attributes { getter, from_py_with, + default, }) } } diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs index 2192caf1f7c..75252032842 100644 --- a/tests/test_frompyobject.rs +++ b/tests/test_frompyobject.rs @@ -686,3 +686,35 @@ fn test_with_keyword_item() { assert_eq!(result, expected); }); } + +#[derive(Debug, FromPyObject, PartialEq, Eq)] +pub struct WithDefaultItem { + #[pyo3(item, default)] + value: Option, +} + +#[test] +fn test_with_default_item() { + Python::with_gil(|py| { + let dict = PyDict::new(py); + let result = dict.extract::().unwrap(); + let expected = WithDefaultItem { value: None }; + assert_eq!(result, expected); + }); +} + +#[derive(Debug, FromPyObject, PartialEq, Eq)] +pub struct WithExplicitDefaultItem { + #[pyo3(item, default = 1)] + value: usize, +} + +#[test] +fn test_with_explicit_default_item() { + Python::with_gil(|py| { + let dict = PyDict::new(py); + let result = dict.extract::().unwrap(); + let expected = WithExplicitDefaultItem { value: 1 }; + assert_eq!(result, expected); + }); +} diff --git a/tests/ui/invalid_frompy_derive.stderr b/tests/ui/invalid_frompy_derive.stderr index 8ed03caafb4..e48176b45c5 100644 --- a/tests/ui/invalid_frompy_derive.stderr +++ b/tests/ui/invalid_frompy_derive.stderr @@ -84,7 +84,7 @@ error: transparent structs and variants can only have 1 field 70 | | }, | |_____^ -error: expected one of: `attribute`, `item`, `from_py_with` +error: expected one of: `attribute`, `item`, `from_py_with`, `default` --> tests/ui/invalid_frompy_derive.rs:76:12 | 76 | #[pyo3(attr)]