Skip to content

Commit

Permalink
feat: Lookup tables and type-specific optimizations (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
Th0rgal authored Oct 17, 2024
1 parent 9ae2c17 commit c3c1581
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 59 deletions.
23 changes: 10 additions & 13 deletions packages/consensus/src/validation/difficulty.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! - https://learnmeabitcoin.com/technical/mining/target/
//! - https://learnmeabitcoin.com/technical/block/bits/

use utils::{bit_shifts::{shl, shr, fast_pow}};
use utils::{bit_shifts::{shl, shr_u64, fast_pow}};

/// Maximum difficulty target allowed
const MAX_TARGET: u256 = 0x00000000FFFF0000000000000000000000000000000000000000000000000000;
Expand Down Expand Up @@ -101,28 +101,25 @@ fn bits_to_target(bits: u32) -> Result<u256, ByteArray> {
}

// Calculate the full target value
let mut target: u256 = mantissa.into();

if exponent == 0 {
// Special case: exponent 0 means we use the mantissa as-is
return Result::Ok(target);
return Result::Ok(mantissa.into());
} else if exponent <= 3 {
// For exponents 1, 2, and 3, divide by 256^(3 - exponent) i.e right shift
let shift = 8 * (3 - exponent);
target = shr(target, shift);
// MAX_TARGET > 2^128 so we can return early
return Result::Ok(shr_u64(mantissa.into(), shift).into());
} else if exponent <= 32 {
let shift = 8 * (exponent - 3);
target = shl(target, shift);
let target = shl(mantissa.into(), shift);
// Ensure the target doesn't exceed the maximum allowed value
if target > MAX_TARGET {
return Result::Err("Target exceeds maximum value");
}
Result::Ok(target)
} else {
return Result::Err("Target size cannot exceed 32 bytes");
}

// Ensure the target doesn't exceed the maximum allowed value
if target > MAX_TARGET {
return Result::Err("Target exceeds maximum value");
}

Result::Ok(target)
}

#[cfg(test)]
Expand Down
154 changes: 115 additions & 39 deletions packages/utils/src/bit_shifts.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -34,38 +34,21 @@ pub fn shl<
self * fast_pow(two, shift)
}

/// Performs a bitwise right shift on the given value by a specified number of bits.
pub fn shr<
T,
U,
+Zero<T>,
+Zero<U>,
+One<T>,
+One<U>,
+Add<T>,
+Add<U>,
+Sub<U>,
+Div<T>,
+Mul<T>,
+Div<U>,
+Rem<U>,
+Copy<T>,
+Copy<U>,
+Drop<T>,
+Drop<U>,
+PartialOrd<U>,
+PartialEq<U>,
+BitSize<T>,
+Into<usize, U>
>(
self: T, shift: U
) -> T {
if shift > BitSize::<T>::bits().try_into().unwrap() - One::one() {
return Zero::zero();
}

let two = One::one() + One::one();
self / fast_pow(two, shift)
/// Performs a bitwise right shift on a u64 value by a specified number of bits.
/// This specialized version offers optimal performance for u64 types.
///
/// # Arguments
/// * `self` - The u64 value to be shifted
/// * `shift` - The number of bits to shift right
///
/// # Returns
/// * The result of the right shift operation
///
/// # Panics
/// * If `shift` is greater than 63 (via pow2's range check on the lookup table)
#[inline(always)]
pub fn shr_u64(self: u64, shift: u32) -> u64 {
self / pow2(shift)
}


Expand Down Expand Up @@ -113,9 +96,89 @@ pub fn fast_pow<
}
}


/// Fast power of 2 using lookup tables
/// Reference: https://github.com/keep-starknet-strange/alexandria/pull/336
///
/// # Arguments
/// * `exponent` - The exponent to raise 2 to
/// # Returns
/// * `u64` - The result of 2^exponent
/// # Panics
/// * If `exponent` is greater than 63 (out of the supported range)
pub fn pow2(exponent: u32) -> u64 {
let hardcoded_results: [u64; 64] = [
0x1,
0x2,
0x4,
0x8,
0x10,
0x20,
0x40,
0x80,
0x100,
0x200,
0x400,
0x800,
0x1000,
0x2000,
0x4000,
0x8000,
0x10000,
0x20000,
0x40000,
0x80000,
0x100000,
0x200000,
0x400000,
0x800000,
0x1000000,
0x2000000,
0x4000000,
0x8000000,
0x10000000,
0x20000000,
0x40000000,
0x80000000,
0x100000000,
0x200000000,
0x400000000,
0x800000000,
0x1000000000,
0x2000000000,
0x4000000000,
0x8000000000,
0x10000000000,
0x20000000000,
0x40000000000,
0x80000000000,
0x100000000000,
0x200000000000,
0x400000000000,
0x800000000000,
0x1000000000000,
0x2000000000000,
0x4000000000000,
0x8000000000000,
0x10000000000000,
0x20000000000000,
0x40000000000000,
0x80000000000000,
0x100000000000000,
0x200000000000000,
0x400000000000000,
0x800000000000000,
0x1000000000000000,
0x2000000000000000,
0x4000000000000000,
0x8000000000000000,
];
*hardcoded_results.span()[exponent]
}

#[cfg(test)]
mod tests {
use super::{fast_pow, shl, shr};
use super::{fast_pow, pow2, shl, shr_u64};

#[test]
#[available_gas(1000000000)]
Expand All @@ -128,6 +191,18 @@ mod tests {
assert_eq!(fast_pow(10_u128, 5_u128), 100000, "invalid result");
}

#[test]
#[available_gas(1000000000)]
fn test_pow2() {
assert_eq!(pow2(0), 1, "2^0 should be 1");
assert_eq!(pow2(1), 2, "2^1 should be 2");
assert_eq!(pow2(2), 4, "2^2 should be 4");
assert_eq!(pow2(3), 8, "2^3 should be 8");
assert_eq!(pow2(10), 1024, "2^10 should be 1024");
assert_eq!(pow2(63), 0x8000000000000000, "2^63 should be 0x8000000000000000");
assert_eq!(pow2(63), 0x8000000000000000, "2^64 should be 0x8000000000000000");
}

#[test]
fn test_shl() {
let value1: u32 = 3;
Expand All @@ -142,19 +217,20 @@ mod tests {
}

#[test]
fn test_shr() {
// Assuming T and U are u32 for simplicity
let x: u32 = 32;
fn test_shr_u64() {
// Expect about 15% steps reduction over previous test,
// should be much higher for bigger shifts
let x: u64 = 32;
let shift: u32 = 2;
let result = shr(x, shift);
let result = shr_u64(x, shift);
assert_eq!(result, 8);

let shift: u32 = 32;
let result = shr(x, shift);
let result = shr_u64(x, shift);
assert_eq!(result, 0);

let shift: u32 = 0;
let result = shr(x, shift);
let result = shr_u64(x, shift);
assert_eq!(result, 32);
}
}
14 changes: 7 additions & 7 deletions packages/utils/src/numeric.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::bit_shifts::shr;
use crate::bit_shifts::shr_u64;

/// Reverses the byte order of a `u32`.
///
Expand All @@ -19,12 +19,12 @@ pub fn u64_next_power_of_two(mut n: u64) -> u64 {
}

n -= 1;
n = n | shr(n, 1_u64);
n = n | shr(n, 2_u64);
n = n | shr(n, 4_u64);
n = n | shr(n, 8_u64);
n = n | shr(n, 16_u64);
n = n | shr(n, 32_u64);
n = n | shr_u64(n, 1_u32);
n = n | shr_u64(n, 2_u32);
n = n | shr_u64(n, 4_u32);
n = n | shr_u64(n, 8_u32);
n = n | shr_u64(n, 16_u32);
n = n | shr_u64(n, 32_u32);

n + 1
}
Expand Down

1 comment on commit c3c1581

@feltroidprime
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match pattern is actually better than const array + span + index

Please sign in to comment.