diff --git a/rustler/src/serde/atoms.rs b/rustler/src/serde/atoms.rs index 27ac7346..9ed55f7c 100644 --- a/rustler/src/serde/atoms.rs +++ b/rustler/src/serde/atoms.rs @@ -3,9 +3,6 @@ use crate::serde::Error; use crate::{types::atom::Atom, Encoder, Env, Term}; -pub static OK: &str = "Ok"; -pub static ERROR: &str = "Err"; - atoms! { nil, ok, @@ -23,16 +20,10 @@ atoms! { * Attempts to create an atom term from the provided string (if the atom already exists in the atom table). If not, returns a string term. */ pub fn str_to_term<'a>(env: &Env<'a>, string: &str) -> Result, Error> { - if string == "Ok" { - Ok(ok().encode(*env)) - } else if string == "Err" { - Ok(error().encode(*env)) - } else { - match Atom::try_from_bytes(*env, string.as_bytes()) { - Ok(Some(term)) => Ok(term.encode(*env)), - Ok(None) => Err(Error::InvalidStringable), - _ => Err(Error::InvalidStringable), - } + match Atom::try_from_bytes(*env, string.as_bytes()) { + Ok(Some(term)) => Ok(term.encode(*env)), + Ok(None) => Ok(string.encode(*env)), + _ => Err(Error::InvalidStringable), } } @@ -40,11 +31,7 @@ pub fn str_to_term<'a>(env: &Env<'a>, string: &str) -> Result, Error> { * Attempts to create a `String` from the term. */ pub fn term_to_string(term: &Term) -> Result { - if ok().eq(term) { - Ok(OK.to_string()) - } else if error().eq(term) { - Ok(ERROR.to_string()) - } else if term.is_atom() { + if term.is_atom() { term.atom_to_string().or(Err(Error::InvalidAtom)) } else { Err(Error::InvalidStringable) diff --git a/rustler/src/serde/de.rs b/rustler/src/serde/de.rs index 04662eda..309a1999 100644 --- a/rustler/src/serde/de.rs +++ b/rustler/src/serde/de.rs @@ -728,8 +728,13 @@ impl<'de, 'a: 'de> de::Deserializer<'de> for VariantNameDeserializer<'a> { { match self.variant.get_type() { TermType::Atom => { - let string = - atoms::term_to_string(&self.variant).or(Err(Error::InvalidVariantName))?; + let string = atoms::term_to_string(&self.variant) + .map(|s| match s.as_str() { + "ok" => "Ok".to_string(), + "error" => "Err".to_string(), + _ => s, + }) + .or(Err(Error::InvalidVariantName))?; visitor.visit_string(string) } TermType::Binary => visitor.visit_string(util::term_to_str(&self.variant)?), diff --git a/rustler/src/serde/ser.rs b/rustler/src/serde/ser.rs index 5a2c71f7..6071c5a1 100644 --- a/rustler/src/serde/ser.rs +++ b/rustler/src/serde/ser.rs @@ -212,7 +212,7 @@ impl<'a> ser::Serializer for Serializer<'a> { /// `enum Result { Ok(u8), Err(_) }` into `{:ok, u8}` or `{:err, _}`. fn serialize_newtype_variant( self, - _name: &'static str, + name: &'static str, _variant_index: u32, variant: &'static str, value: &T, @@ -220,9 +220,9 @@ impl<'a> ser::Serializer for Serializer<'a> { where T: ?Sized + ser::Serialize, { - match variant { - "Ok" => self.serialize_newtype_struct("ok", value), - "Err" => self.serialize_newtype_struct("error", value), + match (name, variant) { + ("Result", "Ok") => self.serialize_newtype_struct("ok", value), + ("Result", "Err") => self.serialize_newtype_struct("error", value), _ => self.serialize_newtype_struct(variant, value), } } diff --git a/rustler_tests/test/serde_rustler_tests_test.exs b/rustler_tests/test/serde_rustler_tests_test.exs index 267ccc8c..0fd6bf09 100644 --- a/rustler_tests/test/serde_rustler_tests_test.exs +++ b/rustler_tests/test/serde_rustler_tests_test.exs @@ -253,14 +253,14 @@ defmodule SerdeRustlerTests.NifTest do test "newtype variant (Result::Ok(T), or {:ok, T})", ctx do test_case = {:ok, 255} - transcoded = ["Ok", 255] + transcoded = ["ok", 255] run_tests("newtype variant (ok tuple)", test_case, Helpers.skip(ctx, :transcode)) Helpers.run_transcode("newtype variant (ok tuple)", test_case, transcoded) end test "newtype variant (Result::Err(T), or {:error, T}", ctx do test_case = {:error, "error reason"} - transcoded = ["Err", "error reason"] + transcoded = ["error", "error reason"] run_tests("newtype variant (error tuple)", test_case, Helpers.skip(ctx, :transcode)) Helpers.run_transcode("newtype variant (error tuple)", test_case, transcoded) end