From aacfc29866e185198a638851483669df157b9790 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 25 Jul 2024 10:50:01 +0300 Subject: [PATCH] write cdf tests + update orion-numbers --- packages/orion-algo/src/algo/cdf.cairo | 92 ++++++++++-- packages/orion-algo/src/algo/linear_fit.cairo | 2 +- .../src/span_math/span_f16x16.cairo | 15 +- .../src/span_math/span_f32x32.cairo | 17 ++- packages/orion-numbers/src/f16x16/core.cairo | 48 ++----- packages/orion-numbers/src/f16x16/erf.cairo | 8 +- .../orion-numbers/src/f16x16/helpers.cairo | 42 +++++- packages/orion-numbers/src/f16x16/lut.cairo | 4 +- packages/orion-numbers/src/f16x16/math.cairo | 98 ++++++------- packages/orion-numbers/src/f16x16/trig.cairo | 132 +++++++++--------- packages/orion-numbers/src/f32x32/core.cairo | 51 ++----- .../orion-numbers/src/f32x32/helpers.cairo | 48 ++++++- packages/orion-numbers/src/f32x32/math.cairo | 36 ++--- packages/orion-numbers/src/lib.cairo | 17 +-- 14 files changed, 368 insertions(+), 242 deletions(-) diff --git a/packages/orion-algo/src/algo/cdf.cairo b/packages/orion-algo/src/algo/cdf.cairo index 81b6cf76b..a23d0cbeb 100644 --- a/packages/orion-algo/src/algo/cdf.cairo +++ b/packages/orion-algo/src/algo/cdf.cairo @@ -3,22 +3,59 @@ use core::array::{SpanTrait, SpanIter}; use orion_algo::span_math::SpanMathTrait; use orion_numbers::FixedTrait; +/// Computes the cumulative distribution function (CDF) for a given set of values using the +/// standard normal distribution formula. This implementation allows for optional location (`loc`) +/// and scale (`scale`) parameters, which default to 0.0 and 1.0 respectively if not provided. +/// +/// # Arguments +/// * `x` - A `Span` containing the data points for which the CDF is to be computed. +/// * `loc` - An optional `Span` representing the location parameter (mean) for each data point. +/// If `Some(Span)` is provided, it must either contain a single value or have the same +/// length as `x`. If `None` is provided, defaults to a Span of a single 0.0 value. +/// * `scale` - An optional `Span` representing the scale parameter (standard deviation) for each +/// data point. If `Some(Span)` is provided, it must either contain a single value +/// or have the same length as `x`. If `None` is provided, defaults to a Span of a single 1.0 value. +/// +/// # Returns +/// A `Span` representing the CDF values corresponding to each entry in `x`. +/// +/// # Panics +/// * The function panics if the lengths of `loc` or `scale` Spans are more than one and not equal to +/// the length of `x`. +/// +/// # Examples +/// Basic usage: +/// +/// ``` +/// let x = array![FixedTrait::new_unscaled(2), FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(0)].span(); +/// let result = cdf(x, None, None); +/// // Expected output: CDF values for a standard normal distribution +/// ``` +/// +/// With location and scale parameters: +/// +/// ``` +/// let x = array![FixedTrait::new_unscaled(2), FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(0)].span(); +/// let loc = array![FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(1)].span(); +/// let scale = array![FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(1)].span(); +/// let result = cdf(x, Some(loc), Some(scale)); +/// // Expected output: Adjusted CDF values using specified loc and scale +/// ``` pub fn cdf< T, +FixedTrait, +SpanMathTrait, +Sub, +Div, +Mul, +Drop, +Add, +Copy >( x: Span, loc: Option>, scale: Option> -) //-> Span -{ +) -> Span { // Default loc to 0.0 if not provided let mut loc = match loc { Option::Some(val) => val, - Option::None => array![FixedTrait::ZERO()].span() + Option::None => array![FixedTrait::ZERO].span() }; // Default scale to 1.0 if not provided let mut scale = match scale { Option::Some(val) => val, - Option::None => array![FixedTrait::ONE()].span() + Option::None => array![FixedTrait::ONE].span() }; // single value or same length as x @@ -50,16 +87,51 @@ pub fn cdf< scale_first_val }; - // Calculate: 0.5 * (1.0 + erf((x_val - loc_val) / (scale_val * (2.0f64).sqrt()))) - let calc = FixedTrait::HALF() - * (FixedTrait::ONE() - + ((*x_val - loc_val) - / (scale_val * (FixedTrait::ONE() + FixedTrait::ONE()).sqrt())) - .erf()); + // Calculate: 0.5 * (1.0 + erf((x_val - loc_val) / (scale_val * sqrt(2.0)))) + let sqrt_2 = FixedTrait::sqrt(FixedTrait::TWO); + let x_minus_loc = FixedTrait::sub(*x_val, loc_val); + let scale_times_sqrt_2 = FixedTrait::mul(scale_val, sqrt_2); + let division_result = FixedTrait::div(x_minus_loc, scale_times_sqrt_2); + let erf_result = FixedTrait::erf(division_result); + let one_plus_erf = FixedTrait::add(FixedTrait::ONE, erf_result); + let calc = FixedTrait::mul(FixedTrait::HALF, one_plus_erf); res_data.append(calc); }, Option::None => { break; } } }; + + res_data.span() +} + +#[cfg(test)] +mod tests { + use super::cdf; + use orion_numbers::{f16x16::{core::f16x16, helpers::assert_relative_span}, FixedTrait}; + + #[test] + fn test_cdf_loc_scale_are_none() { + let x: Span = array![FixedTrait::ONE, FixedTrait::HALF, FixedTrait::ZERO].span(); + + let res = cdf(x, Option::None, Option::None); + let expected = array![55138, 45316, 32768].span(); + + assert_relative_span(res, expected, 'res != expected', Option::None); + } + + #[test] + fn test_cdf_loc_scale_are_some() { + let x: Span = array![FixedTrait::ONE, FixedTrait::HALF, FixedTrait::ZERO].span(); + + let loc: Span = array![FixedTrait::HALF, FixedTrait::HALF, FixedTrait::HALF].span(); + + let scale: Span = array![FixedTrait::HALF, FixedTrait::HALF, FixedTrait::HALF] + .span(); + + let res = cdf(x, Option::Some(loc), Option::Some(scale)); + let expected = array![55138, 32768, 10398].span(); + + assert_relative_span(res, expected, 'res != expected', Option::None); + } } diff --git a/packages/orion-algo/src/algo/linear_fit.cairo b/packages/orion-algo/src/algo/linear_fit.cairo index f5c6f185e..e7b04b5db 100644 --- a/packages/orion-algo/src/algo/linear_fit.cairo +++ b/packages/orion-algo/src/algo/linear_fit.cairo @@ -26,7 +26,7 @@ pub fn linear_fit< let sum_xy = x.dot(y); let denominator = n * sum_xx - (sum_x.mul(sum_x)); - if denominator == FixedTrait::ZERO() { + if denominator == FixedTrait::ZERO { panic!("division by zero exception") } diff --git a/packages/orion-algo/src/span_math/span_f16x16.cairo b/packages/orion-algo/src/span_math/span_f16x16.cairo index e1f3d82c6..6c26c836b 100644 --- a/packages/orion-algo/src/span_math/span_f16x16.cairo +++ b/packages/orion-algo/src/span_math/span_f16x16.cairo @@ -1,4 +1,4 @@ -use orion_numbers::{f16x16::core::{f16x16, ONE}, FixedTrait}; +use orion_numbers::{f16x16::core::f16x16, FixedTrait}; use orion_algo::span_math::SpanMathTrait; @@ -34,7 +34,7 @@ fn arange(n: u32) -> Span { let mut i = 0; let mut arr = array![]; while i < n { - arr.append(i.try_into().unwrap() * ONE); + arr.append(i.try_into().unwrap() * FixedTrait::ONE); i += 1; }; @@ -55,7 +55,7 @@ fn dot(a: Span, b: Span) -> f16x16 { fn max(mut a: Span) -> f16x16 { assert(a.len() > 0, 'span cannot be empty'); - let mut max = FixedTrait::MIN(); + let mut max = FixedTrait::MIN; loop { match a.pop_front() { @@ -70,7 +70,7 @@ fn max(mut a: Span) -> f16x16 { fn min(mut a: Span) -> f16x16 { assert(a.len() > 0, 'span cannot be empty'); - let mut min = FixedTrait::MAX(); + let mut min = FixedTrait::MAX; loop { match a.pop_front() { @@ -105,8 +105,9 @@ fn sum(mut a: Span) -> f16x16 { #[cfg(test)] mod tests { - use super::{arange, dot, max, min, prod, sum, ONE}; + use super::{arange, dot, max, min, prod, sum}; use orion_numbers::f16x16::helpers::assert_precise; + use orion_numbers::F16x16Impl; #[test] fn test_arange() { @@ -129,7 +130,7 @@ mod tests { let y = array![0, 131072, 262144, 393216, 524288, 655360].span(); // 0, 2, 4, 6, 8, 10 let result = dot(x, y); - assert_precise(result, (110 * ONE).into(), 'should be equal', Option::None); + assert_precise(result, (110 * F16x16Impl::ONE).into(), 'should be equal', Option::None); } #[test] @@ -165,6 +166,6 @@ mod tests { let result = sum(x); - assert_precise(result, (15 * ONE).into(), 'should be equal', Option::None); + assert_precise(result, (15 * F16x16Impl::ONE).into(), 'should be equal', Option::None); } } diff --git a/packages/orion-algo/src/span_math/span_f32x32.cairo b/packages/orion-algo/src/span_math/span_f32x32.cairo index b0108734c..0131f2afe 100644 --- a/packages/orion-algo/src/span_math/span_f32x32.cairo +++ b/packages/orion-algo/src/span_math/span_f32x32.cairo @@ -1,5 +1,5 @@ use orion_numbers::{FixedTrait}; -use orion_numbers::f32x32::core::{f32x32, ONE}; +use orion_numbers::f32x32::core::f32x32; use orion_algo::span_math::SpanMathTrait; @@ -34,7 +34,7 @@ fn arange(n: u32) -> Span { let mut i = 0; let mut arr = array![]; while i < n { - arr.append(i.try_into().unwrap() * ONE); + arr.append(i.try_into().unwrap() * FixedTrait::ONE); i += 1; }; @@ -55,7 +55,7 @@ fn dot(a: Span, b: Span) -> f32x32 { fn max(mut a: Span) -> f32x32 { assert(a.len() > 0, 'span cannot be empty'); - let mut max = FixedTrait::MIN(); + let mut max = FixedTrait::MIN; loop { match a.pop_front() { @@ -70,7 +70,7 @@ fn max(mut a: Span) -> f32x32 { fn min(mut a: Span) -> f32x32 { assert(a.len() > 0, 'span cannot be empty'); - let mut min = FixedTrait::MAX(); + let mut min = FixedTrait::MAX; loop { match a.pop_front() { @@ -105,8 +105,9 @@ fn sum(mut a: Span) -> f32x32 { #[cfg(test)] mod tests { - use super::{arange, dot, max, min, prod, sum, ONE}; + use super::{arange, dot, max, min, prod, sum}; use orion_numbers::f32x32::helpers::assert_precise; + use orion_numbers::F32x32Impl; #[test] fn test_arange() { @@ -131,7 +132,7 @@ mod tests { .span(); // 0, 2, 4, 6, 8, 10 let result = dot(x, y); - assert_precise(result, (7208960 * ONE).into(), 'should be equal', Option::None); + assert_precise(result, (7208960 * F32x32Impl::ONE).into(), 'should be equal', Option::None); } #[test] @@ -171,6 +172,8 @@ mod tests { let result = sum(x); - assert_precise(result, (98304000 * ONE).into(), 'should be equal', Option::None); + assert_precise( + result, (98304000 * F32x32Impl::ONE).into(), 'should be equal', Option::None + ); } } diff --git a/packages/orion-numbers/src/f16x16/core.cairo b/packages/orion-numbers/src/f16x16/core.cairo index d3ee15e58..a8a26a1a9 100644 --- a/packages/orion-numbers/src/f16x16/core.cairo +++ b/packages/orion-numbers/src/f16x16/core.cairo @@ -3,37 +3,19 @@ use orion_numbers::FixedTrait; pub type f16x16 = i32; -// CONSTANTS -pub const TWO: f16x16 = 131072; // 2 ** 17 -pub const ONE: f16x16 = 65536; // 2 ** 16 -pub const HALF: f16x16 = 32768; // 2 ** 15 -pub const MAX: f16x16 = 2147483647; // 2 ** 31 -1 -pub const MIN: f16x16 = -2147483648; // 2 ** 31 - pub impl F16x16Impl of FixedTrait { - fn ZERO() -> f16x16 { - 0 - } + // CONSTANTS + const ZERO: f16x16 = 0; + const HALF: f16x16 = 32768; // 2 ** 15 + const ONE: f16x16 = 65536; // 2 ** 16 + const TWO: f16x16 = 131072; // 2 ** 17 + const MAX: f16x16 = 2147483647; // 2 ** 31 -1 + const MIN: f16x16 = -2147483648; // 2 ** 31 - fn HALF() -> f16x16 { - HALF - } - - fn ONE() -> f16x16 { - ONE - } - - fn MAX() -> f16x16 { - MAX - } - - fn MIN() -> f16x16 { - MIN - } fn new_unscaled(x: i32) -> f16x16 { - x * ONE + x * Self::ONE } fn new(x: i32) -> f16x16 { @@ -45,7 +27,7 @@ pub impl F16x16Impl of FixedTrait { } fn from_unscaled_felt(x: felt252) -> f16x16 { - return FixedTrait::from_felt(x * ONE.into()); + return FixedTrait::from_felt(x * Self::ONE.into()); } fn abs(self: f16x16) -> f16x16 { @@ -182,27 +164,27 @@ pub impl F16x16Impl of FixedTrait { } fn INF() -> f16x16 { - MAX + Self::MAX } fn POS_INF() -> f16x16 { - MAX + Self::MAX } fn NEG_INF() -> f16x16 { - MIN + Self::MIN } fn is_inf(self: f16x16) -> bool { - self == MAX + self == Self::MAX } fn is_pos_inf(self: f16x16) -> bool { - self == MAX + self == Self::MAX } fn is_neg_inf(self: f16x16) -> bool { - self == MIN + self == Self::MIN } fn erf(self: f16x16) -> f16x16 { diff --git a/packages/orion-numbers/src/f16x16/erf.cairo b/packages/orion-numbers/src/f16x16/erf.cairo index b7a3081e6..cd2ad3f27 100644 --- a/packages/orion-numbers/src/f16x16/erf.cairo +++ b/packages/orion-numbers/src/f16x16/erf.cairo @@ -1,5 +1,5 @@ -use orion_numbers::f16x16::{core::{f16x16, ONE}, lut}; -use orion_numbers::FixedTrait; +use orion_numbers::f16x16::{core::{f16x16}, lut}; +use orion_numbers::{FixedTrait}; const ERF_COMPUTATIONAL_ACCURACY: i32 = 100; const ROUND_CHECK_NUMBER: i32 = 10; @@ -18,7 +18,7 @@ pub fn erf(x: f16x16) -> f16x16 { if x.abs() < MAX_ERF_NUMBER { erf_value = lut::erf_lut(x.abs()); } else { - erf_value = ONE; + erf_value = FixedTrait::ONE; } FixedTrait::mul(erf_value, x.sign()) @@ -27,7 +27,7 @@ pub fn erf(x: f16x16) -> f16x16 { // Tests // -// +// // -------------------------------------------------------------------------------------------------------------- #[cfg(test)] diff --git a/packages/orion-numbers/src/f16x16/helpers.cairo b/packages/orion-numbers/src/f16x16/helpers.cairo index 1fd6c02f8..4b7f580c2 100644 --- a/packages/orion-numbers/src/f16x16/helpers.cairo +++ b/packages/orion-numbers/src/f16x16/helpers.cairo @@ -1,4 +1,5 @@ -use orion_numbers::f16x16::core::{F16x16Impl, f16x16, ONE, HALF}; +use core::traits::PanicDestruct; +use orion_numbers::f16x16::core::{F16x16Impl, f16x16}; const DEFAULT_PRECISION: i32 = 7; // 1e-4 @@ -20,6 +21,26 @@ pub fn assert_precise( } } +pub fn assert_precise_span( + results: Span, expected: Span, msg: felt252, custom_precision: Option +) { + assert(results.len() == expected.len(), 'Arrays must have same length'); + + let mut i: usize = 0; + loop { + if i == results.len() { + break; + } + + let result = *results.at(i); + let expected_val = *expected.at(i); + + assert_precise(result, expected_val, msg, custom_precision); + + i += 1; + } +} + pub fn assert_relative( result: f16x16, expected: felt252, msg: felt252, custom_precision: Option ) { @@ -37,3 +58,22 @@ pub fn assert_relative( } } +pub fn assert_relative_span( + results: Span, expected: Span, msg: felt252, custom_precision: Option +) { + assert(results.len() == expected.len(), 'Arrays must have same length'); + + let mut i: usize = 0; + loop { + if i == results.len() { + break; + } + + let result = *results.at(i); + let expected_val = *expected.at(i); + + assert_relative(result, expected_val, msg, custom_precision); + + i += 1; + } +} diff --git a/packages/orion-numbers/src/f16x16/lut.cairo b/packages/orion-numbers/src/f16x16/lut.cairo index 73cbe732d..620248ec7 100644 --- a/packages/orion-numbers/src/f16x16/lut.cairo +++ b/packages/orion-numbers/src/f16x16/lut.cairo @@ -1,4 +1,4 @@ -use orion_numbers::f16x16::core::ONE; +use orion_numbers::FixedTrait; // Calculates the most significant bit pub fn msb(whole: i32) -> (i32, i32) { @@ -1928,5 +1928,5 @@ pub fn erf_lut(x: i32) -> i32 { } } - ONE + FixedTrait::ONE } diff --git a/packages/orion-numbers/src/f16x16/math.cairo b/packages/orion-numbers/src/f16x16/math.cairo index 31e152ec2..8e107d0b8 100644 --- a/packages/orion-numbers/src/f16x16/math.cairo +++ b/packages/orion-numbers/src/f16x16/math.cairo @@ -2,7 +2,7 @@ use core::option::OptionTrait; use core::traits::TryInto; use core::integer; use core::num::traits::{WideMul, Sqrt}; -use orion_numbers::f16x16::{core::{F16x16Impl, f16x16, ONE, HALF}, lut}; +use orion_numbers::f16x16::{core::{F16x16Impl, f16x16}, lut}; pub fn abs(a: f16x16) -> f16x16 { if a >= 0 { @@ -22,18 +22,18 @@ pub fn sub(a: f16x16, b: f16x16) -> f16x16 { pub fn ceil(a: f16x16) -> f16x16 { //let (div, rem) = DivRem::div_rem(a, ONE.try_into().unwrap()); - let div = Div::div(a, ONE); - let rem = Rem::rem(a, ONE); + let div = Div::div(a, F16x16Impl::ONE); + let rem = Rem::rem(a, F16x16Impl::ONE); if rem == 0 { F16x16Impl::new_unscaled(div) } else { - F16x16Impl::new_unscaled(div) + ONE + F16x16Impl::new_unscaled(div) + F16x16Impl::ONE } } pub fn div(a: f16x16, b: f16x16) -> f16x16 { - let a_i64 = WideMul::wide_mul(a, ONE); + let a_i64 = WideMul::wide_mul(a, F16x16Impl::ONE); let res_i64 = a_i64 / b.into(); // Re-apply sign @@ -48,12 +48,12 @@ pub fn exp(a: f16x16) -> f16x16 { // Calculates the binary exponent of x: 2^x pub fn exp2(a: f16x16) -> f16x16 { if (a == 0) { - return F16x16Impl::ONE(); + return F16x16Impl::ONE; } //let (int_part, frac_part) = DivRem::div_rem(a.abs(), ONE.try_into().unwrap()); - let int_part = Div::div(a.abs(), ONE); - let frac_part = Rem::rem(a.abs(), ONE); + let int_part = Div::div(a.abs(), F16x16Impl::ONE); + let frac_part = Rem::rem(a.abs(), F16x16Impl::ONE); let int_res = F16x16Impl::new_unscaled(lut::exp2(int_part)); let mut res_u = int_res; @@ -67,11 +67,11 @@ pub fn exp2(a: f16x16) -> f16x16 { let r3 = F16x16Impl::mul((r4 + F16x16Impl::new(3638)), frac); let r2 = F16x16Impl::mul((r3 + F16x16Impl::new(15743)), frac); let r1 = F16x16Impl::mul((r2 + F16x16Impl::new(45426)), frac); - res_u = F16x16Impl::mul(res_u, (r1 + F16x16Impl::ONE())); + res_u = F16x16Impl::mul(res_u, (r1 + F16x16Impl::ONE)); } if a < 0 { - F16x16Impl::div(F16x16Impl::ONE(), res_u) + F16x16Impl::div(F16x16Impl::ONE, res_u) } else { res_u } @@ -83,8 +83,8 @@ fn exp2_int(exp: i32) -> f16x16 { pub fn floor(a: f16x16) -> f16x16 { //let (div, rem) = DivRem::div_rem(a, ONE.try_into().unwrap()); - let div = Div::div(a, ONE); - let rem = Rem::rem(a, ONE); + let div = Div::div(a, F16x16Impl::ONE); + let rem = Rem::rem(a, F16x16Impl::ONE); if rem == 0 { a @@ -106,18 +106,18 @@ pub fn ln(a: f16x16) -> f16x16 { pub fn log2(a: f16x16) -> f16x16 { assert(a >= 0, 'must be positive'); - if (a == ONE) { - return F16x16Impl::ZERO(); - } else if (a < ONE) { + if (a == F16x16Impl::ONE) { + return F16x16Impl::ZERO; + } else if (a < F16x16Impl::ONE) { // Compute true inverse binary log if 0 < x < 1 - let div = F16x16Impl::div(F16x16Impl::ONE(), a); + let div = F16x16Impl::div(F16x16Impl::ONE, a); return -log2(div); } - let whole = a / ONE; + let whole = a / F16x16Impl::ONE; let (msb, div) = lut::msb(whole); - if a == div * ONE { + if a == div * F16x16Impl::ONE { F16x16Impl::new_unscaled(msb) } else { let norm = F16x16Impl::div(a, F16x16Impl::new_unscaled(div)); @@ -144,18 +144,18 @@ pub fn mul(a: f16x16, b: f16x16) -> f16x16 { let prod_i64 = WideMul::wide_mul(a, b); // Re-apply sign - F16x16Impl::new((prod_i64 / ONE.into()).try_into().unwrap()) + F16x16Impl::new((prod_i64 / F16x16Impl::ONE.into()).try_into().unwrap()) } // Calclates the value of x^y and checks for overflow before returning // self is a FP16x16 point value // b is a FP16x16 point value pub fn pow(a: f16x16, b: f16x16) -> f16x16 { - let rem = Rem::rem(b, ONE); + let rem = Rem::rem(b, F16x16Impl::ONE); // use the more performant integer pow when y is an int if (rem == 0) { - return pow_int(a, b / ONE); + return pow_int(a, b / F16x16Impl::ONE); } // x^y = exp(y*ln(x)) for x > 0 will error for x < 0 @@ -168,14 +168,14 @@ fn pow_int(a: f16x16, b: i32) -> f16x16 { let mut n = b.abs(); if b < 0 { - x = F16x16Impl::div(ONE, x); + x = F16x16Impl::div(F16x16Impl::ONE, x); } if n == 0 { - return ONE; + return F16x16Impl::ONE; } - let mut y = ONE; + let mut y = F16x16Impl::ONE; let two: i32 = 2; while n > 1 { @@ -196,10 +196,10 @@ fn pow_int(a: f16x16, b: i32) -> f16x16 { pub fn round(a: f16x16) -> f16x16 { //let (div, rem) = DivRem::div_rem(a, ONE.try_into().unwrap()); - let div = Div::div(a, ONE); - let rem = Rem::rem(a, ONE); + let div = Div::div(a, F16x16Impl::ONE); + let rem = Rem::rem(a, F16x16Impl::ONE); - if (HALF <= rem) { + if (F16x16Impl::HALF <= rem) { F16x16Impl::new_unscaled(div + 1) } else { F16x16Impl::new_unscaled(div) @@ -210,9 +210,9 @@ pub fn sign(a: f16x16) -> f16x16 { if a == 0 { F16x16Impl::new(0) } else if a > 0 { - ONE + F16x16Impl::ONE } else { - -ONE + -F16x16Impl::ONE } } @@ -222,7 +222,7 @@ pub fn sqrt(a: f16x16) -> f16x16 { assert(a >= 0, 'must be positive'); let a: u64 = a.try_into().unwrap(); - let one: u64 = ONE.try_into().unwrap(); + let one: u64 = F16x16Impl::ONE.try_into().unwrap(); let root: u32 = Sqrt::sqrt(a * one); @@ -240,20 +240,20 @@ mod tests { use orion_numbers::f16x16::helpers::{assert_precise, assert_relative}; use super::{ - F16x16Impl, ONE, HALF, f16x16, integer, lut, ceil, add, sqrt, floor, exp, exp2, exp2_int, - ln, log2, log10, pow, round, sign + F16x16Impl, f16x16, integer, lut, ceil, add, sqrt, floor, exp, exp2, exp2_int, ln, log2, + log10, pow, round, sign }; #[test] fn test_into() { let a = F16x16Impl::new_unscaled(5); - assert(a == 5 * ONE, 'invalid result'); + assert(a == 5 * F16x16Impl::ONE, 'invalid result'); } #[test] fn test_ceil() { let a = F16x16Impl::new(190054); // 2.9 - assert(ceil(a) == 3 * ONE, 'invalid pos decimal'); + assert(ceil(a) == 3 * F16x16Impl::ONE, 'invalid pos decimal'); } #[test] @@ -276,7 +276,7 @@ mod tests { #[test] fn test_floor() { let a = F16x16Impl::new(190054); // 2.9 - assert(floor(a) == 2 * ONE, 'invalid pos decimal'); + assert(floor(a) == 2 * F16x16Impl::ONE, 'invalid pos decimal'); } #[test] @@ -285,7 +285,7 @@ mod tests { assert(ln(a) == 0, 'invalid ln of 1'); a = F16x16Impl::new(178145); - assert_relative(ln(a), ONE.into(), 'invalid ln of 2.7...', Option::None(())); + assert_relative(ln(a), F16x16Impl::ONE.into(), 'invalid ln of 2.7...', Option::None(())); } #[test] @@ -300,7 +300,7 @@ mod tests { #[test] fn test_log10() { let a = F16x16Impl::new_unscaled(100); - assert_relative(log10(a), 2 * ONE.into(), 'invalid log10', Option::None(())); + assert_relative(log10(a), 2 * F16x16Impl::ONE.into(), 'invalid log10', Option::None(())); } @@ -308,7 +308,7 @@ mod tests { fn test_pow() { let a = F16x16Impl::new_unscaled(3); let b = F16x16Impl::new_unscaled(4); - assert(pow(a, b) == 81 * ONE, 'invalid pos base power'); + assert(pow(a, b) == 81 * F16x16Impl::ONE, 'invalid pos base power'); } #[test] @@ -323,7 +323,7 @@ mod tests { #[test] fn test_round() { let a = F16x16Impl::new(190054); // 2.9 - assert(round(a) == 3 * ONE, 'invalid pos decimal'); + assert(round(a) == 3 * F16x16Impl::ONE, 'invalid pos decimal'); } @@ -333,7 +333,7 @@ mod tests { assert(sqrt(a) == 0, 'invalid zero root'); a = F16x16Impl::new_unscaled(25); - assert(sqrt(a) == 5 * ONE, 'invalid pos root'); + assert(sqrt(a) == 5 * F16x16Impl::ONE, 'invalid pos root'); } #[test] @@ -348,23 +348,23 @@ mod tests { let a = F16x16Impl::new(0); assert(a.sign() == 0, 'invalid sign (0)'); - let a = F16x16Impl::new(-HALF); - assert(a.sign() == -ONE, 'invalid sign (-HALF)'); + let a = F16x16Impl::new(-F16x16Impl::HALF); + assert(a.sign() == -F16x16Impl::ONE, 'invalid sign (-HALF)'); - let a = F16x16Impl::new(HALF); - assert(a.sign() == ONE, 'invalid sign (HALF)'); + let a = F16x16Impl::new(F16x16Impl::HALF); + assert(a.sign() == F16x16Impl::ONE, 'invalid sign (HALF)'); - let a = F16x16Impl::new(-ONE); - assert(a.sign() == -ONE, 'invalid sign (-ONE)'); + let a = F16x16Impl::new(-F16x16Impl::ONE); + assert(a.sign() == -F16x16Impl::ONE, 'invalid sign (-ONE)'); - let a = F16x16Impl::new(ONE); - assert(a.sign() == ONE, 'invalid sign (ONE)'); + let a = F16x16Impl::new(F16x16Impl::ONE); + assert(a.sign() == F16x16Impl::ONE, 'invalid sign (ONE)'); } #[test] fn test_msb() { let a = F16x16Impl::new_unscaled(100); - let (msb, div) = lut::msb(a / ONE); + let (msb, div) = lut::msb(a / F16x16Impl::ONE); assert(msb == 6, 'invalid msb'); assert(div == 64, 'invalid msb ceil'); } diff --git a/packages/orion-numbers/src/f16x16/trig.cairo b/packages/orion-numbers/src/f16x16/trig.cairo index 60606d59c..ddc7ad47c 100644 --- a/packages/orion-numbers/src/f16x16/trig.cairo +++ b/packages/orion-numbers/src/f16x16/trig.cairo @@ -1,5 +1,5 @@ use core::integer; -use orion_numbers::f16x16::{core::{f16x16, ONE, HALF, TWO}, lut}; +use orion_numbers::f16x16::{core::{f16x16, F16x16Impl}, lut}; use orion_numbers::FixedTrait; // CONSTANTS @@ -12,7 +12,7 @@ const HALF_PI: i32 = 102944; // Calculates arccos(a) for -1 <= a <= 1 (fixed point) // arccos(a) = arcsin(sqrt(1 - a^2)) - arctan identity has discontinuity at zero pub fn acos_fast(a: f16x16) -> f16x16 { - let asin_arg = (FixedTrait::ONE() - FixedTrait::mul(a, a)).sqrt(); // will fail if a > 1 + let asin_arg = (F16x16Impl::ONE - FixedTrait::mul(a, a)).sqrt(); // will fail if a > 1 let asin_res = asin_fast(asin_arg); if a < 0 { @@ -26,15 +26,15 @@ pub fn acos_fast(a: f16x16) -> f16x16 { // Calculates arcsin(a) for -1 <= a <= 1 (fixed point) // arcsin(a) = arctan(a / sqrt(1 - a^2)) pub fn asin_fast(a: f16x16) -> f16x16 { - if (a == ONE) { + if (a == F16x16Impl::ONE) { return FixedTrait::new(HALF_PI); } - if (a == -ONE) { + if (a == -F16x16Impl::ONE) { return FixedTrait::new(-HALF_PI); } - let div = (FixedTrait::ONE() - FixedTrait::mul(a, a)).sqrt(); // will fail if a > 1 + let div = (F16x16Impl::ONE - FixedTrait::mul(a, a)).sqrt(); // will fail if a > 1 atan_fast(FixedTrait::div(a, div)) } @@ -47,15 +47,15 @@ pub fn atan_fast(a: f16x16) -> f16x16 { let mut invert = false; // Invert value when a > 1 - if (at > ONE) { - at = FixedTrait::div(FixedTrait::ONE(), at); + if (at > F16x16Impl::ONE) { + at = FixedTrait::div(F16x16Impl::ONE, at); invert = true; } // Account for lack of precision in polynomaial when a > 0.7 if (at > 45875) { let sqrt3_3 = FixedTrait::new(37837); // sqrt(3) / 3 - at = FixedTrait::div(at - sqrt3_3, FixedTrait::ONE() + FixedTrait::mul(at, sqrt3_3)); + at = FixedTrait::div(at - sqrt3_3, F16x16Impl::ONE + FixedTrait::mul(at, sqrt3_3)); shift = true; } @@ -119,44 +119,44 @@ pub fn tan_fast(a: f16x16) -> f16x16 { // Calculates inverse hyperbolic cosine of a (fixed point) pub fn acosh(a: f16x16) -> f16x16 { - let root = (FixedTrait::mul(a, a) - FixedTrait::ONE()).sqrt(); + let root = (FixedTrait::mul(a, a) - FixedTrait::ONE).sqrt(); (a + root).ln() } // Calculates inverse hyperbolic sine of a (fixed point) pub fn asinh(a: f16x16) -> f16x16 { - let root = (FixedTrait::mul(a, a) + FixedTrait::ONE()).sqrt(); + let root = (FixedTrait::mul(a, a) + FixedTrait::ONE).sqrt(); (a + root).ln() } // Calculates inverse hyperbolic tangent of a (fixed point) pub fn atanh(a: f16x16) -> f16x16 { - let one = FixedTrait::ONE(); + let one = FixedTrait::ONE; let ln_arg = FixedTrait::div((one + a), (one - a)); - FixedTrait::div(ln_arg.ln(), FixedTrait::new(TWO)) + FixedTrait::div(ln_arg.ln(), FixedTrait::new(FixedTrait::TWO)) } // Calculates hyperbolic cosine of a (fixed point) pub fn cosh(a: f16x16) -> f16x16 { let ea = a.exp(); - FixedTrait::div((ea + FixedTrait::div(FixedTrait::ONE(), ea)), FixedTrait::new(TWO)) + FixedTrait::div((ea + FixedTrait::div(FixedTrait::ONE, ea)), FixedTrait::new(FixedTrait::TWO)) } // Calculates hyperbolic sine of a (fixed point) pub fn sinh(a: f16x16) -> f16x16 { let ea = a.exp(); - FixedTrait::div((ea - FixedTrait::div(FixedTrait::ONE(), ea)), FixedTrait::new(TWO)) + FixedTrait::div((ea - FixedTrait::div(FixedTrait::ONE, ea)), FixedTrait::new(FixedTrait::TWO)) } // Calculates hyperbolic tangent of a (fixed point) pub fn tanh(a: f16x16) -> f16x16 { let ea = a.exp(); - let ea_i = FixedTrait::div(FixedTrait::ONE(), ea); + let ea_i = FixedTrait::div(FixedTrait::ONE, ea); FixedTrait::div((ea - ea_i), (ea + ea_i)) } @@ -164,7 +164,7 @@ pub fn tanh(a: f16x16) -> f16x16 { // Tests // -// +// // -------------------------------------------------------------------------------------------------------------- #[cfg(test)] @@ -172,34 +172,34 @@ mod tests { use orion_numbers::f16x16::helpers::{assert_precise, assert_relative}; use super::{ - FixedTrait, PI, HALF_PI, ONE, HALF, TWO, acos_fast, atan_fast, asin_fast, cos_fast, - sin_fast, tan_fast, acosh, asinh, atanh, cosh, sinh, tanh + FixedTrait, F16x16Impl, PI, HALF_PI, acos_fast, atan_fast, asin_fast, cos_fast, sin_fast, + tan_fast, acosh, asinh, atanh, cosh, sinh, tanh }; #[test] fn test_acos_fast() { let error = Option::Some(84); // 1e-5 - let a = FixedTrait::ONE(); + let a = FixedTrait::ONE; assert(acos_fast(a).into() == 0, 'invalid one'); - let a = FixedTrait::new(ONE / 2); + let a = FixedTrait::new(FixedTrait::ONE / 2); assert_relative(acos_fast(a), 68629, 'invalid half', error); // 1.3687308642680 - let a = FixedTrait::ZERO(); + let a = FixedTrait::ZERO; assert_relative(acos_fast(a), HALF_PI.into(), 'invalid zero', Option::None(())); // PI / 2 - let a = FixedTrait::new(-ONE / 2); + let a = FixedTrait::new(-FixedTrait::ONE / 2); assert_relative(acos_fast(a), 137258, 'invalid neg half', error); // 2.737461741902 - let a = FixedTrait::new(-ONE); + let a = FixedTrait::new(-FixedTrait::ONE); assert_relative(acos_fast(a), PI.into(), 'invalid neg one', Option::None(())); // PI } #[test] #[should_panic] fn test_acos_fail() { - let a = FixedTrait::new(2 * ONE); + let a = FixedTrait::new(2 * FixedTrait::ONE); acos_fast(a); } @@ -207,25 +207,25 @@ mod tests { fn test_atan_fast() { let error = Option::Some(84); // 1e-5 - let a = FixedTrait::new(2 * ONE); + let a = FixedTrait::new(2 * FixedTrait::ONE); assert_relative(atan_fast(a), 72558, 'invalid two', error); - let a = FixedTrait::ONE(); + let a = FixedTrait::ONE; assert_relative(atan_fast(a), 51472, 'invalid one', error); - let a = FixedTrait::new(ONE / 2); + let a = FixedTrait::new(FixedTrait::ONE / 2); assert_relative(atan_fast(a), 30386, 'invalid half', error); - let a = FixedTrait::ZERO(); + let a = FixedTrait::ZERO; assert(atan_fast(a).into() == 0, 'invalid zero'); - let a = FixedTrait::new(-ONE / 2); + let a = FixedTrait::new(-FixedTrait::ONE / 2); assert_relative(atan_fast(a), -30386, 'invalid neg half', error); - let a = FixedTrait::new(-ONE); + let a = FixedTrait::new(-FixedTrait::ONE); assert_relative(atan_fast(a), -51472, 'invalid neg one', error); - let a = FixedTrait::new(-2 * ONE); + let a = FixedTrait::new(-2 * FixedTrait::ONE); assert_relative(atan_fast(a), -72558, 'invalid neg two', error); } @@ -233,19 +233,19 @@ mod tests { fn test_asin() { let error = Option::Some(84); // 1e-5 - let a = FixedTrait::ONE(); + let a = FixedTrait::ONE; assert_relative(asin_fast(a), HALF_PI.into(), 'invalid one', Option::None(())); // PI / 2 - let a = FixedTrait::new(ONE / 2); + let a = FixedTrait::new(FixedTrait::ONE / 2); assert_relative(asin_fast(a), 34315, 'invalid half', error); - let a = FixedTrait::ZERO(); + let a = FixedTrait::ZERO; assert_precise(asin_fast(a), 0, 'invalid zero', Option::None(())); - let a = FixedTrait::new(-ONE / 2); + let a = FixedTrait::new(-FixedTrait::ONE / 2); assert_relative(asin_fast(a), -34315, 'invalid neg half', error); - let a = FixedTrait::new(-ONE); + let a = FixedTrait::new(-FixedTrait::ONE); assert_relative( asin_fast(a), -HALF_PI.into(), 'invalid neg one', Option::None(()) ); // -PI / 2 @@ -254,7 +254,7 @@ mod tests { #[test] #[should_panic] fn test_asin_fail() { - let a = FixedTrait::new(2 * ONE); + let a = FixedTrait::new(2 * FixedTrait::ONE); asin_fast(a); } @@ -269,7 +269,7 @@ mod tests { assert_precise(cos_fast(a), 46341, 'invalid quarter pi', error); // 0.55242717280199 let a = FixedTrait::new(PI); - assert_precise(cos_fast(a), -1 * ONE.into(), 'invalid pi', error); + assert_precise(cos_fast(a), -1 * F16x16Impl::ONE.into(), 'invalid pi', error); let a = FixedTrait::new(HALF_PI); assert_precise(cos_fast(a), 0, 'invalid neg half pi', Option::None(())); @@ -283,7 +283,7 @@ mod tests { let error = Option::Some(84); // 1e-5 let a = FixedTrait::new(HALF_PI); - assert_precise(sin_fast(a), ONE.into(), 'invalid half pi', error); + assert_precise(sin_fast(a), F16x16Impl::ONE.into(), 'invalid half pi', error); let a = FixedTrait::new(HALF_PI / 2); assert_precise(sin_fast(a), 46341, 'invalid quarter pi', error); // 0.55242717280199 @@ -292,7 +292,9 @@ mod tests { assert(sin_fast(a).into() == 0, 'invalid pi'); let a = FixedTrait::new(-HALF_PI); - assert_precise(sin_fast(a), -ONE.into(), 'invalid neg half pi', error); // 0.78124999999529 + assert_precise( + sin_fast(a), -F16x16Impl::ONE.into(), 'invalid neg half pi', error + ); // 0.78124999999529 let a = FixedTrait::new_unscaled(17); assert_precise(sin_fast(a), -63006, 'invalid 17', error); // -0.75109179053073 @@ -304,7 +306,7 @@ mod tests { #[test] fn test_tan_fast() { let a = FixedTrait::new(HALF_PI / 2); - assert_precise(tan_fast(a), ONE.into(), 'invalid quarter pi', Option::None(())); + assert_precise(tan_fast(a), F16x16Impl::ONE.into(), 'invalid quarter pi', Option::None(())); let a = FixedTrait::new(PI); assert_precise(tan_fast(a), 0, 'invalid pi', Option::None(())); @@ -323,9 +325,9 @@ mod tests { assert_precise(acosh(a), 131072, 'invalid two', Option::None(())); let a = FixedTrait::new(101127); // 1.42428174592510 - assert_precise(acosh(a), ONE.into(), 'invalid one', Option::None(())); + assert_precise(acosh(a), F16x16Impl::ONE.into(), 'invalid one', Option::None(())); - let a = FixedTrait::ONE(); // 1 + let a = FixedTrait::ONE; // 1 assert(acosh(a).into() == 0, 'invalid zero'); } @@ -335,13 +337,13 @@ mod tests { assert_precise(asinh(a), 131072, 'invalid two', Option::None(())); let a = FixedTrait::new(77018); // 1.13687593250230 - assert_precise(asinh(a), ONE.into(), 'invalid one', Option::None(())); + assert_precise(asinh(a), F16x16Impl::ONE.into(), 'invalid one', Option::None(())); - let a = FixedTrait::ZERO(); + let a = F16x16Impl::ZERO; assert(asinh(a).into() == 0, 'invalid zero'); let a = FixedTrait::new(-77018); // -1.13687593250230 - assert_precise(asinh(a), -ONE.into(), 'invalid neg one', Option::None(())); + assert_precise(asinh(a), -F16x16Impl::ONE.into(), 'invalid neg one', Option::None(())); let a = FixedTrait::new(-237690); // -3.48973469357602 assert_precise(asinh(a), -131017, 'invalid neg two', Option::None(())); @@ -352,13 +354,13 @@ mod tests { let a = FixedTrait::new(58982); // 0.9 assert_precise(atanh(a), 96483, 'invalid 0.9', Option::None(())); // 1.36892147623689 - let a = FixedTrait::new(HALF); // 0.5 + let a = FixedTrait::new(F16x16Impl::HALF); // 0.5 assert_precise(atanh(a), 35999, 'invalid half', Option::None(())); // 0.42914542526098 - let a = FixedTrait::ZERO(); + let a = FixedTrait::ZERO; assert(atanh(a).into() == 0, 'invalid zero'); - let a = FixedTrait::new(-HALF); // 0.5 + let a = FixedTrait::new(-F16x16Impl::HALF); // 0.5 assert_precise(atanh(a), -35999, 'invalid neg half', Option::None(())); // 0.42914542526098 let a = FixedTrait::new(-58982); // 0.9 @@ -367,55 +369,55 @@ mod tests { #[test] fn test_cosh() { - let a = FixedTrait::new(TWO); + let a = FixedTrait::new(FixedTrait::TWO); assert_precise(cosh(a), 246550, 'invalid two', Option::None(())); // 3.5954653836066 - let a = FixedTrait::ONE(); + let a = FixedTrait::ONE; assert_precise(cosh(a), 101127, 'invalid one', Option::None(())); // 1.42428174592510 - let a = FixedTrait::ZERO(); - assert_precise(cosh(a), ONE.into(), 'invalid zero', Option::None(())); + let a = FixedTrait::ZERO; + assert_precise(cosh(a), F16x16Impl::ONE.into(), 'invalid zero', Option::None(())); - let a = -FixedTrait::ONE(); + let a = -FixedTrait::ONE; assert_precise(cosh(a), 101127, 'invalid neg one', Option::None(())); // 1.42428174592510 - let a = FixedTrait::new(-TWO); + let a = FixedTrait::new(-FixedTrait::TWO); assert_precise(cosh(a), 246568, 'invalid neg two', Option::None(())); // 3.5954653836066 } #[test] fn test_sinh() { - let a = FixedTrait::new(TWO); + let a = FixedTrait::new(FixedTrait::TWO); assert_precise(sinh(a), 237681, 'invalid two', Option::None(())); // 3.48973469357602 - let a = FixedTrait::ONE(); + let a = FixedTrait::ONE; assert_precise(sinh(a), 77018, 'invalid one', Option::None(())); // 1.13687593250230 - let a = FixedTrait::ZERO(); + let a = FixedTrait::ZERO; assert(sinh(a).into() == 0, 'invalid zero'); - let a = FixedTrait::new(-ONE); + let a = FixedTrait::new(-F16x16Impl::ONE); assert_precise(sinh(a), -77018, 'invalid neg one', Option::None(())); // -1.13687593250230 - let a = FixedTrait::new(-TWO); + let a = FixedTrait::new(-FixedTrait::TWO); assert_precise(sinh(a), -237699, 'invalid neg two', Option::None(())); // -3.48973469357602 } #[test] fn test_tanh() { - let a = FixedTrait::new(TWO); + let a = FixedTrait::new(FixedTrait::TWO); assert_precise(tanh(a), 63179, 'invalid two', Option::None(())); // 0.75314654693321 - let a = FixedTrait::ONE(); + let a = FixedTrait::ONE; assert_precise(tanh(a), 49912, 'invalid one', Option::None(())); // 0.59499543433175 - let a = FixedTrait::ZERO(); + let a = FixedTrait::ZERO; assert(tanh(a).into() == 0, 'invalid zero'); - let a = FixedTrait::new(-ONE); + let a = FixedTrait::new(-F16x16Impl::ONE); assert_precise(tanh(a), -49912, 'invalid neg one', Option::None(())); // -0.59499543433175 - let a = FixedTrait::new(-TWO); + let a = FixedTrait::new(-FixedTrait::TWO); assert_precise(tanh(a), -63179, 'invalid neg two', Option::None(())); // 0.75314654693321 } } diff --git a/packages/orion-numbers/src/f32x32/core.cairo b/packages/orion-numbers/src/f32x32/core.cairo index 70b2e63d6..8fbf8a934 100644 --- a/packages/orion-numbers/src/f32x32/core.cairo +++ b/packages/orion-numbers/src/f32x32/core.cairo @@ -3,37 +3,18 @@ use orion_numbers::FixedTrait; pub type f32x32 = i64; -// CONSTANTS -pub const TWO: f32x32 = 8589934592; // 2 ** 33 -pub const ONE: f32x32 = 4294967296; // 2 ** 32 -pub const HALF: f32x32 = 2147483648; // 2 ** 31 -pub const MAX: f32x32 = 9223372036854775807; // 2 ** 63 -1 -pub const MIN: f32x32 = -9223372036854775808; // -2 ** 63 - pub impl F32x32Impl of FixedTrait { - fn ZERO() -> f32x32 { - 0 - } - - fn HALF() -> f32x32 { - HALF - } - - fn ONE() -> f32x32 { - ONE - } - - fn MAX() -> f32x32 { - MAX - } - - fn MIN() -> f32x32 { - MIN - } + // CONSTANTS + const ZERO: f32x32 = 0; + const HALF: f32x32 = 2147483648; // 2 ** 31 + const ONE: f32x32 = 4294967296; // 2 ** 32 + const TWO: f32x32 = 8589934592; // 2 ** 33 + const MAX: f32x32 = 9223372036854775807; // 2 ** 63 -1 + const MIN: f32x32 = -9223372036854775808; // -2 ** 63 fn new_unscaled(x: i64) -> f32x32 { - x * ONE + x * Self::ONE } fn new(x: i64) -> f32x32 { @@ -45,7 +26,7 @@ pub impl F32x32Impl of FixedTrait { } fn from_unscaled_felt(x: felt252) -> f32x32 { - return FixedTrait::from_felt(x * ONE.into()); + return FixedTrait::from_felt(x * Self::ONE.into()); } fn abs(self: f32x32) -> f32x32 { @@ -170,32 +151,30 @@ pub impl F32x32Impl of FixedTrait { } fn INF() -> f32x32 { - MAX + Self::MAX } fn POS_INF() -> f32x32 { - MAX + Self::MAX } fn NEG_INF() -> f32x32 { - MIN + Self::MIN } fn is_inf(self: f32x32) -> bool { - self == MAX + self == Self::MAX } fn is_pos_inf(self: f32x32) -> bool { - self == MAX + self == Self::MAX } fn is_neg_inf(self: f32x32) -> bool { - self == MIN + self == Self::MIN } fn erf(self: f32x32) -> f32x32 { panic!("not implem yet") } - - } diff --git a/packages/orion-numbers/src/f32x32/helpers.cairo b/packages/orion-numbers/src/f32x32/helpers.cairo index 4ba229ef4..0e8ba87ba 100644 --- a/packages/orion-numbers/src/f32x32/helpers.cairo +++ b/packages/orion-numbers/src/f32x32/helpers.cairo @@ -1,4 +1,4 @@ -use orion_numbers::f32x32::core::{F32x32Impl, f32x32, ONE, HALF}; +use orion_numbers::f32x32::core::{F32x32Impl, f32x32}; const DEFAULT_PRECISION: i64 = 429497; // 1e-4 @@ -20,6 +20,29 @@ pub fn assert_precise( } } +pub fn assert_precise_span( + results: Span, expected: Span, msg: felt252, custom_precision: Option +) { + assert(results.len() == expected.len(), 'Arrays must have same length'); + + println!("results: {:?}", results); + println!("expected: {:?}", expected); + + let mut i: usize = 0; + loop { + if i == results.len() { + break; + } + + let result = *results.at(i); + let expected_val = *expected.at(i); + + assert_precise(result, expected_val, msg, custom_precision); + + i += 1; + } +} + pub fn assert_relative( result: f32x32, expected: felt252, msg: felt252, custom_precision: Option ) { @@ -37,3 +60,26 @@ pub fn assert_relative( } } + +pub fn assert_relative_span( + results: Span, expected: Span, msg: felt252, custom_precision: Option +) { + assert(results.len() == expected.len(), 'Arrays must have same length'); + + println!("results: {:?}", results); + println!("expected: {:?}", expected); + + let mut i: usize = 0; + loop { + if i == results.len() { + break; + } + + let result = *results.at(i); + let expected_val = *expected.at(i); + + assert_relative(result, expected_val, msg, custom_precision); + + i += 1; + } +} diff --git a/packages/orion-numbers/src/f32x32/math.cairo b/packages/orion-numbers/src/f32x32/math.cairo index 2afb1d4df..3da976fd5 100644 --- a/packages/orion-numbers/src/f32x32/math.cairo +++ b/packages/orion-numbers/src/f32x32/math.cairo @@ -1,6 +1,6 @@ use core::integer; use core::num::traits::{WideMul, Sqrt}; -use orion_numbers::f32x32::core::{F32x32Impl, f32x32, ONE, HALF}; +use orion_numbers::f32x32::core::{F32x32Impl, f32x32}; @@ -13,7 +13,7 @@ pub fn abs(a: f32x32) -> f32x32 { } pub fn div(a: f32x32, b: f32x32) -> f32x32 { - let a_i128 = WideMul::wide_mul(a, ONE); + let a_i128 = WideMul::wide_mul(a, F32x32Impl::ONE); let res_i128 = a_i128 / b.into(); // Re-apply sign @@ -24,15 +24,15 @@ pub fn mul(a: f32x32, b: f32x32) -> f32x32 { let prod_i128 = WideMul::wide_mul(a, b); // Re-apply sign - F32x32Impl::new((prod_i128 / ONE.into()).try_into().unwrap()) + F32x32Impl::new((prod_i128 / F32x32Impl::ONE.into()).try_into().unwrap()) } pub fn round(a: f32x32) -> f32x32 { - //let (div, rem) = DivRem::div_rem(a, ONE.try_into().unwrap()); - let div = Div::div(a, ONE); - let rem = Rem::rem(a, ONE); + //let (div, rem) = DivRem::div_rem(a, F32x32Impl::ONE.try_into().unwrap()); + let div = Div::div(a, F32x32Impl::ONE); + let rem = Rem::rem(a, F32x32Impl::ONE); - if (HALF <= rem) { + if (F32x32Impl::HALF <= rem) { F32x32Impl::new_unscaled(div + 1) } else { F32x32Impl::new_unscaled(div) @@ -43,9 +43,9 @@ pub fn sign(a: f32x32) -> f32x32 { if a == 0 { F32x32Impl::new(0) } else if a > 0 { - ONE + F32x32Impl::ONE } else { - -ONE + -F32x32Impl::ONE } } @@ -58,7 +58,7 @@ pub fn sign(a: f32x32) -> f32x32 { mod tests { use orion_numbers::f32x32::helpers::{assert_precise, assert_relative}; - use super::{F32x32Impl, ONE, HALF, f32x32, integer, round, sign}; + use super::{F32x32Impl, f32x32, integer, round, sign}; #[test] @@ -66,16 +66,16 @@ mod tests { let a = F32x32Impl::new(0); assert(a.sign() == 0, 'invalid sign (0)'); - let a = F32x32Impl::new(-HALF); - assert(a.sign() == -ONE, 'invalid sign (-HALF)'); + let a = F32x32Impl::new(-F32x32Impl::HALF); + assert(a.sign() == -F32x32Impl::ONE, 'invalid sign (-HALF)'); - let a = F32x32Impl::new(HALF); - assert(a.sign() == ONE, 'invalid sign (HALF)'); + let a = F32x32Impl::new(F32x32Impl::HALF); + assert(a.sign() == F32x32Impl::ONE, 'invalid sign (HALF)'); - let a = F32x32Impl::new(-ONE); - assert(a.sign() == -ONE, 'invalid sign (-ONE)'); + let a = F32x32Impl::new(-F32x32Impl::ONE); + assert(a.sign() == -F32x32Impl::ONE, 'invalid sign (-ONE)'); - let a = F32x32Impl::new(ONE); - assert(a.sign() == ONE, 'invalid sign (ONE)'); + let a = F32x32Impl::new(F32x32Impl::ONE); + assert(a.sign() == F32x32Impl::ONE, 'invalid sign (ONE)'); } } diff --git a/packages/orion-numbers/src/lib.cairo b/packages/orion-numbers/src/lib.cairo index 38a303fd3..36aed1679 100644 --- a/packages/orion-numbers/src/lib.cairo +++ b/packages/orion-numbers/src/lib.cairo @@ -2,16 +2,17 @@ pub mod f16x16; pub mod f32x32; pub mod core_trait; -use orion_numbers::f16x16::core::F16x16Impl; -use orion_numbers::f32x32::core::F32x32Impl; - +pub use orion_numbers::f16x16::core::F16x16Impl; +pub use orion_numbers::f32x32::core::F32x32Impl; pub trait FixedTrait { - fn ZERO() -> T; - fn HALF() -> T; - fn ONE() -> T; - fn MAX() -> T; - fn MIN() -> T; + const ZERO: T; + const HALF: T; + const ONE: T; + const TWO: T; + const MAX: T; + const MIN: T; + fn new_unscaled(x: T) -> T; fn new(x: T) -> T; fn from_felt(x: felt252) -> T;