Skip to content

Commit

Permalink
Nvidia MSM proof of concept (serial) (#480)
Browse files Browse the repository at this point in the history
* wrap `execCudaImpl` macro logic in a block

Otherwise we run into problems if we have two execs in the same scope.

* add more EC Jac operations to helper templates

* do not quit on failure in NvidiaAssembler destructor

A failure in the check from the destructor almost certainly means that
we destroyed early, due to an exception. We don't want to hide the
exception, hence we don't quit.

* add CurveDescriptor fields for LLVM type for Fr, scalars for MSM

* [LLVM] add `isPointerTy` helper to determine if type is a pointer

* [tests] add sanity test for adding neutral EC element to EC sum

* store EC order bit width in CurveDescriptor

* make `store` for `ValueRef` safer by checking for pointer-ness

Also adds `storePtr` if user really wants to store a pointer

* forbid `=copy` on Array, likely *not* what user wants

Easy to introduce bugs by thinking one stores, when in fact one just
copies the reference.

* allow access read/write of `Array` using `ValueRef`

* add `FieldScalar`, `FieldScalarArray`, `EcAffArray`, `EcAffArray`

- for safer handling of multiple EC points in different coordinates
- separate logic of elements of Fp (`Field`) from those of Fr (`FieldScalar`)

* extend doc string of `compile` taking a string

* add ConstantValue, MutableValue wrappers around ValueRef

Dealing with ValueRef and the fact that pointers are now opaque in
LLVM is extremely annoying. So here are 2 types that wrap the LLVM
values with their respective underlying types which also provide
easier load / write access.

* add `llvmFor` macro that produces code for a for loop in LLVM

* add helpers for arithmetic, boolean logic for ValueRef, M/CValue

* add `llvmIf` to generate code for if statements

It _wraps around_ a full if statement.

* add `to` type conversion helper which extends/truncates int types

* use `llvmForCountdown` in `genFpNsqrRt` instead of fixed countdown logic

* add `getWindowAt` helper required for baseline MSM implementation

* add serial MSM implementation for Nvidia using bucket method

This implementation is a bit of a proof of concept and playground to
investigate how easily we can generate code on the LLVM target with
the help of Nim macros.

* [tests] add mini test case for MSM on Nvidia

* whoops, revert local change to test CT error on `=copy`
  • Loading branch information
Vindaar authored Jan 8, 2025
1 parent d9db7ab commit 5d66b52
Show file tree
Hide file tree
Showing 11 changed files with 929 additions and 38 deletions.
39 changes: 34 additions & 5 deletions constantine/math_compiler/codegen_nvidia.nim
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@ export
# Cuda Driver API
# ------------------------------------------------------------

template check*(status: CUresult) =
template check*(status: CUresult, quitOnFailure = true) =
## Check the status code of a CUDA operation
## Exit program with error if failure

let code = status # ensure that the input expression is evaluated once only

if code != CUDA_SUCCESS:
writeStackTrace()
stderr.write(astToStr(status) & " " & $instantiationInfo() & " exited with error: " & $code & '\n')
quit 1
if quitOnFailure:
quit 1 # NOTE: this hides exceptions if they are thrown!

func cuModuleLoadData*(module: var CUmodule, sourceCode: openArray[char]): CUresult {.inline.}=
cuModuleLoadData(module, sourceCode[0].unsafeAddr)
Expand Down Expand Up @@ -448,6 +450,9 @@ proc execCudaImpl(jitFn, res, inputs: NimNode): NimNode =
x[0]
)
)
result = quote do:
block:
`result`

macro execCuda*(jitFn: CUfunction,
res: typed,
Expand Down Expand Up @@ -513,8 +518,18 @@ type

proc `=destroy`*(nv: NvidiaAssemblerObj) =
## XXX: Need to also call the finalizer for `asy` in the future!
check nv.cuMod.cuModuleUnload()
check nv.cuCtx.cuCtxDestroy()
# NOTE: In the destructor we don't want to quit on a `check` failure.
# The reason is that if we throw an exception with an `NvidiaAssembler`
# in scope, it will trigger the destructor here (with a likely invalid
# state in the CUDA module / context). However, in this case
# we will crash anyway and would just end up hiding the actual cause of
# the error.
# In the unlikely case that all CUDA operations worked correctly up
# to this point, but then fail to unload, we currently ignore this
# as a failure mode.
# Hopefully we find a better solution in the future.
check nv.cuMod.cuModuleUnload(), quitOnFailure = false
check nv.cuCtx.cuCtxDestroy(), quitOnFailure = false
`=destroy`(nv.asy)

proc initNvAsm*[Name: static Algebra](field: type FF[Name], wordSize: int = 32, backend = bkNvidiaPTX): NvidiaAssembler =
Expand Down Expand Up @@ -571,7 +586,8 @@ proc initNvAsm*[Name: static Algebra](field: type EC_ShortW_Jac[Fp[Name], G1], w
Fp[Name].getModulus().toHex(),
v = 1, w = wordSize,
coef_a = Fp[Name].Name.getCoefA(),
coef_B = Fp[Name].Name.getCoefB()
coef_B = Fp[Name].Name.getCoefB(),
curveOrderBitWidth = Fr[Name].bits()
)
result.fd = result.cd.fd
result.asy.definePrimitives(result.cd)
Expand All @@ -580,6 +596,19 @@ proc compile*(nv: NvidiaAssembler, kernName: string): CUfunction =
## Overload of `compile` below.
## Call this version if you have manually used the Assembler_LLVM object
## to build instructions and have a kernel name you wish to compile.
##
## Use this overload if your generator function does not match the `FieldFnGenerator` or
## `CurveFnGenerator` signatures. This is useful if your function requires additional
## arguments that are compile time values in the context of LLVM.
##
## Example:
##
## ```nim
## let nv = initNvAsm(EC, wordSize)
## let kernel = nv.compile(asy.genEcMSM(cd, 3, 1000) # window size, num. points
## ```
## where `genEcMSM` returns the name of the kernel.

let ptx = nv.asy.codegenNvidiaPTX(nv.sm) # convert to PTX

# GPU exec
Expand Down
29 changes: 29 additions & 0 deletions constantine/math_compiler/impl_curves_ops_affine.nim
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ const SectionName = "ctt.curves_affine"
type
EcPointAff* {.borrow: `.`.} = distinct Array

proc `=copy`*(m: var EcPointAff, x: EcPointAff) {.error: "Copying an EcPointAff is not allowed. " &
"You likely want to copy the LLVM value. Use `dst.store(src)` instead.".}

proc asEcPointAff*(br: BuilderRef, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointAff =
## Constructs an elliptic curve point in Affine coordinates from an array pointer.
##
## `arrayTy` is an `array[FieldTy, 2]` where `FieldTy` itsel is an array of
## `array[WordTy, NumWords]`.
result = EcPointAff(br.asArray(arrayPtr, arrayTy))

proc asEcPointAff*(asy: Assembler_LLVM, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointAff =
## Constructs an elliptic curve point in Affine coordinates from an array pointer.
##
Expand Down Expand Up @@ -54,6 +64,25 @@ proc store*(dst: EcPointAff, src: EcPointAff) =
store(dst.getX(), src.getX())
store(dst.getY(), src.getY())

# Array of EC points in affine coordinates
type EcAffArray* {.borrow: `.`.} = distinct Array

proc `=copy`(m: var EcAffArray, x: EcAffArray) {.error: "Copying an EcAffArray is not allowed. " &
"You likely want to copy the LLVM value. Use `dst.store(src)` instead.".}

proc `[]`*(a: EcAffArray, index: SomeInteger | ValueRef): EcPointAff = a.builder.asEcPointAff((distinctBase(a).getPtr(index)), a.elemTy)
proc `[]=`*(a: EcAffArray, index: SomeInteger | ValueRef, val: EcPointAff) = distinctBase(a)[index] = val.buf

proc asEcAffArray*(asy: Assembler_LLVM, cd: CurveDescriptor, a: ValueRef, num: int): EcAffArray =
## Interpret the given value `a` as an array of EC elements in Affine coordinates.
let ty = array_t(cd.curveTyAff, num)
result = EcAffArray(asy.br.asArray(a, ty))

proc initEcAffArray*(asy: Assembler_LLVM, cd: CurveDescriptor, num: int): EcAffArray =
## Initialize a new EcAffArray for `num` elements
let ty = array_t(cd.curveTyAff, num)
result = EcAffArray(asy.makeArray(ty))

template declEllipticAffOps*(asy: Assembler_LLVM, cd: CurveDescriptor): untyped =
## This template can be used to make operations on `Field` elements
## more convenient.
Expand Down
43 changes: 42 additions & 1 deletion constantine/math_compiler/impl_curves_ops_jacobian.nim
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ const SectionName = "ctt.curves_jacobian"
type
EcPointJac* {.borrow: `.`.} = distinct Array

proc `=copy`(m: var EcPointJac, x: EcPointJac) {.error: "Copying an EcPointJac is not allowed. " &
"You likely want to copy the LLVM value. Use `dst.store(src)` instead.".}

proc asEcPointJac*(br: BuilderRef, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointJac =
## Constructs an elliptic curve point in Jacobian coordinates from an array pointer.
##
## `arrayTy` is an `array[FieldTy, 3]` where `FieldTy` itsel is an array of
## `array[WordTy, NumWords]`.
result = EcPointJac(br.asArray(arrayPtr, arrayTy))

proc asEcPointJac*(asy: Assembler_LLVM, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointJac =
## Constructs an elliptic curve point in Jacobian coordinates from an array pointer.
##
Expand Down Expand Up @@ -57,17 +67,48 @@ proc store*(dst: EcPointJac, src: EcPointJac) =
store(dst.getY(), src.getY())
store(dst.getZ(), src.getZ())

# Representation of a finite field point with some utilities
type EcJacArray* {.borrow: `.`.} = distinct Array

proc `=copy`(m: var EcJacArray, x: EcJacArray) {.error: "Copying an EcJacArray is not allowed. " &
"You likely want to copy the LLVM value. Use `dst.store(src)` instead.".}

proc `[]`*(a: EcJacArray, index: SomeInteger | ValueRef): EcPointJac = a.builder.asEcPointJac((distinctBase(a).getPtr(index)), a.elemTy)
proc `[]=`*(a: EcJacArray, index: SomeInteger | ValueRef, val: EcPointJac) = distinctBase(a)[index] = val.buf

proc asEcJacArray*(asy: Assembler_LLVM, cd: CurveDescriptor, a: ValueRef, num: int): EcJacArray =
## Interpret the given value `a` as an array of EC elements in Jacobian coordinates.
let ty = array_t(cd.curveTy, num)
result = EcJacArray(asy.br.asArray(a, ty))

proc initEcJacArray*(asy: Assembler_LLVM, cd: CurveDescriptor, num: int): EcJacArray =
## Initialize a new EcJacArray for `num` elements
let ty = array_t(cd.curveTy, num)
result = EcJacArray(asy.makeArray(ty))

template declEllipticJacOps*(asy: Assembler_LLVM, cd: CurveDescriptor): untyped =
## This template can be used to make operations on `Field` elements
## more convenient.
## XXX: extend to include all ops
# Setters
template setNeutral(x: EcPointJac): untyped = asy.setNeutral(cd, x.buf)

# Boolean checks
template isNeutral(res, x: EcPointJac): untyped = asy.isNeutral(cd, res, x.buf)
template isNeutral(x: EcPointJac): untyped =
var res = asy.br.alloca(asy.ctx.int1_t())
asy.isNeutral(cd, res, x.buf)
res

# Mutating assignment ops
template sum(res, x, y: EcPointJac): untyped = asy.sum(cd, res.buf, x.buf, y.buf)
template `+=`(x, y: EcPointJac): untyped = x.sum(x, y)
template mixedSum(res, x: EcPointJac, y: EcPointAff): untyped = asy.mixedSum(cd, res.buf, x.buf, y.buf)
template `+=`(x: EcPointJac, y: EcPointAff): untyped = x.mixedSum(x, y)

# Arithmetic mutations
template double(res, x: EcPointJac): untyped = asy.double(cd, res.buf, x.buf)
template double(x: EcPointJac): untyped = x.double(x)

# Conditional ops
template ccopy(x, y: EcPointJac, c): untyped = asy.ccopy(cd, x.buf, y.buf, derefBool c)

Expand Down
44 changes: 44 additions & 0 deletions constantine/math_compiler/impl_fields_ops.nim
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
constantine/platforms/bithacks, # for log2_vartime
constantine/platforms/llvm/[llvm, asm_nvidia],
./ir,
./impl_fields_globals,
Expand Down Expand Up @@ -725,3 +726,46 @@ proc scalarMul*(asy: Assembler_LLVM, fd: FieldDescriptor, a: ValueRef, b: int) =
asy.br.retVoid()

asy.callFn(name, [a])

proc getWindowAt*(asy: Assembler_LLVM, cd: CurveDescriptor, r, c, bI, wI: ValueRef) {.used.} =
## Generate an internal field `getWindowAt` function
## with signature
## void name(BaseType r, FieldType c, int bitIndex, int windowSize)
let name = cd.fd.name & "_getWindowAt"
asy.llvmInternalFnDef(
name, SectionName,
asy.void_t, toTypes([r, c, bI, wI]),
{kHot}):
tagParameter(1, "sret")

# Operations for numbers as `ValueRef`
declNumberOps(asy, cd.fd)

let (ri, ci, bitIndex, windowSize) = llvmParams
let rA = asy.asFieldScalar(cd, ri)
let cA = asy.asFieldScalar(cd, ci)
let fd = cd.fd

# Nim values
let SlotShift = log2_vartime(fd.w.uint32)
let WordMask = fd.w - 1
let WindowMask = (1 shl windowSize) - 1 # LLVM

# LLVM values
let slot = bitIndex shr SlotShift
let word = cA[slot] # word in limbs
let pos = bitIndex and WordMask # position in the word

# This is constant-time, the branch does not depend on secret data.
llvmIf(asy): # transforms an `if` statement body into llvm conditional branches
if pos + windowSize > fd.w and slot+1 < fd.numWords:
# Read next word as well
let x = ((word shr pos) or (cA[slot+1] shl (fd.w - pos))) and WindowMask
asy.store(ri, x)
else:
let x = (word shr pos) and WindowMask
asy.store(ri, x)

asy.br.retVoid()

asy.callFn(name, [r, c, bI, wI])
96 changes: 96 additions & 0 deletions constantine/math_compiler/impl_msm_nvidia.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
constantine/platforms/llvm/[llvm, asm_nvidia],
constantine/platforms/[primitives],
./ir,
./impl_fields_globals,
./impl_fields_dispatch,
./impl_fields_ops,
./impl_curves_ops_affine,
./impl_curves_ops_jacobian,
std / typetraits # for distinctBase

## Section name used for `llvmInternalFnDef`
const SectionName = "ctt.msm_nvidia"

proc msm*(asy: Assembler_LLVM, cd: CurveDescriptor, r, coefs, points: ValueRef,
c, N: int) {.used.} =
## Inner implementation of MSM, for static dispatch over c, the bucket bit length
## This is a straightforward simple translation of BDLO12, section 4
##
## Entirely serial implementation!
##
## Important note: The coefficients given to this procedure must be in canonical
## representation instead of Montgomery representation! Thus, you cannot pass
## values of type `Fr[Curve]` directly, as they are internally stored in Montgomery
## rep. Convert to a `BigInt` using `fromField`.
let name = cd.name & "_msm_impl"
asy.llvmInternalFnDef(
name, SectionName,
asy.void_t, toTypes([r, coefs, points]),
{kHot}):
tagParameter(1, "sret")

# Inject templates for convenient access
declFieldOps(asy, cd.fd)
declEllipticJacOps(asy, cd)
declEllipticAffOps(asy, cd)
declNumberOps(asy, cd.fd)

let (ri, coefsIn, pointsIn) = llvmParams
let rA = asy.asEcPointJac(cd, ri)
let cs = asy.asFieldScalarArray(cd, coefsIn, N) # coefficients
let Ps = asy.asEcAffArray(cd, pointsIn, N) # EC points
# Prologue
# --------
let numBuckets = 1 shl c - 1 # bucket 0 is unused
let numWindows = cd.orderBitWidth.int.ceilDiv_vartime(c)

let miniMSMs = asy.initEcJacArray(cd, numWindows)
let buckets = asy.initEcJacArray(cd, numBuckets)

# Algorithm
# ---------
var cNonZero = asy.initMutVal(cd.fd.wordTy)
asy.llvmFor w, 0, numWindows - 1, true:
# Place our points in a bucket corresponding to
# how many times their bit pattern in the current window of size c
asy.llvmFor i, 0, numBuckets - 1, true:
buckets[i].setNeutral()

# 1. Bucket accumulation. Cost: n - (2ᶜ-1) => n points in 2ᶜ-1 buckets, first point per bucket is just copied
asy.llvmFor j, 0, N-1, true:
var b = asy.initMutVal(cd.fd.wordTy)
let w0 = asy.initConstVal(0, cd.fd.wordTy)
asy.getWindowAt(cd, b.buf, cs[j].buf, asy.to(w, cd.fd.wordTy) * c, constInt(cd.fd.wordTy, c))
llvmIf(asy):
if b != w0:
buckets[b-1] += Ps[j]

var accumBuckets = asy.newEcPointJac(cd)
var miniMSM = asy.newEcPointJac(cd)
accumBuckets.store(buckets[numBuckets-1])
miniMSM.store(buckets[numBuckets-1])

asy.llvmFor k, numBuckets-2, 0, false:
accumBuckets += buckets[k] # Stores S₈ then S₈+S₇ then S₈+S₇+S₆ then ...
miniMSM += accumBuckets # Stores S₈ then [2]S₈+S₇ then [3]S₈+[2]S₇+S₆ then ...

miniMSMs[w].store(miniMSM)

rA.store(miniMSMs[numWindows-1])
asy.llvmFor w, numWindows-2, 0, false:
asy.llvmFor j, 0, c-1:
rA.double()
rA += miniMSMs[w]

asy.br.retVoid()

asy.callFn(name, [r, coefs, points])
Loading

0 comments on commit 5d66b52

Please sign in to comment.