Skip to content

Commit

Permalink
Allow floats to decode from integer terms (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
benhaney authored Jul 25, 2024
1 parent f758190 commit 49d4773
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 31 deletions.
93 changes: 62 additions & 31 deletions rustler/src/types/primitive.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,43 @@
use crate::types::atom;
use crate::{Decoder, Encoder, Env, Error, NifResult, Term};

macro_rules! impl_number_transcoder {
($dec_type:ty, $nif_type:ty, $encode_fun:ident, $decode_fun:ident) => {
macro_rules! erl_make {
($self:expr, $env:ident, $encode_fun:ident, $type:ty) => {
#[allow(clippy::cast_lossless)]
unsafe {
Term::new(
$env,
rustler_sys::$encode_fun($env.as_c_arg(), $self as $type),
)
}
};
}

macro_rules! erl_get {
($decode_fun:ident, $term:ident, $dest:ident) => {
unsafe {
rustler_sys::$decode_fun($term.get_env().as_c_arg(), $term.as_c_arg(), &mut $dest)
}
};
}

macro_rules! impl_number_encoder {
($dec_type:ty, $nif_type:ty, $encode_fun:ident) => {
impl Encoder for $dec_type {
fn encode<'a>(&self, env: Env<'a>) -> Term<'a> {
#[allow(clippy::cast_lossless)]
unsafe {
Term::new(
env,
rustler_sys::$encode_fun(env.as_c_arg(), *self as $nif_type),
)
}
erl_make!(*self, env, $encode_fun, $nif_type)
}
}
};
}

macro_rules! impl_number_decoder {
($dec_type:ty, $nif_type:ty, $decode_fun:ident) => {
impl<'a> Decoder<'a> for $dec_type {
fn decode(term: Term) -> NifResult<$dec_type> {
#![allow(unused_unsafe)]
let mut res: $nif_type = Default::default();
if unsafe {
rustler_sys::$decode_fun(term.get_env().as_c_arg(), term.as_c_arg(), &mut res)
} == 0
{
if erl_get!($decode_fun, term, res) == 0 {
return Err(Error::BadArg);
}
Ok(res as $dec_type)
Expand All @@ -30,12 +46,19 @@ macro_rules! impl_number_transcoder {
};
}

macro_rules! impl_number_transcoder {
($dec_type:ty, $nif_type:ty, $encode_fun:ident, $decode_fun:ident) => {
impl_number_encoder!($dec_type, $nif_type, $encode_fun);
impl_number_decoder!($dec_type, $nif_type, $decode_fun);
};
}

// Base number types
impl_number_transcoder!(i32, i32, enif_make_int, enif_get_int);
impl_number_transcoder!(u32, u32, enif_make_uint, enif_get_uint);
impl_number_transcoder!(i64, i64, enif_make_int64, enif_get_int64);
impl_number_transcoder!(u64, u64, enif_make_uint64, enif_get_uint64);
impl_number_transcoder!(f64, f64, enif_make_double, enif_get_double);
impl_number_encoder!(f64, f64, enif_make_double);

// Casted number types
impl_number_transcoder!(i8, i32, enif_make_int, enif_get_int);
Expand All @@ -44,25 +67,18 @@ impl_number_transcoder!(i16, i32, enif_make_int, enif_get_int);
impl_number_transcoder!(u16, u32, enif_make_uint, enif_get_uint);
impl_number_transcoder!(usize, u64, enif_make_uint64, enif_get_uint64);
impl_number_transcoder!(isize, i64, enif_make_int64, enif_get_int64);
impl_number_encoder!(f32, f64, enif_make_double);

impl Encoder for bool {
fn encode<'a>(&self, env: Env<'a>) -> Term<'a> {
if *self {
atom::true_().to_term(env)
} else {
atom::false_().to_term(env)
// Manual Decoder impls for floats so they can fall back to decoding from integer terms
impl<'a> Decoder<'a> for f64 {
fn decode(term: Term) -> NifResult<f64> {
#![allow(unused_unsafe)]
let mut res: f64 = Default::default();
if erl_get!(enif_get_double, term, res) == 0 {
let res_fallback: i64 = term.decode()?;
return Ok(res_fallback as f64);
}
}
}
impl<'a> Decoder<'a> for bool {
fn decode(term: Term<'a>) -> NifResult<bool> {
atom::decode_bool(term)
}
}

impl Encoder for f32 {
fn encode<'a>(&self, env: Env<'a>) -> Term<'a> {
f64::from(*self).encode(env)
Ok(res)
}
}

Expand All @@ -78,3 +94,18 @@ impl<'a> Decoder<'a> for f32 {
}
}
}

impl Encoder for bool {
fn encode<'a>(&self, env: Env<'a>) -> Term<'a> {
if *self {
atom::true_().to_term(env)
} else {
atom::false_().to_term(env)
}
}
}
impl<'a> Decoder<'a> for bool {
fn decode(term: Term<'a>) -> NifResult<bool> {
atom::decode_bool(term)
}
}
1 change: 1 addition & 0 deletions rustler_tests/lib/rustler_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ defmodule RustlerTest do

def add_u32(_, _), do: err()
def add_i32(_, _), do: err()
def add_floats(_, _), do: err()
def echo_u8(_), do: err()
def echo_u128(_), do: err()
def echo_i128(_), do: err()
Expand Down
5 changes: 5 additions & 0 deletions rustler_tests/native/rustler_test/src/test_primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ pub fn add_i32(a: i32, b: i32) -> i32 {
a + b
}

#[rustler::nif]
pub fn add_floats(a: f32, b: f64) -> f64 {
(a as f64) + b
}

#[rustler::nif]
pub fn echo_u8(n: u8) -> u8 {
n
Expand Down
4 changes: 4 additions & 0 deletions rustler_tests/test/primitives_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ defmodule RustlerTest.PrimitivesTest do
assert 3 == RustlerTest.add_i32(6, -3)
assert -3 == RustlerTest.add_i32(3, -6)
assert 3 == RustlerTest.echo_u8(3)
assert 2.0 == RustlerTest.add_floats(3.0, -1.0)
assert 2.0 == RustlerTest.add_floats(3, -1)
assert 2.0 == RustlerTest.add_floats(3.0, -1)
assert 2.0 == RustlerTest.add_floats(3, -1.0)
end

test "number decoding should fail on invalid terms" do
Expand Down

0 comments on commit 49d4773

Please sign in to comment.