Skip to content

Commit

Permalink
extendr-macro: Added propogating errors for
Browse files Browse the repository at this point in the history
enum-to-factor conversions
  • Loading branch information
CGMossa committed Feb 18, 2024
1 parent b400d39 commit 1701f2a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 9 deletions.
18 changes: 18 additions & 0 deletions extendr-api/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ pub enum Error {
NamespaceNotFound(Robj),
NoGraphicsDevices(Robj),

ExpectedFactor(Robj),
ExpectedScalarFactor(Robj),
InvalidLevels(Robj, Robj),
ExpectedExternalPtrType(Robj, String),
Other(String),

Expand Down Expand Up @@ -158,6 +161,21 @@ impl std::fmt::Display for Error {
Error::TypeMismatch(_robj) => write!(f, "Type mismatch"),

Error::NamespaceNotFound(robj) => write!(f, "Namespace {:?} not found", robj),
// factor conversion errors
Error::ExpectedFactor(robj) => write!(f, "Expected factor, got {:?}", robj),
Error::ExpectedScalarFactor(robj) => {
write!(
f,
"Expected scalar factor, got {:?}, of length {} instead of 1",
robj,
crate::Length::len(robj)
)
}
Error::InvalidLevels(source_levels, target_levels) => write!(
f,
"Expected levels: {:?}, received levels: {:?}",
target_levels, source_levels
),
Error::ExpectedExternalPtrType(_robj, type_name) => {
write!(f, "Incorrect external pointer type {}", type_name)
}
Expand Down
55 changes: 46 additions & 9 deletions extendr-macros/src/extendr_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,15 @@ pub(crate) fn extendr_enum(

let field_name = &ele.ident;
//FIXME: sanitize field names, as sometimes they have r# etc.
literal_field_names.push(syn::LitStr::new(field_name.to_string().as_str(), field_name.span()));
literal_field_names.push(syn::LitStr::new(
field_name.to_string().as_str(),
field_name.span(),
));
// field_names.push(format!("{enum_name}::{field_name}"));
field_names.push(field_name);

}
let literal_field_names = literal_field_names;
let field_names = field_names;

let enum_name_upper = enum_name.to_string().to_uppercase();
let enum_levels_name_strings = format_ident!("__{}_R_LEVELS", enum_name_upper);
Expand All @@ -66,12 +70,12 @@ pub(crate) fn extendr_enum(

//TODO: consider using a secret module to hide this even further?
quote!(

#item_enum

#[doc(hidden)]
const fn is_clone<T: Clone>(){}

#[doc(hidden)]
const _: () = is_clone::<#enum_name>();

Expand Down Expand Up @@ -109,18 +113,19 @@ pub(crate) fn extendr_enum(
type Error = extendr_api::Error;
fn try_from(value: #enum_name) -> Result<Self> {
let rint: Rint = value.into();
let robj: Robj = rint.try_into()?;
let mut robj: Robj = rint.try_into()?;
// TODO: consider using `single_threaded` here
unsafe {
#enum_levels_name_strings.with(|strings_enum|{
let strings_enum = once_cell::unsync::Lazy::force(strings_enum);
libR_sys::Rf_setAttrib(robj.get(), libR_sys::R_LevelsSymbol, strings_enum.get());
libR_sys::Rf_setAttrib(robj.get_mut(), libR_sys::R_LevelsSymbol, strings_enum.get());
});
extendr_api::R_FactorSymbol.with(|factor_class| {
let factor_class = once_cell::unsync::Lazy::force(factor_class);
// a symbol is permanent, so no need to protect it
// printname is CHARSXP, and we need a STRSXP, hence `Rf_ScalarString`
// doesn't need protection, because it gets inserted into a protected `SEXP` immediately
libR_sys::Rf_setAttrib(robj.get(), libR_sys::R_ClassSymbol, libR_sys::Rf_ScalarString(libR_sys::PRINTNAME(*factor_class)));
libR_sys::Rf_setAttrib(robj.get_mut(), libR_sys::R_ClassSymbol, libR_sys::Rf_ScalarString(libR_sys::PRINTNAME(*factor_class)));
});
}
Ok(robj)
Expand All @@ -130,7 +135,39 @@ pub(crate) fn extendr_enum(
impl TryFrom<Robj> for #enum_name {
type Error = extendr_api::Error;
fn try_from(robj: Robj) -> Result<Self> {
assert!(robj.is_factor());
if !robj.is_factor() {
return Err(Error::ExpectedFactor(robj));
}

let levels = robj.get_attrib(levels_symbol()).unwrap();
let levels: Strings = levels.try_into().unwrap();

// same levels as enum?
let levels_cmp_flag = #enum_levels_name_strings.with(|x|{
let target_levels = extendr_api::prelude::once_cell::unsync::Lazy::force(x);

//FIXME: propogate error instead of panic'ing.
if &levels == target_levels {
None
} else {
Some(Error::InvalidLevels(levels.into(), target_levels.into()))
}
});
if let Some(levels_err) = levels_cmp_flag {
return Err(levels_err);
}

use extendr_api::AsTypedSlice;
let int_vector: &[Rint] = robj.as_typed_slice().unwrap();
if int_vector.len() != 1 {
return Err(Error::ExpectedScalarFactor(robj))
}

let result: #enum_name = int_vector[0].into();

Ok(result)
}
}

let levels = robj.get_attrib(levels_symbol()).unwrap();
let levels: Strings = levels.try_into().unwrap();
Expand Down

0 comments on commit 1701f2a

Please sign in to comment.