Skip to content

Commit

Permalink
derive(FromPyObject): adds default option
Browse files Browse the repository at this point in the history
Takes an optional stringified expression to set a custom value that is not the one from the Default trait
  • Loading branch information
Tpt committed Dec 30, 2024
1 parent 3965f5f commit a3d6c7d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 12 deletions.
3 changes: 3 additions & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub mod kw {
syn::custom_keyword!(unsendable);
syn::custom_keyword!(weakref);
syn::custom_keyword!(gil_used);
syn::custom_keyword!(default);
}

fn take_int(read: &mut &str, tracker: &mut usize) -> String {
Expand Down Expand Up @@ -351,6 +352,8 @@ impl<K: ToTokens, V: ToTokens> ToTokens for OptionalKeywordAttribute<K, V> {

pub type FromPyWithAttribute = KeywordAttribute<kw::from_py_with, LitStrValue<ExprPath>>;

pub type DefaultAttribute = OptionalKeywordAttribute<kw::default, LitStrValue<Expr>>;

/// For specifying the path to the pyo3 crate.
pub type CrateAttribute = KeywordAttribute<Token![crate], LitStrValue<Path>>;

Expand Down
56 changes: 44 additions & 12 deletions pyo3-macros-backend/src/frompyobject.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute};
use crate::attributes::{
self, get_pyo3_options, CrateAttribute, DefaultAttribute, FromPyWithAttribute,
};
use crate::utils::Ctx;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use quote::{format_ident, quote, ToTokens};
use syn::{
ext::IdentExt,
parenthesized,
Expand Down Expand Up @@ -90,6 +92,7 @@ struct NamedStructField<'a> {
ident: &'a syn::Ident,
getter: Option<FieldGetter>,
from_py_with: Option<FromPyWithAttribute>,
default: Option<DefaultAttribute>,
}

struct TupleStructField {
Expand Down Expand Up @@ -193,6 +196,7 @@ impl<'a> Container<'a> {
ident,
getter: attrs.getter,
from_py_with: attrs.from_py_with,
default: attrs.default,
})
})
.collect::<Result<Vec<_>>>()?;
Expand Down Expand Up @@ -346,18 +350,33 @@ impl<'a> Container<'a> {
quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name)))
}
};
let extractor = match &field.from_py_with {
None => {
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?)
}
Some(FromPyWithAttribute {
value: expr_path, ..
}) => {
quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?)
}
let extractor = if let Some(FromPyWithAttribute {
value: expr_path, ..
}) = &field.from_py_with
{
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &value, #struct_name, #field_name)?)
} else {
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
};
let extracted = if let Some(default) = &field.default {
let default_expr = if let Some(default_expr) = &default.value {
default_expr.0.to_token_stream()
} else {
quote!(Default::default())
};
quote!(if let Ok(value) = #getter {
#extractor
} else {
#default_expr
})
} else {
quote!({
let value = #getter?;
#extractor
})
};

fields.push(quote!(#ident: #extractor));
fields.push(quote!(#ident: #extracted));
}

quote!(::std::result::Result::Ok(#self_ty{#fields}))
Expand Down Expand Up @@ -458,6 +477,7 @@ impl ContainerOptions {
struct FieldPyO3Attributes {
getter: Option<FieldGetter>,
from_py_with: Option<FromPyWithAttribute>,
default: Option<DefaultAttribute>,
}

#[derive(Clone, Debug)]
Expand All @@ -469,6 +489,7 @@ enum FieldGetter {
enum FieldPyO3Attribute {
Getter(FieldGetter),
FromPyWith(FromPyWithAttribute),
Default(DefaultAttribute),
}

impl Parse for FieldPyO3Attribute {
Expand Down Expand Up @@ -512,6 +533,8 @@ impl Parse for FieldPyO3Attribute {
}
} else if lookahead.peek(attributes::kw::from_py_with) {
input.parse().map(FieldPyO3Attribute::FromPyWith)
} else if lookahead.peek(attributes::kw::default) {
input.parse().map(FieldPyO3Attribute::Default)
} else {
Err(lookahead.error())
}
Expand All @@ -523,6 +546,7 @@ impl FieldPyO3Attributes {
fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
let mut getter = None;
let mut from_py_with = None;
let mut default = None;

for attr in attrs {
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
Expand All @@ -542,6 +566,13 @@ impl FieldPyO3Attributes {
);
from_py_with = Some(from_py_with_attr);
}
FieldPyO3Attribute::Default(default_attr) => {
ensure_spanned!(
default.is_none(),
attr.span() => "`default` may only be provided once"
);
default = Some(default_attr);
}
}
}
}
Expand All @@ -550,6 +581,7 @@ impl FieldPyO3Attributes {
Ok(FieldPyO3Attributes {
getter,
from_py_with,
default,
})
}
}
Expand Down
32 changes: 32 additions & 0 deletions tests/test_frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,35 @@ fn test_with_keyword_item() {
assert_eq!(result, expected);
});
}

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub struct WithDefaultItem {
#[pyo3(item, default)]
value: Option<usize>,
}

#[test]
fn test_with_default_item() {
Python::with_gil(|py| {
let dict = PyDict::new(py);
let result = dict.extract::<WithDefaultItem>().unwrap();
let expected = WithDefaultItem { value: None };
assert_eq!(result, expected);
});
}

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub struct WithExplicitDefaultItem {
#[pyo3(item, default = "1")]
value: usize,
}

#[test]
fn test_with_explicit_default_item() {
Python::with_gil(|py| {
let dict = PyDict::new(py);
let result = dict.extract::<WithExplicitDefaultItem>().unwrap();
let expected = WithExplicitDefaultItem { value: 1 };
assert_eq!(result, expected);
});
}

0 comments on commit a3d6c7d

Please sign in to comment.