diff --git a/packages/orion-numbers/.tool-versions b/packages/orion-numbers/.tool-versions index cfbafa68b..8f7adfd7b 100644 --- a/packages/orion-numbers/.tool-versions +++ b/packages/orion-numbers/.tool-versions @@ -1 +1 @@ -scarb nightly-2024-06-10 \ No newline at end of file +scarb 2.6.5 \ No newline at end of file diff --git a/packages/orion-numbers/src/f16x16.cairo b/packages/orion-numbers/src/f16x16.cairo index 17935277e..84b9046f0 100644 --- a/packages/orion-numbers/src/f16x16.cairo +++ b/packages/orion-numbers/src/f16x16.cairo @@ -3,4 +3,5 @@ pub mod math; pub mod trig; pub mod erf; pub mod helpers; -pub mod lut; \ No newline at end of file +pub mod lut; +pub mod core_trait; \ No newline at end of file diff --git a/packages/orion-numbers/src/f16x16/core.cairo b/packages/orion-numbers/src/f16x16/core.cairo index 2dd3027b2..5f7e03278 100644 --- a/packages/orion-numbers/src/f16x16/core.cairo +++ b/packages/orion-numbers/src/f16x16/core.cairo @@ -47,7 +47,7 @@ pub impl f16x16Impl of FixedTrait { } fn from_unscaled_felt(x: felt252) -> f16x16 { - return Self::from_felt(x * ONE.into()); + return FixedTrait::from_felt(x * ONE.into()); } fn abs(self: f16x16) -> f16x16 { diff --git a/packages/orion-numbers/src/f16x16/core_trait.cairo b/packages/orion-numbers/src/f16x16/core_trait.cairo new file mode 100644 index 000000000..06a47fee3 --- /dev/null +++ b/packages/orion-numbers/src/f16x16/core_trait.cairo @@ -0,0 +1,94 @@ +pub impl I32Div of Div { + fn div(lhs: i32, rhs: i32) -> i32 { + assert(rhs != 0, 'divisor cannot be 0'); + + let mut lhs_positive = lhs; + let mut rhs_positive = rhs; + + if lhs < 0 { + lhs_positive = lhs * -1; + } + if rhs < 0 { + rhs_positive = rhs * -1; + } + + let lhs_u32: u32 = lhs_positive.try_into().unwrap(); + let rhs_u32: u32 = rhs_positive.try_into().unwrap(); + + let mut result = lhs_u32 / rhs_u32; + let felt_result: felt252 = result.into(); + let signed_int_result: i32 = felt_result.try_into().unwrap(); + + // avoids mul overflow for f16x16 + if sign_i32(lhs) * rhs < 0 { + signed_int_result * -1 + } else { + signed_int_result + } + } +} + +pub impl I32Rem of Rem { + fn rem(lhs: i32, rhs: i32) -> i32 { + let div = Div::div(lhs, rhs); + lhs - rhs * div + } +} + + +pub impl I64Div of Div { + fn div(lhs: i64, rhs: i64) -> i64 { + assert(rhs != 0, 'divisor cannot be 0'); + + let mut lhs_positive = lhs; + let mut rhs_positive = rhs; + + if lhs < 0 { + lhs_positive = lhs * -1; + } + if rhs < 0 { + rhs_positive = rhs * -1; + } + + let lhs_u64: u64 = lhs_positive.try_into().unwrap(); + let rhs_u64: u64 = rhs_positive.try_into().unwrap(); + + let mut result = lhs_u64 / rhs_u64; + let felt_result: felt252 = result.into(); + let signed_int_result: i64 = felt_result.try_into().unwrap(); + + // avoids mul overflow for f16x16 + if sign_i64(lhs) * rhs < 0 { + signed_int_result * -1 + } else { + signed_int_result + } + } +} + +pub impl I64Rem of Rem { + fn rem(lhs: i64, rhs: i64) -> i64 { + let div = Div::div(lhs, rhs); + lhs - rhs * div + } +} + +pub fn sign_i32(a: i32) -> i32 { + if a == 0 { + 0 + } else if a > 0 { + 1 + } else { + -1 + } +} + +pub fn sign_i64(a: i64) -> i64 { + if a == 0 { + 0 + } else if a > 0 { + 1 + } else { + -1 + } +} diff --git a/packages/orion-numbers/src/f16x16/helpers.cairo b/packages/orion-numbers/src/f16x16/helpers.cairo index b4d6e13ec..f2e26fd5e 100644 --- a/packages/orion-numbers/src/f16x16/helpers.cairo +++ b/packages/orion-numbers/src/f16x16/helpers.cairo @@ -1,5 +1,7 @@ use orion_numbers::f16x16::core::{FixedTrait, f16x16, ONE, HALF}; +use orion_numbers::f16x16::core_trait::I32Div; + const DEFAULT_PRECISION: i32 = 7; // 1e-4 // To use `DEFAULT_PRECISION`, final arg is: `Option::None(())`. @@ -15,7 +17,7 @@ pub fn assert_precise( let diff = (result - FixedTrait::from_felt(expected)); if (diff > precision) { - println!("{}", result); + //println!("{}", result); assert(diff <= precision, msg); } } @@ -32,7 +34,7 @@ pub fn assert_relative( let rel_diff = diff / result; if (rel_diff > precision) { - println!("{}", result); + //println!("{}", result); assert(rel_diff <= precision, msg); } } diff --git a/packages/orion-numbers/src/f16x16/lut.cairo b/packages/orion-numbers/src/f16x16/lut.cairo index 73cbe732d..a834d4b5a 100644 --- a/packages/orion-numbers/src/f16x16/lut.cairo +++ b/packages/orion-numbers/src/f16x16/lut.cairo @@ -1,5 +1,7 @@ use orion_numbers::f16x16::core::ONE; +use orion_numbers::f16x16::core_trait::I32Div; + // Calculates the most significant bit pub fn msb(whole: i32) -> (i32, i32) { if whole < 256 { diff --git a/packages/orion-numbers/src/f16x16/math.cairo b/packages/orion-numbers/src/f16x16/math.cairo index 4861698ae..bc885f298 100644 --- a/packages/orion-numbers/src/f16x16/math.cairo +++ b/packages/orion-numbers/src/f16x16/math.cairo @@ -5,6 +5,8 @@ use core::integer; use orion_numbers::f16x16::{core::{FixedTrait, f16x16, ONE, HALF}, lut}; +use orion_numbers::f16x16::core_trait::{I32Rem, I32Div, I64Div}; //I32TryIntoNonZero, I32DivRem + pub fn abs(a: f16x16) -> f16x16 { if a >= 0 { @@ -23,7 +25,9 @@ pub fn sub(a: f16x16, b: f16x16) -> f16x16 { } pub fn ceil(a: f16x16) -> f16x16 { - let (div, rem) = DivRem::div_rem(a, ONE.try_into().unwrap()); + //let (div, rem) = DivRem::div_rem(a, ONE.try_into().unwrap()); + let div = Div::div(a, ONE); + let rem = Rem::rem(a, ONE); if rem == 0 { FixedTrait::new_unscaled(div) @@ -51,7 +55,10 @@ pub fn exp2(a: f16x16) -> f16x16 { return FixedTrait::ONE(); } - let (int_part, frac_part) = DivRem::div_rem(a.abs(), ONE.try_into().unwrap()); + //let (int_part, frac_part) = DivRem::div_rem(a.abs(), ONE.try_into().unwrap()); + let int_part = Div::div(a.abs(), ONE); + let frac_part = Rem::rem(a.abs(), ONE); + let int_res = FixedTrait::new_unscaled(lut::exp2(int_part)); let mut res_u = int_res; @@ -79,7 +86,9 @@ fn exp2_int(exp: i32) -> f16x16 { } pub fn floor(a: f16x16) -> f16x16 { - let (div, rem) = DivRem::div_rem(a, ONE.try_into().unwrap()); + //let (div, rem) = DivRem::div_rem(a, ONE.try_into().unwrap()); + let div = Div::div(a, ONE); + let rem = Rem::rem(a, ONE); if rem == 0 { a @@ -146,7 +155,7 @@ pub fn mul(a: f16x16, b: f16x16) -> f16x16 { // self is a FP16x16 point value // b is a FP16x16 point value pub fn pow(a: f16x16, b: f16x16) -> f16x16 { - let (_, rem) = DivRem::div_rem(b, ONE.try_into().unwrap()); + let rem = Rem::rem(b, ONE); // use the more performant integer pow when y is an int if (rem == 0) { @@ -174,7 +183,9 @@ fn pow_int(a: f16x16, b: i32) -> f16x16 { let two: i32 = 2; while n > 1 { - let (div, rem) = DivRem::div_rem(n, two.try_into().unwrap()); + //let (div, rem) = DivRem::div_rem(n, two.try_into().unwrap()); + let div = Div::div(n, two); + let rem = Rem::rem(n, two); if rem == 1 { y = FixedTrait::mul(x, y); @@ -188,7 +199,9 @@ fn pow_int(a: f16x16, b: i32) -> f16x16 { } pub fn round(a: f16x16) -> f16x16 { - let (div, rem) = DivRem::div_rem(a, ONE.try_into().unwrap()); + //let (div, rem) = DivRem::div_rem(a, ONE.try_into().unwrap()); + let div = Div::div(a, ONE); + let rem = Rem::rem(a, ONE); if (HALF <= rem) { FixedTrait::new_unscaled(div + 1) @@ -233,6 +246,8 @@ mod tests { ln, log2, log10, pow, round, sign }; + use orion_numbers::f16x16::core_trait::{I32Rem, I32Div}; + #[test] fn test_into() { let a = FixedTrait::new_unscaled(5); @@ -246,7 +261,7 @@ mod tests { } #[test] - #[available_gas(1000000)] + #[available_gas(10000000)] fn test_exp() { let a = FixedTrait::new_unscaled(2); assert_relative(exp(a), 484249, 'invalid exp of 2', Option::None(())); // 7.389056098793725 diff --git a/packages/orion-numbers/src/f16x16/trig.cairo b/packages/orion-numbers/src/f16x16/trig.cairo index c0f8a53d4..551218ff7 100644 --- a/packages/orion-numbers/src/f16x16/trig.cairo +++ b/packages/orion-numbers/src/f16x16/trig.cairo @@ -4,6 +4,8 @@ use core::integer; use orion_numbers::f16x16::core::{FixedTrait, f16x16, ONE, HALF, TWO}; use orion_numbers::f16x16::lut; +use orion_numbers::f16x16::core_trait::{I32Div, I32Rem}; + // CONSTANTS const TWO_PI: i32 = 411775; const PI: i32 = 205887; @@ -86,7 +88,10 @@ pub fn cos_fast(a: f16x16) -> f16x16 { pub fn sin_fast(a: f16x16) -> f16x16 { let a1 = a.abs() % TWO_PI; - let (whole_rem, mut partial_rem) = DivRem::div_rem(a1, PI.try_into().unwrap()); + //let (whole_rem, mut partial_rem) = DivRem::div_rem(a1, PI.try_into().unwrap()); + let whole_rem = Div::div(a1, PI); + let mut partial_rem = Rem::rem(a1, PI); + let partial_sign = whole_rem == 1; if partial_rem >= HALF_PI { @@ -175,6 +180,7 @@ mod tests { sin_fast, tan_fast, acosh, asinh, atanh, cosh, sinh, tanh }; + use orion_numbers::f16x16::core_trait::I32Div; #[test] #[available_gas(8000000)]