Skip to content

Commit

Permalink
EVM modexp: solve DOS vectors (#286)
Browse files Browse the repository at this point in the history
* stash prep for Barret Reduction

* benches lost in rebase

* fix vartime reduction

* some improvement and fixes on reduce_vartime

* Fuse reductions when converting to Montgomery + use window=1 in powMont for small exponents. ~2.7x to 3.3x accel

* modexp: Introduce a no-reduction path for small base+exponent compared to modulus. Fix DOS

* optim for padded exponents

* remove commented out code [skip ci]

* Missing noInline for allocStackArray
  • Loading branch information
mratsim authored Oct 18, 2023
1 parent 34baa74 commit 4ccd8aa
Show file tree
Hide file tree
Showing 16 changed files with 1,116 additions and 154 deletions.
15 changes: 14 additions & 1 deletion benchmarks/bench_gmp_modmul.nim
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import
# Internal
../constantine/math/io/io_bigints,
../constantine/math/arithmetic,
../constantine/math_arbitrary_precision/arithmetic/limbs_divmod_vartime,
../constantine/platforms/abstractions,
../constantine/serialization/codecs,
# Test utilities
Expand Down Expand Up @@ -124,6 +125,13 @@ proc main() =
let stopCTTMod = getMonoTime()
echo "Constantine - ", aBits, " mod ", bBits, " -> ", bBits, " mod: ", float(inNanoseconds((stopCTTmod-startCTTmod)))/float(NumIters), " ns"

let startCTTvartimeMod = getMonoTime()
var q {.noInit.}: BigInt[bBits]
for _ in 0 ..< NumIters:
discard divRem_vartime(q.limbs, rTestMod.limbs, aTest.limbs, bTest.limbs)
let stopCTTvartimeMod = getMonoTime()
echo "Constantine - ", aBits, " mod ", bBits, " (vartime) -> ", bBits, " mod: ", float(inNanoseconds((stopCTTvartimeMod-startCTTvartimeMod)))/float(NumIters), " ns"

echo "----"
# Modular reduction - double-size

Expand All @@ -139,7 +147,12 @@ proc main() =
let stopCTTMod2 = getMonoTime()
echo "Constantine - ", rBits, " mod ", bBits, " -> ", bBits, " mod: ", float(inNanoseconds((stopCTTmod2-startCTTmod2)))/float(NumIters), " ns"

# Constantine
let startCTTvartimeMod2 = getMonoTime()
var q2 {.noInit.}: BigInt[bBits]
for _ in 0 ..< NumIters:
discard divRem_vartime(q2.limbs, rTestMod.limbs, rTest.limbs, bTest.limbs)
let stopCTTvartimeMod2 = getMonoTime()
echo "Constantine - ", rBits, " mod ", bBits, " (vartime) -> ", bBits, " mod: ", float(inNanoseconds((stopCTTvartimeMod2-startCTTvartimeMod2)))/float(NumIters), " ns"

echo ""

Expand Down
10 changes: 4 additions & 6 deletions constantine/math/arithmetic/limbs_extmul.nim
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@ func prod_comba[rLen, aLen, bLen: static int](r: var Limbs[rLen], a: Limbs[aLen]
u = t
t = Zero

if aLen+bLen < rLen:
for i in aLen+bLen ..< rLen:
r[i] = Zero
for i in aLen+bLen ..< rLen:
r[i] = Zero

func prod*[rLen, aLen, bLen: static int](r{.noalias.}: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) {.inline.} =
## Multi-precision multiplication
Expand Down Expand Up @@ -156,9 +155,8 @@ func square_Comba[rLen, aLen](
u = t
t = Zero

if aLen*2 < rLen:
for i in aLen*2 ..< rLen:
r[i] = Zero
for i in aLen*2 ..< rLen:
r[i] = Zero

func square_operandScan[rLen, aLen](
r: var Limbs[rLen],
Expand Down
69 changes: 45 additions & 24 deletions constantine/math_arbitrary_precision/arithmetic/bigints_views.nim
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import
./limbs_mod2k,
./limbs_multiprec,
./limbs_extmul,
./limbs_divmod
./limbs_divmod_vartime

# No exceptions allowed
{.push raises: [], checks: off.}
Expand All @@ -41,7 +41,34 @@ import
# Also need to take into account constant-time for RSA
# i.e. countLeadingZeros can only be done on public moduli.

func powOddMod_vartime*(
iterator unpackBE(scalarByte: byte): bool =
for i in countdown(7, 0):
yield bool((scalarByte shr i) and 1)

func pow_vartime(
r: var openArray[SecretWord],
a: openArray[SecretWord],
exponent: openArray[byte]) {.tags:[VarTime, Alloca], meter.} =
## r <- a^exponent

r.setOne()
var isOne = true

for e in exponent:
for bit in unpackBE(e):
if not isOne:
r.square_vartime(r)
if bit:
if isOne:
for i in 0 ..< a.len:
r[i] = a[i]
for i in a.len ..< r.len:
r[i] = Zero
isOne = false
else:
r.prod_vartime(r, a)

func powOddMod_vartime(
r: var openArray[SecretWord],
a: openArray[SecretWord],
exponent: openArray[byte],
Expand All @@ -55,33 +82,20 @@ func powOddMod_vartime*(
debug:
doAssert bool(M.isOdd())

let aBits = a.getBits_LE_vartime()
let mBits = M.getBits_LE_vartime()
let eBits = exponent.getBits_BE_vartime()

if eBits == 1:
r.view().reduce(a.view(), aBits, M.view(), mBits)
discard r.reduce_vartime(a, M)
return

let L = wordsRequired(mBits)
let m0ninv = M[0].negInvModWord()
var rMont = allocStackArray(SecretWord, L)

block:
var r2Buf = allocStackArray(SecretWord, L)
template r2: untyped = r2Buf.toOpenArray(0, L-1)
r2.r2_vartime(M.toOpenArray(0, L-1))

# Conversion to Montgomery can auto-reduced by up to M*R
# if we use redc2xMont (a/R) and montgomery multiplication by R³
# For now, we call explicit reduction as it can handle all sizes.
# TODO: explicit reduction uses constant-time division which is **very** expensive
if a.len != M.len:
let t = allocStackArray(SecretWord, L)
t.LimbsViewMut.reduce(a.view(), aBits, M.view(), mBits)
rMont.LimbsViewMut.getMont(LimbsViewConst t, M.view(), LimbsViewConst r2.view(), m0ninv, mBits)
else:
rMont.LimbsViewMut.getMont(a.view(), M.view(), LimbsViewConst r2.view(), m0ninv, mBits)
var aMont_buf = allocStackArray(SecretWord, L)
template aMont: untyped = aMont_buf.toOpenArray(0, L-1)

aMont.getMont_vartime(a, M)

block:
var oneMontBuf = allocStackArray(SecretWord, L)
Expand All @@ -91,12 +105,11 @@ func powOddMod_vartime*(
let scratchLen = L * ((1 shl window) + 1)
var scratchSpace = allocStackArray(SecretWord, scratchLen)

rMont.LimbsViewMut.powMont_vartime(
aMont_buf.LimbsViewMut.powMont_vartime(
exponent, M.view(), LimbsViewConst oneMontBuf,
m0ninv, LimbsViewMut scratchSpace, scratchLen, mBits)

r.view().fromMont(LimbsViewConst rMont, M.view(), m0ninv, mBits)

r.view().fromMont(LimbsViewConst aMont_buf, M.view(), m0ninv, mBits)

func powMod_vartime*(
r: var openArray[SecretWord],
Expand Down Expand Up @@ -128,6 +141,14 @@ func powMod_vartime*(
r[i] = Zero
return

# No modular reduction needed
# -------------------------------------------------------------------
if eBits < WordBitWidth and
aBits.uint shr (WordBitWidth - eBits) == 0 and # handle overflow of uint128 [0, aBits] << eBits
aBits.uint shl eBits < mBits.uint:
r.pow_vartime(a, exponent)
return

# Odd modulus
# -------------------------------------------------------------------
if M.isOdd().bool:
Expand Down Expand Up @@ -197,5 +218,5 @@ func powMod_vartime*(

var qyBuf = allocStackArray(SecretWord, M.len)
template qy: untyped = qyBuf.toOpenArray(0, M.len-1)
qy.prod(q, y)
qy.prod_vartime(q, y)
discard r.addMP(qy, a1)
24 changes: 12 additions & 12 deletions constantine/math_arbitrary_precision/arithmetic/limbs_divmod.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import
./limbs_fixedprec

# No exceptions allowed
{.push raises: [].}
{.push raises: [], checks: off.}

# ############################################################
#
Expand Down Expand Up @@ -63,7 +63,7 @@ func shlAddMod_estimate(a: LimbsViewMut, aLen: int,
a1 = (a[^1] shl (WordBitWidth-R)) or (a[^2] shr R)
m0 = (M[^1] shl (WordBitWidth-R)) or (M[^2] shr R)

# m0 has its high bit set. (a0, a1)/p0 fits in a limb.
# m0 has its high bit set. (a0, a1)/m0 fits in a limb.
# Get a quotient q, at most we will be 2 iterations off
# from the true quotient
var q, r: SecretWord
Expand All @@ -78,29 +78,29 @@ func shlAddMod_estimate(a: LimbsViewMut, aLen: int,

# Now substract a*2^64 - q*p
var carry = Zero
var over_p = CtTrue # Track if quotient greater than the modulus
var overM = CtTrue # Track if quotient greater than the modulus

for i in 0 ..< MLen:
var qp_lo: SecretWord
var qm_lo: SecretWord

block: # q*p
# q * p + carry (doubleword) carry from previous limb
muladd1(carry, qp_lo, q, M[i], carry)
block: # q*m
# q * m + carry (doubleword) carry from previous limb
muladd1(carry, qm_lo, q, M[i], carry)

block: # a*2^64 - q*p
var borrow: Borrow
subB(borrow, a[i], a[i], qp_lo, Borrow(0))
subB(borrow, a[i], a[i], qm_lo, Borrow(0))
carry += SecretWord(borrow) # Adjust if borrow

over_p = mux(a[i] == M[i], over_p, a[i] > M[i])
overM = mux(a[i] == M[i], overM, a[i] > M[i])

# Fix quotient, the true quotient is either q-1, q or q+1
#
# if carry < q or carry == q and over_p we must do "a -= p"
# if carry > hi (negative result) we must do "a += p"
# if carry < q or carry == q and over_p we must do "a -= m"
# if carry > hi (negative result) we must do "a += m"

result.neg = carry > hi
result.tooBig = not(result.neg) and (over_p or (carry < hi))
result.tooBig = not(result.neg) and (overM or (carry < hi))

func shlAddMod(a: LimbsViewMut, aLen: int,
c: SecretWord, M: LimbsViewConst, mBits: int) =
Expand Down
Loading

0 comments on commit 4ccd8aa

Please sign in to comment.