From c3c15817075f622aab758278bd8492ab7b013594 Mon Sep 17 00:00:00 2001 From: Thomas Marchand Date: Thu, 17 Oct 2024 18:36:01 +0100 Subject: [PATCH] feat: Lookup tables and type-specific optimizations (#269) --- .../consensus/src/validation/difficulty.cairo | 23 ++- packages/utils/src/bit_shifts.cairo | 154 +++++++++++++----- packages/utils/src/numeric.cairo | 14 +- 3 files changed, 132 insertions(+), 59 deletions(-) diff --git a/packages/consensus/src/validation/difficulty.cairo b/packages/consensus/src/validation/difficulty.cairo index 69fe952c..2beae115 100644 --- a/packages/consensus/src/validation/difficulty.cairo +++ b/packages/consensus/src/validation/difficulty.cairo @@ -4,7 +4,7 @@ //! - https://learnmeabitcoin.com/technical/mining/target/ //! - https://learnmeabitcoin.com/technical/block/bits/ -use utils::{bit_shifts::{shl, shr, fast_pow}}; +use utils::{bit_shifts::{shl, shr_u64, fast_pow}}; /// Maximum difficulty target allowed const MAX_TARGET: u256 = 0x00000000FFFF0000000000000000000000000000000000000000000000000000; @@ -101,28 +101,25 @@ fn bits_to_target(bits: u32) -> Result { } // Calculate the full target value - let mut target: u256 = mantissa.into(); - if exponent == 0 { // Special case: exponent 0 means we use the mantissa as-is - return Result::Ok(target); + return Result::Ok(mantissa.into()); } else if exponent <= 3 { // For exponents 1, 2, and 3, divide by 256^(3 - exponent) i.e right shift let shift = 8 * (3 - exponent); - target = shr(target, shift); + // MAX_TARGET > 2^128 so we can return early + return Result::Ok(shr_u64(mantissa.into(), shift).into()); } else if exponent <= 32 { let shift = 8 * (exponent - 3); - target = shl(target, shift); + let target = shl(mantissa.into(), shift); + // Ensure the target doesn't exceed the maximum allowed value + if target > MAX_TARGET { + return Result::Err("Target exceeds maximum value"); + } + Result::Ok(target) } else { return Result::Err("Target size cannot exceed 32 bytes"); } - - // Ensure the target doesn't exceed the maximum allowed value - if target > MAX_TARGET { - return Result::Err("Target exceeds maximum value"); - } - - Result::Ok(target) } #[cfg(test)] diff --git a/packages/utils/src/bit_shifts.cairo b/packages/utils/src/bit_shifts.cairo index b9b934c2..2fb3482d 100644 --- a/packages/utils/src/bit_shifts.cairo +++ b/packages/utils/src/bit_shifts.cairo @@ -34,38 +34,21 @@ pub fn shl< self * fast_pow(two, shift) } -/// Performs a bitwise right shift on the given value by a specified number of bits. -pub fn shr< - T, - U, - +Zero, - +Zero, - +One, - +One, - +Add, - +Add, - +Sub, - +Div, - +Mul, - +Div, - +Rem, - +Copy, - +Copy, - +Drop, - +Drop, - +PartialOrd, - +PartialEq, - +BitSize, - +Into ->( - self: T, shift: U -) -> T { - if shift > BitSize::::bits().try_into().unwrap() - One::one() { - return Zero::zero(); - } - - let two = One::one() + One::one(); - self / fast_pow(two, shift) +/// Performs a bitwise right shift on a u64 value by a specified number of bits. +/// This specialized version offers optimal performance for u64 types. +/// +/// # Arguments +/// * `self` - The u64 value to be shifted +/// * `shift` - The number of bits to shift right +/// +/// # Returns +/// * The result of the right shift operation +/// +/// # Panics +/// * If `shift` is greater than 63 (via pow2's range check on the lookup table) +#[inline(always)] +pub fn shr_u64(self: u64, shift: u32) -> u64 { + self / pow2(shift) } @@ -113,9 +96,89 @@ pub fn fast_pow< } } + +/// Fast power of 2 using lookup tables +/// Reference: https://github.com/keep-starknet-strange/alexandria/pull/336 +/// +/// # Arguments +/// * `exponent` - The exponent to raise 2 to +/// # Returns +/// * `u64` - The result of 2^exponent +/// # Panics +/// * If `exponent` is greater than 63 (out of the supported range) +pub fn pow2(exponent: u32) -> u64 { + let hardcoded_results: [u64; 64] = [ + 0x1, + 0x2, + 0x4, + 0x8, + 0x10, + 0x20, + 0x40, + 0x80, + 0x100, + 0x200, + 0x400, + 0x800, + 0x1000, + 0x2000, + 0x4000, + 0x8000, + 0x10000, + 0x20000, + 0x40000, + 0x80000, + 0x100000, + 0x200000, + 0x400000, + 0x800000, + 0x1000000, + 0x2000000, + 0x4000000, + 0x8000000, + 0x10000000, + 0x20000000, + 0x40000000, + 0x80000000, + 0x100000000, + 0x200000000, + 0x400000000, + 0x800000000, + 0x1000000000, + 0x2000000000, + 0x4000000000, + 0x8000000000, + 0x10000000000, + 0x20000000000, + 0x40000000000, + 0x80000000000, + 0x100000000000, + 0x200000000000, + 0x400000000000, + 0x800000000000, + 0x1000000000000, + 0x2000000000000, + 0x4000000000000, + 0x8000000000000, + 0x10000000000000, + 0x20000000000000, + 0x40000000000000, + 0x80000000000000, + 0x100000000000000, + 0x200000000000000, + 0x400000000000000, + 0x800000000000000, + 0x1000000000000000, + 0x2000000000000000, + 0x4000000000000000, + 0x8000000000000000, + ]; + *hardcoded_results.span()[exponent] +} + #[cfg(test)] mod tests { - use super::{fast_pow, shl, shr}; + use super::{fast_pow, pow2, shl, shr_u64}; #[test] #[available_gas(1000000000)] @@ -128,6 +191,18 @@ mod tests { assert_eq!(fast_pow(10_u128, 5_u128), 100000, "invalid result"); } + #[test] + #[available_gas(1000000000)] + fn test_pow2() { + assert_eq!(pow2(0), 1, "2^0 should be 1"); + assert_eq!(pow2(1), 2, "2^1 should be 2"); + assert_eq!(pow2(2), 4, "2^2 should be 4"); + assert_eq!(pow2(3), 8, "2^3 should be 8"); + assert_eq!(pow2(10), 1024, "2^10 should be 1024"); + assert_eq!(pow2(63), 0x8000000000000000, "2^63 should be 0x8000000000000000"); + assert_eq!(pow2(63), 0x8000000000000000, "2^64 should be 0x8000000000000000"); + } + #[test] fn test_shl() { let value1: u32 = 3; @@ -142,19 +217,20 @@ mod tests { } #[test] - fn test_shr() { - // Assuming T and U are u32 for simplicity - let x: u32 = 32; + fn test_shr_u64() { + // Expect about 15% steps reduction over previous test, + // should be much higher for bigger shifts + let x: u64 = 32; let shift: u32 = 2; - let result = shr(x, shift); + let result = shr_u64(x, shift); assert_eq!(result, 8); let shift: u32 = 32; - let result = shr(x, shift); + let result = shr_u64(x, shift); assert_eq!(result, 0); let shift: u32 = 0; - let result = shr(x, shift); + let result = shr_u64(x, shift); assert_eq!(result, 32); } } diff --git a/packages/utils/src/numeric.cairo b/packages/utils/src/numeric.cairo index 91f7d970..7e605399 100644 --- a/packages/utils/src/numeric.cairo +++ b/packages/utils/src/numeric.cairo @@ -1,4 +1,4 @@ -use crate::bit_shifts::shr; +use crate::bit_shifts::shr_u64; /// Reverses the byte order of a `u32`. /// @@ -19,12 +19,12 @@ pub fn u64_next_power_of_two(mut n: u64) -> u64 { } n -= 1; - n = n | shr(n, 1_u64); - n = n | shr(n, 2_u64); - n = n | shr(n, 4_u64); - n = n | shr(n, 8_u64); - n = n | shr(n, 16_u64); - n = n | shr(n, 32_u64); + n = n | shr_u64(n, 1_u32); + n = n | shr_u64(n, 2_u32); + n = n | shr_u64(n, 4_u32); + n = n | shr_u64(n, 8_u32); + n = n | shr_u64(n, 16_u32); + n = n | shr_u64(n, 32_u32); n + 1 }