diff --git a/packages/data-structures/src/vec.cairo b/packages/data-structures/src/vec.cairo index 6e43b2d54..4dcbb99d9 100644 --- a/packages/data-structures/src/vec.cairo +++ b/packages/data-structures/src/vec.cairo @@ -58,8 +58,8 @@ impl VecIndex> of Index { } pub struct NullableVec { - items: Felt252Dict>, - len: usize, + pub items: Felt252Dict>, + pub len: usize, } impl DestructNullableVec> of Destruct> { diff --git a/packages/deep-learning/src/ops/binary.cairo b/packages/deep-learning/src/ops/binary.cairo index ef8ed3839..6fe1c612d 100644 --- a/packages/deep-learning/src/ops/binary.cairo +++ b/packages/deep-learning/src/ops/binary.cairo @@ -174,6 +174,7 @@ mod tests { }; } + #[test] fn test_tensor_rem() { // This would be precomputed diff --git a/packages/numbers/src/f64.cairo b/packages/numbers/src/f64.cairo index 205afe640..3dec4cf08 100644 --- a/packages/numbers/src/f64.cairo +++ b/packages/numbers/src/f64.cairo @@ -14,6 +14,7 @@ pub const ONE: i64 = 4294967296; // 2 ** 32 pub const HALF: i64 = 2147483648; // 2 ** 31 pub const MAX: i64 = 9223372036854775807; //2**63 - 1 const MIN: i64 = -9223372036854775808; // -2 ** 63 +pub const NaN: i64 = 0x4e614e; // STRUCTS diff --git a/packages/numbers/src/f64/comp.cairo b/packages/numbers/src/f64/comp.cairo index e8b2575eb..b993bc9a6 100644 --- a/packages/numbers/src/f64/comp.cairo +++ b/packages/numbers/src/f64/comp.cairo @@ -1,6 +1,10 @@ -use super::{F64, FixedTrait, F64Impl}; +use super::{F64, NaN, FixedTrait, F64Impl}; fn max(a: F64, b: F64) -> F64 { + if a.d == NaN || b.d == NaN { + return F64 { d: NaN }; + } + if (a >= b) { return a; } else { @@ -9,6 +13,10 @@ fn max(a: F64, b: F64) -> F64 { } fn min(a: F64, b: F64) -> F64 { + if a.d == NaN || b.d == NaN { + return F64 { d: NaN }; + } + if (a <= b) { return a; } else { diff --git a/packages/numbers/src/f64/erf.cairo b/packages/numbers/src/f64/erf.cairo index ee99f0f7c..b113a1048 100644 --- a/packages/numbers/src/f64/erf.cairo +++ b/packages/numbers/src/f64/erf.cairo @@ -1,4 +1,4 @@ -use super::{F64, F64Impl, ONE}; +use super::{F64, NaN, F64Impl, ONE}; use super::lut; const ERF_COMPUTATIONAL_ACCURACY: i64 = 100; @@ -10,6 +10,10 @@ const MAX_ERF_NUMBER: i64 = 15032385536; const ERF_TRUNCATION_NUMBER: i64 = 8589934592; pub(crate) fn erf(x: F64) -> F64 { + if x.d == NaN { + return F64 { d: NaN }; + } + // Lookup // 1. if x.mag < 3.5 { lookup table } // 2. else{ return 1} diff --git a/packages/numbers/src/f64/ops.cairo b/packages/numbers/src/f64/ops.cairo index 39380f5c7..1ec747ef0 100644 --- a/packages/numbers/src/f64/ops.cairo +++ b/packages/numbers/src/f64/ops.cairo @@ -1,10 +1,13 @@ use core::num::traits::{WideMul, Sqrt}; -use super::{F64, FixedTrait, F64Impl, ONE, HALF, lut}; - +use super::{F64, FixedTrait, F64Impl, ONE, HALF, NaN, lut}; // PUBLIC pub(crate) fn abs(a: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } + if a.d <= 0 { F64 { d: a.d * ONE } } else { @@ -13,10 +16,16 @@ pub(crate) fn abs(a: F64) -> F64 { } pub(crate) fn add(a: F64, b: F64) -> F64 { + if a.d == NaN || b.d == NaN { + return F64 { d: NaN }; + } F64 { d: a.d + b.d } } pub(crate) fn ceil(a: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } let div = Div::div(a.d, ONE); let rem = Rem::rem(a.d, ONE); @@ -28,6 +37,9 @@ pub(crate) fn ceil(a: F64) -> F64 { } pub(crate) fn div(a: F64, b: F64) -> F64 { + if a.d == NaN || b.d == NaN { + return F64 { d: NaN }; + } let a_i128 = WideMul::wide_mul(a.d, ONE); let res_i128 = a_i128 / b.d.into(); @@ -40,12 +52,18 @@ pub(crate) fn eq(a: @F64, b: @F64) -> bool { // // Calculates the natural exponent of x: e^x pub(crate) fn exp(a: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } return exp2(FixedTrait::new(6196328018) * a); } // Calculates the binary exponent of x: 2^x pub(crate) fn exp2(a: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } if (a.d == 0) { return FixedTrait::ONE(); } @@ -83,6 +101,10 @@ pub(crate) fn exp2_int(exp: i64) -> F64 { } pub(crate) fn floor(a: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } + let div = Div::div(a.d, ONE); let rem = Rem::rem(a.d, ONE); @@ -111,13 +133,23 @@ pub(crate) fn le(a: F64, b: F64) -> bool { // Calculates the natural logarithm of x: ln(x) // self must be greater than zero pub(crate) fn ln(a: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } + return FixedTrait::new(2977044472) * log2(a); // ln(2) = 0.693... } // Calculates the binary logarithm of x: log2(x) // self must be greather than zero pub(crate) fn log2(a: F64) -> F64 { - assert(a.d >= 0, 'must be positive'); + if a.d == NaN { + return F64 { d: NaN }; + } + + if a.d < 0 { + return F64 { d: NaN }; + } if (a.d == ONE) { return FixedTrait::ZERO(); @@ -149,6 +181,10 @@ pub(crate) fn log2(a: F64) -> F64 { // Calculates the base 10 log of x: log10(x) // self must be greater than zero pub(crate) fn log10(a: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } + return FixedTrait::new(1292913986) * log2(a); // log10(2) = 0.301... } @@ -157,6 +193,10 @@ pub(crate) fn lt(a: F64, b: F64) -> bool { } pub(crate) fn mul(a: F64, b: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } + let prod_i128 = WideMul::wide_mul(a.d, b.d); // Re-apply sign @@ -168,6 +208,10 @@ pub(crate) fn ne(a: @F64, b: @F64) -> bool { } pub(crate) fn neg(a: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } + a * F64 { d: -ONE } } @@ -175,6 +219,10 @@ pub(crate) fn neg(a: F64) -> F64 { // self is a Fixed point value // b is a Fixed point value pub(crate) fn pow(a: F64, b: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } + let rem = Rem::rem(b.d, ONE); // use the more performant integer pow when y is an int @@ -188,6 +236,10 @@ pub(crate) fn pow(a: F64, b: F64) -> F64 { // Calclates the value of a^b and checks for overflow before returning pub(crate) fn pow_int(a: F64, b: i64, sign: bool) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } + let mut x = a; let mut n = b; @@ -222,10 +274,18 @@ pub(crate) fn pow_int(a: F64, b: i64, sign: bool) -> F64 { } pub(crate) fn rem(a: F64, b: F64) -> F64 { - return a - floor(a / b) * b; + if a.d == NaN || b.d == NaN || b.d == 0 { + return F64 { d: NaN }; + } + + return F64 { d: a.d - (a.d / b.d) * b.d }; } pub(crate) fn round(a: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } + let div = Div::div(a.d, ONE); let rem = Rem::rem(a.d, ONE); @@ -239,7 +299,13 @@ pub(crate) fn round(a: F64) -> F64 { // Calculates the square root of a FP16x16 point value // x must be positive pub(crate) fn sqrt(a: F64) -> F64 { - assert(a.d >= 0, 'must be positive'); + if a.d == NaN { + return F64 { d: NaN }; + } + + if a.d < 0 { + return F64 { d: NaN }; + } let a: u128 = a.d.try_into().unwrap(); let one: u128 = ONE.try_into().unwrap(); @@ -250,6 +316,10 @@ pub(crate) fn sqrt(a: F64) -> F64 { } pub(crate) fn sub(a: F64, b: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } + return add(a, -b); } @@ -262,8 +332,8 @@ mod tests { use orion_numbers::f64::helpers::{assert_precise, assert_relative}; use super::{ - FixedTrait, ONE, round, floor, sqrt, ceil, lut, exp, exp2, exp2_int, pow, log10, log2, ln, - eq, ne, add, F64, F64Impl + FixedTrait, ONE, NaN, round, floor, sqrt, ceil, lut, exp, exp2, exp2_int, pow, log10, log2, + ln, eq, ne, add, F64, F64Impl }; #[test] @@ -292,10 +362,9 @@ mod tests { } #[test] - #[should_panic] - fn test_sqrt_fail() { + fn test_sqrt_nan() { let a = FixedTrait::new_unscaled(-25); - sqrt(a); + assert(sqrt(a).d == NaN, 'should return NaN'); } #[test] diff --git a/packages/numbers/src/f64/trig.cairo b/packages/numbers/src/f64/trig.cairo index 72018e862..85a94eb7d 100644 --- a/packages/numbers/src/f64/trig.cairo +++ b/packages/numbers/src/f64/trig.cairo @@ -1,4 +1,4 @@ -use super::{F64, F64Impl, lut, helpers::abs_and_sign, HALF, ONE}; +use super::{F64, NaN, F64Impl, lut, helpers::abs_and_sign, HALF, ONE}; const TWO_PI: i64 = 26986075409; const PI: i64 = 13493037705; @@ -6,6 +6,10 @@ const HALF_PI: i64 = 6746518852; pub(crate) fn sin_fast(a: F64) -> F64 { + if a.d == NaN { + return F64 { d: NaN }; + } + let (a_abs, _) = abs_and_sign(a.d); let a1 = a_abs.try_into().unwrap() % TWO_PI;