Skip to content

Commit

Permalink
add NaN value
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Aug 28, 2024
1 parent 69e0448 commit 96bfa1c
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 15 deletions.
4 changes: 2 additions & 2 deletions packages/data-structures/src/vec.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ impl VecIndex<V, T, +VecTrait<V, T>> of Index<V, usize> {
}

pub struct NullableVec<T> {
items: Felt252Dict<Nullable<T>>,
len: usize,
pub items: Felt252Dict<Nullable<T>>,
pub len: usize,
}

impl DestructNullableVec<T, +Drop<T>> of Destruct<NullableVec<T>> {
Expand Down
1 change: 1 addition & 0 deletions packages/deep-learning/src/ops/binary.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ mod tests {
};
}


#[test]
fn test_tensor_rem() {
// This would be precomputed
Expand Down
1 change: 1 addition & 0 deletions packages/numbers/src/f64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub const ONE: i64 = 4294967296; // 2 ** 32
pub const HALF: i64 = 2147483648; // 2 ** 31
pub const MAX: i64 = 9223372036854775807; //2**63 - 1
const MIN: i64 = -9223372036854775808; // -2 ** 63
pub const NaN: i64 = 0x4e614e;

// STRUCTS

Expand Down
10 changes: 9 additions & 1 deletion packages/numbers/src/f64/comp.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use super::{F64, FixedTrait, F64Impl};
use super::{F64, NaN, FixedTrait, F64Impl};

fn max(a: F64, b: F64) -> F64 {
if a.d == NaN || b.d == NaN {
return F64 { d: NaN };
}

if (a >= b) {
return a;
} else {
Expand All @@ -9,6 +13,10 @@ fn max(a: F64, b: F64) -> F64 {
}

fn min(a: F64, b: F64) -> F64 {
if a.d == NaN || b.d == NaN {
return F64 { d: NaN };
}

if (a <= b) {
return a;
} else {
Expand Down
6 changes: 5 additions & 1 deletion packages/numbers/src/f64/erf.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{F64, F64Impl, ONE};
use super::{F64, NaN, F64Impl, ONE};
use super::lut;

const ERF_COMPUTATIONAL_ACCURACY: i64 = 100;
Expand All @@ -10,6 +10,10 @@ const MAX_ERF_NUMBER: i64 = 15032385536;
const ERF_TRUNCATION_NUMBER: i64 = 8589934592;

pub(crate) fn erf(x: F64) -> F64 {
if x.d == NaN {
return F64 { d: NaN };
}

// Lookup
// 1. if x.mag < 3.5 { lookup table }
// 2. else{ return 1}
Expand Down
89 changes: 79 additions & 10 deletions packages/numbers/src/f64/ops.cairo
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use core::num::traits::{WideMul, Sqrt};
use super::{F64, FixedTrait, F64Impl, ONE, HALF, lut};

use super::{F64, FixedTrait, F64Impl, ONE, HALF, NaN, lut};

// PUBLIC

pub(crate) fn abs(a: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}

if a.d <= 0 {
F64 { d: a.d * ONE }
} else {
Expand All @@ -13,10 +16,16 @@ pub(crate) fn abs(a: F64) -> F64 {
}

pub(crate) fn add(a: F64, b: F64) -> F64 {
if a.d == NaN || b.d == NaN {
return F64 { d: NaN };
}
F64 { d: a.d + b.d }
}

pub(crate) fn ceil(a: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}
let div = Div::div(a.d, ONE);
let rem = Rem::rem(a.d, ONE);

Expand All @@ -28,6 +37,9 @@ pub(crate) fn ceil(a: F64) -> F64 {
}

pub(crate) fn div(a: F64, b: F64) -> F64 {
if a.d == NaN || b.d == NaN {
return F64 { d: NaN };
}
let a_i128 = WideMul::wide_mul(a.d, ONE);
let res_i128 = a_i128 / b.d.into();

Expand All @@ -40,12 +52,18 @@ pub(crate) fn eq(a: @F64, b: @F64) -> bool {

// // Calculates the natural exponent of x: e^x
pub(crate) fn exp(a: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}
return exp2(FixedTrait::new(6196328018) * a);
}


// Calculates the binary exponent of x: 2^x
pub(crate) fn exp2(a: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}
if (a.d == 0) {
return FixedTrait::ONE();
}
Expand Down Expand Up @@ -83,6 +101,10 @@ pub(crate) fn exp2_int(exp: i64) -> F64 {
}

pub(crate) fn floor(a: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}

let div = Div::div(a.d, ONE);
let rem = Rem::rem(a.d, ONE);

Expand Down Expand Up @@ -111,13 +133,23 @@ pub(crate) fn le(a: F64, b: F64) -> bool {
// Calculates the natural logarithm of x: ln(x)
// self must be greater than zero
pub(crate) fn ln(a: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}

return FixedTrait::new(2977044472) * log2(a); // ln(2) = 0.693...
}

// Calculates the binary logarithm of x: log2(x)
// self must be greather than zero
pub(crate) fn log2(a: F64) -> F64 {
assert(a.d >= 0, 'must be positive');
if a.d == NaN {
return F64 { d: NaN };
}

if a.d < 0 {
return F64 { d: NaN };
}

if (a.d == ONE) {
return FixedTrait::ZERO();
Expand Down Expand Up @@ -149,6 +181,10 @@ pub(crate) fn log2(a: F64) -> F64 {
// Calculates the base 10 log of x: log10(x)
// self must be greater than zero
pub(crate) fn log10(a: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}

return FixedTrait::new(1292913986) * log2(a); // log10(2) = 0.301...
}

Expand All @@ -157,6 +193,10 @@ pub(crate) fn lt(a: F64, b: F64) -> bool {
}

pub(crate) fn mul(a: F64, b: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}

let prod_i128 = WideMul::wide_mul(a.d, b.d);

// Re-apply sign
Expand All @@ -168,13 +208,21 @@ pub(crate) fn ne(a: @F64, b: @F64) -> bool {
}

pub(crate) fn neg(a: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}

a * F64 { d: -ONE }
}

// Calclates the value of x^y and checks for overflow before returning
// self is a Fixed point value
// b is a Fixed point value
pub(crate) fn pow(a: F64, b: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}

let rem = Rem::rem(b.d, ONE);

// use the more performant integer pow when y is an int
Expand All @@ -188,6 +236,10 @@ pub(crate) fn pow(a: F64, b: F64) -> F64 {

// Calclates the value of a^b and checks for overflow before returning
pub(crate) fn pow_int(a: F64, b: i64, sign: bool) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}

let mut x = a;
let mut n = b;

Expand Down Expand Up @@ -222,10 +274,18 @@ pub(crate) fn pow_int(a: F64, b: i64, sign: bool) -> F64 {
}

pub(crate) fn rem(a: F64, b: F64) -> F64 {
return a - floor(a / b) * b;
if a.d == NaN || b.d == NaN || b.d == 0 {
return F64 { d: NaN };
}

return F64 { d: a.d - (a.d / b.d) * b.d };
}

pub(crate) fn round(a: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}

let div = Div::div(a.d, ONE);
let rem = Rem::rem(a.d, ONE);

Expand All @@ -239,7 +299,13 @@ pub(crate) fn round(a: F64) -> F64 {
// Calculates the square root of a FP16x16 point value
// x must be positive
pub(crate) fn sqrt(a: F64) -> F64 {
assert(a.d >= 0, 'must be positive');
if a.d == NaN {
return F64 { d: NaN };
}

if a.d < 0 {
return F64 { d: NaN };
}

let a: u128 = a.d.try_into().unwrap();
let one: u128 = ONE.try_into().unwrap();
Expand All @@ -250,6 +316,10 @@ pub(crate) fn sqrt(a: F64) -> F64 {
}

pub(crate) fn sub(a: F64, b: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}

return add(a, -b);
}

Expand All @@ -262,8 +332,8 @@ mod tests {
use orion_numbers::f64::helpers::{assert_precise, assert_relative};

use super::{
FixedTrait, ONE, round, floor, sqrt, ceil, lut, exp, exp2, exp2_int, pow, log10, log2, ln,
eq, ne, add, F64, F64Impl
FixedTrait, ONE, NaN, round, floor, sqrt, ceil, lut, exp, exp2, exp2_int, pow, log10, log2,
ln, eq, ne, add, F64, F64Impl
};

#[test]
Expand Down Expand Up @@ -292,10 +362,9 @@ mod tests {
}

#[test]
#[should_panic]
fn test_sqrt_fail() {
fn test_sqrt_nan() {
let a = FixedTrait::new_unscaled(-25);
sqrt(a);
assert(sqrt(a).d == NaN, 'should return NaN');
}

#[test]
Expand Down
6 changes: 5 additions & 1 deletion packages/numbers/src/f64/trig.cairo
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use super::{F64, F64Impl, lut, helpers::abs_and_sign, HALF, ONE};
use super::{F64, NaN, F64Impl, lut, helpers::abs_and_sign, HALF, ONE};

const TWO_PI: i64 = 26986075409;
const PI: i64 = 13493037705;
const HALF_PI: i64 = 6746518852;


pub(crate) fn sin_fast(a: F64) -> F64 {
if a.d == NaN {
return F64 { d: NaN };
}

let (a_abs, _) = abs_and_sign(a.d);

let a1 = a_abs.try_into().unwrap() % TWO_PI;
Expand Down

0 comments on commit 96bfa1c

Please sign in to comment.