From 1701f2aa9424790d9228391cb9178fab4821e118 Mon Sep 17 00:00:00 2001
From: Mossa <cgmossa@gmail.com>
Date: Sun, 18 Feb 2024 14:17:21 +0100
Subject: [PATCH] extendr-macro: Added propogating errors for enum-to-factor
 conversions

---
 extendr-api/src/error.rs           | 18 ++++++++++
 extendr-macros/src/extendr_enum.rs | 55 +++++++++++++++++++++++++-----
 2 files changed, 64 insertions(+), 9 deletions(-)

diff --git a/extendr-api/src/error.rs b/extendr-api/src/error.rs
index f82d649bed..054917c704 100644
--- a/extendr-api/src/error.rs
+++ b/extendr-api/src/error.rs
@@ -77,6 +77,9 @@ pub enum Error {
     NamespaceNotFound(Robj),
     NoGraphicsDevices(Robj),
 
+    ExpectedFactor(Robj),
+    ExpectedScalarFactor(Robj),
+    InvalidLevels(Robj, Robj),
     ExpectedExternalPtrType(Robj, String),
     Other(String),
 
@@ -158,6 +161,21 @@ impl std::fmt::Display for Error {
             Error::TypeMismatch(_robj) => write!(f, "Type mismatch"),
 
             Error::NamespaceNotFound(robj) => write!(f, "Namespace {:?} not found", robj),
+            // factor conversion errors
+            Error::ExpectedFactor(robj) => write!(f, "Expected factor, got {:?}", robj),
+            Error::ExpectedScalarFactor(robj) => {
+                write!(
+                    f,
+                    "Expected scalar factor, got {:?}, of length {} instead of 1",
+                    robj,
+                    crate::Length::len(robj)
+                )
+            }
+            Error::InvalidLevels(source_levels, target_levels) => write!(
+                f,
+                "Expected levels: {:?}, received levels: {:?}",
+                target_levels, source_levels
+            ),
             Error::ExpectedExternalPtrType(_robj, type_name) => {
                 write!(f, "Incorrect external pointer type {}", type_name)
             }
diff --git a/extendr-macros/src/extendr_enum.rs b/extendr-macros/src/extendr_enum.rs
index 623cff2d11..6bff9f2bd5 100644
--- a/extendr-macros/src/extendr_enum.rs
+++ b/extendr-macros/src/extendr_enum.rs
@@ -46,11 +46,15 @@ pub(crate) fn extendr_enum(
 
         let field_name = &ele.ident;
         //FIXME: sanitize field names, as sometimes they have r# etc.
-        literal_field_names.push(syn::LitStr::new(field_name.to_string().as_str(), field_name.span()));
+        literal_field_names.push(syn::LitStr::new(
+            field_name.to_string().as_str(),
+            field_name.span(),
+        ));
         // field_names.push(format!("{enum_name}::{field_name}"));
         field_names.push(field_name);
-        
     }
+    let literal_field_names = literal_field_names;
+    let field_names = field_names;
 
     let enum_name_upper = enum_name.to_string().to_uppercase();
     let enum_levels_name_strings = format_ident!("__{}_R_LEVELS", enum_name_upper);
@@ -66,12 +70,12 @@ pub(crate) fn extendr_enum(
     
     //TODO: consider using a secret module to hide this even further?
     quote!(
-        
+
         #item_enum
-        
+
         #[doc(hidden)]
         const fn is_clone<T: Clone>(){}
-        
+
         #[doc(hidden)]
         const _: () = is_clone::<#enum_name>();
 
@@ -109,18 +113,19 @@ pub(crate) fn extendr_enum(
             type Error = extendr_api::Error;
             fn try_from(value: #enum_name) -> Result<Self> {
                 let rint: Rint = value.into();
-                let robj: Robj = rint.try_into()?;
+                let mut robj: Robj = rint.try_into()?;
+                // TODO: consider using `single_threaded` here
                 unsafe {
                     #enum_levels_name_strings.with(|strings_enum|{
                         let strings_enum = once_cell::unsync::Lazy::force(strings_enum);
-                        libR_sys::Rf_setAttrib(robj.get(), libR_sys::R_LevelsSymbol, strings_enum.get());
+                        libR_sys::Rf_setAttrib(robj.get_mut(), libR_sys::R_LevelsSymbol, strings_enum.get());
                     });
                     extendr_api::R_FactorSymbol.with(|factor_class| {
                         let factor_class = once_cell::unsync::Lazy::force(factor_class);
                         // a symbol is permanent, so no need to protect it
                         // printname is CHARSXP, and we need a STRSXP, hence `Rf_ScalarString`
                         // doesn't need protection, because it gets inserted into a protected `SEXP` immediately
-                        libR_sys::Rf_setAttrib(robj.get(), libR_sys::R_ClassSymbol, libR_sys::Rf_ScalarString(libR_sys::PRINTNAME(*factor_class)));
+                        libR_sys::Rf_setAttrib(robj.get_mut(), libR_sys::R_ClassSymbol, libR_sys::Rf_ScalarString(libR_sys::PRINTNAME(*factor_class)));
                     });
                 }
                 Ok(robj)
@@ -130,7 +135,39 @@ pub(crate) fn extendr_enum(
         impl TryFrom<Robj> for #enum_name {
             type Error = extendr_api::Error;
             fn try_from(robj: Robj) -> Result<Self> {
-                assert!(robj.is_factor());
+                if !robj.is_factor() {
+                    return Err(Error::ExpectedFactor(robj));
+                }
+
+                let levels = robj.get_attrib(levels_symbol()).unwrap();
+                let levels: Strings = levels.try_into().unwrap();
+
+                // same levels as enum?
+                let levels_cmp_flag = #enum_levels_name_strings.with(|x|{
+                    let target_levels = extendr_api::prelude::once_cell::unsync::Lazy::force(x);
+                    
+                    //FIXME: propogate error instead of panic'ing.
+                    if &levels == target_levels {
+                        None
+                    } else {
+                        Some(Error::InvalidLevels(levels.into(), target_levels.into()))
+                    }
+                });
+                if let Some(levels_err) = levels_cmp_flag {
+                    return Err(levels_err);
+                }
+
+                use extendr_api::AsTypedSlice;
+                let int_vector: &[Rint] = robj.as_typed_slice().unwrap();
+                if int_vector.len() != 1 {
+                    return Err(Error::ExpectedScalarFactor(robj))
+                }
+
+                let result: #enum_name = int_vector[0].into();
+
+                Ok(result)
+            }
+        }
 
                 let levels = robj.get_attrib(levels_symbol()).unwrap();
                 let levels: Strings = levels.try_into().unwrap();