Skip to content

Commit

Permalink
extendr-macros: Accept argument alias with mut in front (extendr#752
Browse files Browse the repository at this point in the history
)

* extendr-macros: Use `try_from` instead of `FromRobj`.

* extendr-macros: fix the lack of `mut`-alias in `#[extendr]` macro

* use `if let Some` instead of `match

* refactor: moved the code to clear `mut` from `Ident` to its own function
  • Loading branch information
CGMossa authored May 9, 2024
1 parent c215039 commit 6923b8e
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 41 deletions.
1 change: 0 additions & 1 deletion extendr-api/src/wrapper/externalptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
//! a class in R, you must decorate it with a class-attribute manually.
//!
use super::*;
use std::any::Any;
use std::fmt::Debug;

/// Wrapper for creating R objects containing any Rust object.
Expand Down
6 changes: 3 additions & 3 deletions extendr-macros/src/extendr_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub fn extendr_module(item: TokenStream) -> TokenStream {
implnames,
usenames,
} = module;
let modname = modname.unwrap();
let modname = modname.expect("cannot include unnamed modules");
let modname_string = modname.to_string();
let module_init_name = format_ident!("R_init_{}_extendr", modname);

Expand Down Expand Up @@ -102,10 +102,10 @@ pub fn extendr_module(item: TokenStream) -> TokenStream {
use extendr_api::robj::*;
use extendr_api::GetSexp;
let robj = Robj::from_sexp(use_symbols_sexp);
let use_symbols: bool = <bool>::from_robj(&robj).unwrap();
let use_symbols: bool = <bool>::try_from(&robj).unwrap();

let robj = Robj::from_sexp(package_name_sexp);
let package_name: &str = <&str>::from_robj(&robj).unwrap();
let package_name: &str = <&str>::try_from(&robj).unwrap();

extendr_api::Robj::from(
#module_metadata_name()
Expand Down
31 changes: 27 additions & 4 deletions extendr-macros/src/wrappers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,10 @@ pub fn translate_formal(input: &FnArg, self_ty: Option<&syn::Type>) -> syn::Resu
match input {
// function argument.
FnArg::Typed(ref pattype) => {
let pat = &pattype.pat.as_ref();
Ok(parse_quote! { #pat : extendr_api::SEXP })
let pat = pattype.pat.as_ref();
// ensure that `mut` in args are ignored in the wrapper
let pat_ident = translate_only_alias(pat)?;
Ok(parse_quote! { #pat_ident: extendr_api::SEXP })
}
// &self / &mut self
FnArg::Receiver(ref receiver) => {
Expand All @@ -374,14 +376,33 @@ pub fn translate_formal(input: &FnArg, self_ty: Option<&syn::Type>) -> syn::Resu
}
}

/// Returns only the alias from a function argument.
///
/// For example `mut x: Vec<i32>`, the alias is `x`, but the `mut` would still
/// be present if only the `Ident` of `PatType` was used.
fn translate_only_alias(pat: &syn::Pat) -> Result<&Ident, syn::Error> {
Ok(match pat {
syn::Pat::Ident(ref pat_ident) => &pat_ident.ident,
_ => {
return Err(syn::Error::new_spanned(
pat,
"failed to translate name of argument",
));
}
})
}

// Generate code to make a metadata::Arg.
fn translate_meta_arg(input: &mut FnArg, self_ty: Option<&syn::Type>) -> syn::Result<Expr> {
match input {
// function argument.
FnArg::Typed(ref mut pattype) => {
let pat = pattype.pat.as_ref();
let ty = pattype.ty.as_ref();
let name_string = quote! { #pat }.to_string();
// here the argument name is extracted, without the `mut` keyword,
// ensuring the generated r-wrappers, can use these argument names
let pat_ident = translate_only_alias(pat)?;
let name_string = quote! { #pat_ident }.to_string();
let type_string = type_name(ty);
let default = if let Some(default) = get_named_lit(&mut pattype.attrs, "default") {
quote!(Some(#default))
Expand Down Expand Up @@ -433,7 +454,9 @@ fn translate_to_robj(input: &FnArg) -> syn::Result<syn::Stmt> {
let pat = &pattype.pat.as_ref();
if let syn::Pat::Ident(ref ident) = pat {
let varname = format_ident!("_{}_robj", ident.ident);
Ok(parse_quote! { let #varname = extendr_api::robj::Robj::from_sexp(#pat); })
let ident = &ident.ident;
// TODO: these do not need protection, as they come from R
Ok(parse_quote! { let #varname = extendr_api::robj::Robj::from_sexp(#ident); })
} else {
Err(syn::Error::new_spanned(
input,
Expand Down
2 changes: 1 addition & 1 deletion tests/extendrtests/src/rust/src/altrep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn new_usize(robj: Integers) -> Altrep {

#[cfg(not(use_r_altlist))]
#[extendr]
fn new_usize(robj: Integers) -> Robj {
fn new_usize(_robj: Integers) -> Robj {
extendr_api::nil_value()
}

Expand Down
9 changes: 3 additions & 6 deletions tests/extendrtests/src/rust/src/attributes.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use extendr_api::prelude::*;

#[extendr]
fn dbls_named(x: Doubles) -> Doubles {
let mut x = x;
fn dbls_named(mut x: Doubles) -> Doubles {
x.set_attrib(
"names",
x.iter()
Expand All @@ -15,8 +14,7 @@ fn dbls_named(x: Doubles) -> Doubles {
}

#[extendr]
fn strings_named(x: Strings) -> Strings {
let mut x = x;
fn strings_named(mut x: Strings) -> Strings {
x.set_attrib(
"names",
x.iter()
Expand All @@ -28,8 +26,7 @@ fn strings_named(x: Strings) -> Strings {
}

#[extendr]
fn list_named(x: List, nms: Strings) -> List {
let mut x = x;
fn list_named(mut x: List, nms: Strings) -> List {
let _ = x.set_attrib("names", nms);
x
}
Expand Down
49 changes: 23 additions & 26 deletions tests/extendrtests/tests/testthat/_snaps/macro-snapshot.md
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,9 @@
use extendr_api::robj::*;
use extendr_api::GetSexp;
let robj = Robj::from_sexp(use_symbols_sexp);
let use_symbols: bool = <bool>::from_robj(&robj).unwrap();
let use_symbols: bool = <bool>::try_from(&robj).unwrap();
let robj = Robj::from_sexp(package_name_sexp);
let package_name: &str = <&str>::from_robj(&robj).unwrap();
let package_name: &str = <&str>::try_from(&robj).unwrap();
extendr_api::Robj::from(
get_altrep_metadata()
.make_r_wrappers(use_symbols, package_name)
Expand All @@ -528,8 +528,7 @@
}
mod attributes {
use extendr_api::prelude::*;
fn dbls_named(x: Doubles) -> Doubles {
let mut x = x;
fn dbls_named(mut x: Doubles) -> Doubles {
x.set_attrib(
"names",
x.iter().map(|xi| xi.inner().to_string()).collect::<Vec<_>>(),
Expand Down Expand Up @@ -623,8 +622,7 @@
hidden: false,
})
}
fn strings_named(x: Strings) -> Strings {
let mut x = x;
fn strings_named(mut x: Strings) -> Strings {
x.set_attrib(
"names",
x.iter().map(|xi| xi.as_str().to_string()).collect::<Vec<_>>(),
Expand Down Expand Up @@ -718,8 +716,7 @@
hidden: false,
})
}
fn list_named(x: List, nms: Strings) -> List {
let mut x = x;
fn list_named(mut x: List, nms: Strings) -> List {
let _ = x.set_attrib("names", nms);
x
}
Expand Down Expand Up @@ -887,9 +884,9 @@
use extendr_api::robj::*;
use extendr_api::GetSexp;
let robj = Robj::from_sexp(use_symbols_sexp);
let use_symbols: bool = <bool>::from_robj(&robj).unwrap();
let use_symbols: bool = <bool>::try_from(&robj).unwrap();
let robj = Robj::from_sexp(package_name_sexp);
let package_name: &str = <&str>::from_robj(&robj).unwrap();
let package_name: &str = <&str>::try_from(&robj).unwrap();
extendr_api::Robj::from(
get_attributes_metadata()
.make_r_wrappers(use_symbols, package_name)
Expand Down Expand Up @@ -1215,9 +1212,9 @@
use extendr_api::robj::*;
use extendr_api::GetSexp;
let robj = Robj::from_sexp(use_symbols_sexp);
let use_symbols: bool = <bool>::from_robj(&robj).unwrap();
let use_symbols: bool = <bool>::try_from(&robj).unwrap();
let robj = Robj::from_sexp(package_name_sexp);
let package_name: &str = <&str>::from_robj(&robj).unwrap();
let package_name: &str = <&str>::try_from(&robj).unwrap();
extendr_api::Robj::from(
get_dataframe_metadata()
.make_r_wrappers(use_symbols, package_name)
Expand Down Expand Up @@ -2107,9 +2104,9 @@
use extendr_api::robj::*;
use extendr_api::GetSexp;
let robj = Robj::from_sexp(use_symbols_sexp);
let use_symbols: bool = <bool>::from_robj(&robj).unwrap();
let use_symbols: bool = <bool>::try_from(&robj).unwrap();
let robj = Robj::from_sexp(package_name_sexp);
let package_name: &str = <&str>::from_robj(&robj).unwrap();
let package_name: &str = <&str>::try_from(&robj).unwrap();
extendr_api::Robj::from(
get_memory_leaks_metadata()
.make_r_wrappers(use_symbols, package_name)
Expand Down Expand Up @@ -2280,9 +2277,9 @@
use extendr_api::robj::*;
use extendr_api::GetSexp;
let robj = Robj::from_sexp(use_symbols_sexp);
let use_symbols: bool = <bool>::from_robj(&robj).unwrap();
let use_symbols: bool = <bool>::try_from(&robj).unwrap();
let robj = Robj::from_sexp(package_name_sexp);
let package_name: &str = <&str>::from_robj(&robj).unwrap();
let package_name: &str = <&str>::try_from(&robj).unwrap();
extendr_api::Robj::from(
get_optional_either_metadata()
.make_r_wrappers(use_symbols, package_name)
Expand Down Expand Up @@ -2995,9 +2992,9 @@
use extendr_api::robj::*;
use extendr_api::GetSexp;
let robj = Robj::from_sexp(use_symbols_sexp);
let use_symbols: bool = <bool>::from_robj(&robj).unwrap();
let use_symbols: bool = <bool>::try_from(&robj).unwrap();
let robj = Robj::from_sexp(package_name_sexp);
let package_name: &str = <&str>::from_robj(&robj).unwrap();
let package_name: &str = <&str>::try_from(&robj).unwrap();
extendr_api::Robj::from(
get_optional_faer_metadata()
.make_r_wrappers(use_symbols, package_name)
Expand Down Expand Up @@ -3261,9 +3258,9 @@
use extendr_api::robj::*;
use extendr_api::GetSexp;
let robj = Robj::from_sexp(use_symbols_sexp);
let use_symbols: bool = <bool>::from_robj(&robj).unwrap();
let use_symbols: bool = <bool>::try_from(&robj).unwrap();
let robj = Robj::from_sexp(package_name_sexp);
let package_name: &str = <&str>::from_robj(&robj).unwrap();
let package_name: &str = <&str>::try_from(&robj).unwrap();
extendr_api::Robj::from(
get_optional_ndarray_metadata()
.make_r_wrappers(use_symbols, package_name)
Expand Down Expand Up @@ -3609,9 +3606,9 @@
use extendr_api::robj::*;
use extendr_api::GetSexp;
let robj = Robj::from_sexp(use_symbols_sexp);
let use_symbols: bool = <bool>::from_robj(&robj).unwrap();
let use_symbols: bool = <bool>::try_from(&robj).unwrap();
let robj = Robj::from_sexp(package_name_sexp);
let package_name: &str = <&str>::from_robj(&robj).unwrap();
let package_name: &str = <&str>::try_from(&robj).unwrap();
extendr_api::Robj::from(
get_raw_identifiers_metadata()
.make_r_wrappers(use_symbols, package_name)
Expand Down Expand Up @@ -5496,9 +5493,9 @@
use extendr_api::robj::*;
use extendr_api::GetSexp;
let robj = Robj::from_sexp(use_symbols_sexp);
let use_symbols: bool = <bool>::from_robj(&robj).unwrap();
let use_symbols: bool = <bool>::try_from(&robj).unwrap();
let robj = Robj::from_sexp(package_name_sexp);
let package_name: &str = <&str>::from_robj(&robj).unwrap();
let package_name: &str = <&str>::try_from(&robj).unwrap();
extendr_api::Robj::from(
get_submodule_metadata()
.make_r_wrappers(use_symbols, package_name)
Expand Down Expand Up @@ -10237,9 +10234,9 @@
use extendr_api::robj::*;
use extendr_api::GetSexp;
let robj = Robj::from_sexp(use_symbols_sexp);
let use_symbols: bool = <bool>::from_robj(&robj).unwrap();
let use_symbols: bool = <bool>::try_from(&robj).unwrap();
let robj = Robj::from_sexp(package_name_sexp);
let package_name: &str = <&str>::from_robj(&robj).unwrap();
let package_name: &str = <&str>::try_from(&robj).unwrap();
extendr_api::Robj::from(
get_extendrtests_metadata()
.make_r_wrappers(use_symbols, package_name)
Expand Down

0 comments on commit 6923b8e

Please sign in to comment.