Skip to content

Commit

Permalink
Handling u256 inverse for n == 1. (#4478)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Nov 26, 2023
1 parent ea445c7 commit 1932d1d
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 17 deletions.
2 changes: 2 additions & 0 deletions corelib/src/math.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ fn inv_mod<
}

/// Returns `1 / b (mod n)`, or None if `b` is not invertible modulo `n`.
/// All `b`s will be considered not invertible for `n == 1`.
/// Additionally returns several `U128MulGuarantee`s that are required for validating the
/// calculation.
extern fn u256_guarantee_inv_mod_n(
Expand All @@ -98,6 +99,7 @@ extern fn u256_guarantee_inv_mod_n(
> implicits(RangeCheck) nopanic;

/// Returns the inverse of `a` modulo `n`, or None if `a` is not invertible modulo `n`.
/// All `b`s will be considered not invertible for `n == 1`.
#[inline(always)]
fn u256_inv_mod(a: u256, n: NonZero<u256>) -> Option<NonZero<u256>> {
match u256_guarantee_inv_mod_n(a, n) {
Expand Down
18 changes: 14 additions & 4 deletions corelib/src/test/math_test.cairo
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::option::OptionTrait;
use core::math;

/// Helper for making a non-zero value.
Expand Down Expand Up @@ -83,10 +84,19 @@ fn test_u256_div_mod_n() {
}

#[test]
fn test_u256_inv_mod_no_inverse() {
assert(math::u256_inv_mod(3, nz(6)).is_none(), 'inv_mod(3, 6)');
assert(math::u256_inv_mod(4, nz(6)).is_none(), 'inv_mod(4, 6)');
assert(math::u256_inv_mod(8, nz(4)).is_none(), 'inv_mod(8, 4)');
fn test_u256_inv_mod() {
assert(math::u256_inv_mod(5, nz(24)).unwrap().into() == 5_u256, 'inv_mov(5, 24) != 5');
assert(math::u256_inv_mod(29, nz(24)).unwrap().into() == 5_u256, 'inv_mov(29, 24) != 5');
assert(math::u256_inv_mod(1, nz(24)).unwrap().into() == 1_u256, 'inv_mov(1, 24) != 1');
assert(math::u256_inv_mod(1, nz(5)).unwrap().into() == 1_u256, 'inv_mov(1, 5) != 1');
assert(math::u256_inv_mod(8, nz(24)).is_none(), 'inv_mov(8, 24) != None');
assert(math::u256_inv_mod(1, nz(1)).is_none(), 'inv_mov(1, 1) != None');
assert(math::u256_inv_mod(7, nz(1)).is_none(), 'inv_mov(7, 1) != None');
assert(math::u256_inv_mod(0, nz(1)).is_none(), 'inv_mov(0, 1) != None');
assert(math::u256_inv_mod(0, nz(7)).is_none(), 'inv_mov(0, 7) != None');
assert(math::u256_inv_mod(3, nz(6)).is_none(), 'inv_mod(3, 6) != None');
assert(math::u256_inv_mod(4, nz(6)).is_none(), 'inv_mod(4, 6) != None');
assert(math::u256_inv_mod(8, nz(4)).is_none(), 'inv_mod(8, 4) != None');
assert(
math::u256_inv_mod(
0xea9195982bd472e30e5146ad7cb0acd954cbc75032a298ac73234b6b05e28cc1,
Expand Down
13 changes: 12 additions & 1 deletion crates/cairo-lang-casm/src/hints/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ pub enum CoreHint {
/// `g * s = b`
/// `g * t = n`
///
/// The case `n == 1` is considered "no-inverse" (special case).
/// In this case: Returns `g == 1`, `s == b` and `t == 1`.
/// All no-inverse requirements are satisfied, except for `g > 1`.
///
/// In all cases - `name`0 is the least significant limb.
#[codec(index = 27)]
U256InvModN {
Expand Down Expand Up @@ -732,7 +736,14 @@ impl PythonicHint for CoreHint {
n = {n0} + ({n1} << 128)
(_, r, g) = igcdex(n, b)
if g != 1:
if n == 1:
memory{g0_or_no_inv} = 1
memory{g1_option} = 0
memory{s_or_r0} = {b0}
memory{s_or_r1} = {b1}
memory{t_or_k0} = 1
memory{t_or_k1} = 0
elif g != 1:
if g % 2 == 0:
g = 2
s = b // g
Expand Down
11 changes: 9 additions & 2 deletions crates/cairo-lang-runner/src/casm_run/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1970,10 +1970,17 @@ pub fn execute_core_hint(
let b1 = get_val(vm, b1)?.to_bigint();
let n0 = get_val(vm, n0)?.to_bigint();
let n1 = get_val(vm, n1)?.to_bigint();
let b: BigInt = b0 + b1.shl(128);
let b: BigInt = b0.clone() + b1.clone().shl(128);
let n: BigInt = n0 + n1.shl(128);
let ExtendedGcd { gcd: mut g, x: _, y: mut r } = n.extended_gcd(&b);
if g != 1.into() {
if n == 1.into() {
insert_value_to_cellref!(vm, s_or_r0, Felt252::from(b0))?;
insert_value_to_cellref!(vm, s_or_r1, Felt252::from(b1))?;
insert_value_to_cellref!(vm, t_or_k0, Felt252::from(1))?;
insert_value_to_cellref!(vm, t_or_k1, Felt252::from(0))?;
insert_value_to_cellref!(vm, g0_or_no_inv, Felt252::from(1))?;
insert_value_to_cellref!(vm, g1_option, Felt252::from(0))?;
} else if g != 1.into() {
// This makes sure `g0_or_no_inv` is always non-zero in the no inverse case.
if g.is_even() {
g = 2u32.into();
Expand Down
2 changes: 1 addition & 1 deletion crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ fn u256_libfunc_cost(libfunc: &Uint256Concrete) -> Vec<ConstCost> {
Uint256Concrete::SquareRoot(_) => vec![ConstCost { steps: 30, holes: 0, range_checks: 7 }],
Uint256Concrete::InvModN(_) => vec![
ConstCost { steps: 40, holes: 0, range_checks: 9 },
ConstCost { steps: 23, holes: 0, range_checks: 7 },
ConstCost { steps: 25, holes: 0, range_checks: 7 },
],
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,17 @@ fn build_u256_inv_mod_n(
assert s1 = *(range_check++);
assert t0 = *(range_check++);
assert t1 = *(range_check++);
// Validating `g > 1`.
// Validate that `g > 1` or `g = n = 1`.
tempvar g0_minus_1;
jump GIsValid if g1 != 0;
assert g0_minus_1 = g0 - one;
jump GIsValid if g0_minus_1 != 0;
fail;
// Handle the case where `g = 1`, which is only valid if `n = 1`.
assert n1 = zero;
assert n0 = one;
GIsValid:

// Validating `g * s = b` and `g * t = n`.
// Validate `g * s = b` and `g * t = n`.

// Only calculate the upper word, since we already know the lower word is `b0`.
let g0s0_low = b0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ impl NoGenericArgsGenericLibfunc for Uint256SquareRootLibfunc {
}
}

// Inverse Modulo N.
/// Inverse Modulo N.
///
/// Libfunc for calculating the inverse of a number modulo N.
/// If `N == 1`, the value is not considered invertible.
#[derive(Default)]
pub struct Uint256InvModNLibfunc;
impl NoGenericArgsGenericLibfunc for Uint256InvModNLibfunc {
Expand Down
18 changes: 13 additions & 5 deletions tests/e2e_test_data/libfuncs/u256
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,14 @@ b = memory[fp + -6] + (memory[fp + -5] << 128)
n = memory[fp + -4] + (memory[fp + -3] << 128)

(_, r, g) = igcdex(n, b)
if g != 1:
if n == 1:
memory[ap + 0] = 1
memory[ap + 1] = 0
memory[ap + 2] = memory[fp + -6]
memory[ap + 3] = memory[fp + -5]
memory[ap + 4] = 1
memory[ap + 5] = 0
elif g != 1:
if g % 2 == 0:
g = 2
s = b // g
Expand Down Expand Up @@ -371,17 +378,18 @@ jmp rel 3;
[ap + 0] = [ap + -2] + 340282366920938463463374607431768178688, ap++;
[ap + -1] = [[fp + -7] + 8];
[ap + -30] = [ap + -22] + [ap + -3];
jmp rel 47;
jmp rel 49;
[ap + -1] = [[fp + -7] + 0], ap++;
[ap + -1] = [[fp + -7] + 1], ap++;
[ap + -1] = [[fp + -7] + 2], ap++;
[ap + -1] = [[fp + -7] + 3], ap++;
[ap + -1] = [[fp + -7] + 4], ap++;
[ap + -1] = [[fp + -7] + 5];
jmp rel 8 if [ap + -5] != 0, ap++;
jmp rel 10 if [ap + -5] != 0, ap++;
[ap + -7] = [ap + -1] + 1;
jmp rel 4 if [ap + -1] != 0;
[fp + -1] = [fp + -1] + 1;
jmp rel 6 if [ap + -1] != 0;
[fp + -3] = 0;
[fp + -4] = 1;
%{ (memory[ap + 0], memory[fp + -6]) = divmod(memory[ap + -7] * memory[ap + -5], 2**128) %}
%{ (memory[ap + 1], memory[fp + -4]) = divmod(memory[ap + -7] * memory[ap + -3], 2**128) %}
jmp rel 12 if [ap + -6] != 0, ap++;
Expand Down

0 comments on commit 1932d1d

Please sign in to comment.