diff --git a/extendr-api/src/error.rs b/extendr-api/src/error.rs index f82d649bed..054917c704 100644 --- a/extendr-api/src/error.rs +++ b/extendr-api/src/error.rs @@ -77,6 +77,9 @@ pub enum Error { NamespaceNotFound(Robj), NoGraphicsDevices(Robj), + ExpectedFactor(Robj), + ExpectedScalarFactor(Robj), + InvalidLevels(Robj, Robj), ExpectedExternalPtrType(Robj, String), Other(String), @@ -158,6 +161,21 @@ impl std::fmt::Display for Error { Error::TypeMismatch(_robj) => write!(f, "Type mismatch"), Error::NamespaceNotFound(robj) => write!(f, "Namespace {:?} not found", robj), + // factor conversion errors + Error::ExpectedFactor(robj) => write!(f, "Expected factor, got {:?}", robj), + Error::ExpectedScalarFactor(robj) => { + write!( + f, + "Expected scalar factor, got {:?}, of length {} instead of 1", + robj, + crate::Length::len(robj) + ) + } + Error::InvalidLevels(source_levels, target_levels) => write!( + f, + "Expected levels: {:?}, received levels: {:?}", + target_levels, source_levels + ), Error::ExpectedExternalPtrType(_robj, type_name) => { write!(f, "Incorrect external pointer type {}", type_name) } diff --git a/extendr-macros/src/extendr_enum.rs b/extendr-macros/src/extendr_enum.rs index 623cff2d11..6bff9f2bd5 100644 --- a/extendr-macros/src/extendr_enum.rs +++ b/extendr-macros/src/extendr_enum.rs @@ -46,11 +46,15 @@ pub(crate) fn extendr_enum( let field_name = &ele.ident; //FIXME: sanitize field names, as sometimes they have r# etc. - literal_field_names.push(syn::LitStr::new(field_name.to_string().as_str(), field_name.span())); + literal_field_names.push(syn::LitStr::new( + field_name.to_string().as_str(), + field_name.span(), + )); // field_names.push(format!("{enum_name}::{field_name}")); field_names.push(field_name); - } + let literal_field_names = literal_field_names; + let field_names = field_names; let enum_name_upper = enum_name.to_string().to_uppercase(); let enum_levels_name_strings = format_ident!("__{}_R_LEVELS", enum_name_upper); @@ -66,12 +70,12 @@ pub(crate) fn extendr_enum( //TODO: consider using a secret module to hide this even further? quote!( - + #item_enum - + #[doc(hidden)] const fn is_clone(){} - + #[doc(hidden)] const _: () = is_clone::<#enum_name>(); @@ -109,18 +113,19 @@ pub(crate) fn extendr_enum( type Error = extendr_api::Error; fn try_from(value: #enum_name) -> Result { let rint: Rint = value.into(); - let robj: Robj = rint.try_into()?; + let mut robj: Robj = rint.try_into()?; + // TODO: consider using `single_threaded` here unsafe { #enum_levels_name_strings.with(|strings_enum|{ let strings_enum = once_cell::unsync::Lazy::force(strings_enum); - libR_sys::Rf_setAttrib(robj.get(), libR_sys::R_LevelsSymbol, strings_enum.get()); + libR_sys::Rf_setAttrib(robj.get_mut(), libR_sys::R_LevelsSymbol, strings_enum.get()); }); extendr_api::R_FactorSymbol.with(|factor_class| { let factor_class = once_cell::unsync::Lazy::force(factor_class); // a symbol is permanent, so no need to protect it // printname is CHARSXP, and we need a STRSXP, hence `Rf_ScalarString` // doesn't need protection, because it gets inserted into a protected `SEXP` immediately - libR_sys::Rf_setAttrib(robj.get(), libR_sys::R_ClassSymbol, libR_sys::Rf_ScalarString(libR_sys::PRINTNAME(*factor_class))); + libR_sys::Rf_setAttrib(robj.get_mut(), libR_sys::R_ClassSymbol, libR_sys::Rf_ScalarString(libR_sys::PRINTNAME(*factor_class))); }); } Ok(robj) @@ -130,7 +135,39 @@ pub(crate) fn extendr_enum( impl TryFrom for #enum_name { type Error = extendr_api::Error; fn try_from(robj: Robj) -> Result { - assert!(robj.is_factor()); + if !robj.is_factor() { + return Err(Error::ExpectedFactor(robj)); + } + + let levels = robj.get_attrib(levels_symbol()).unwrap(); + let levels: Strings = levels.try_into().unwrap(); + + // same levels as enum? + let levels_cmp_flag = #enum_levels_name_strings.with(|x|{ + let target_levels = extendr_api::prelude::once_cell::unsync::Lazy::force(x); + + //FIXME: propogate error instead of panic'ing. + if &levels == target_levels { + None + } else { + Some(Error::InvalidLevels(levels.into(), target_levels.into())) + } + }); + if let Some(levels_err) = levels_cmp_flag { + return Err(levels_err); + } + + use extendr_api::AsTypedSlice; + let int_vector: &[Rint] = robj.as_typed_slice().unwrap(); + if int_vector.len() != 1 { + return Err(Error::ExpectedScalarFactor(robj)) + } + + let result: #enum_name = int_vector[0].into(); + + Ok(result) + } + } let levels = robj.get_attrib(levels_symbol()).unwrap(); let levels: Strings = levels.try_into().unwrap();