Skip to content

Commit

Permalink
Elide kernel wrappers by altering device function IR
Browse files Browse the repository at this point in the history
  • Loading branch information
gmarkall committed Nov 29, 2024
1 parent 793d238 commit e7125bf
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 33 deletions.
190 changes: 180 additions & 10 deletions numba_cuda/numba/cuda/compiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from llvmlite import ir
from numba.core.typing.templates import ConcreteTemplate
from numba.core import types, typing, funcdesc, config, compiler, sigutils
from numba.core import (cgutils, types, typing, funcdesc, config, compiler,
sigutils, utils)
from numba.core.compiler import (sanitize_compile_result_entries, CompilerBase,
DefaultPassBuilder, Flags, Option,
CompileResult)
Expand All @@ -11,7 +12,10 @@
from numba.core.typed_passes import (IRLegalization, NativeLowering,
AnnotateTypes)
from warnings import warn
from numba.cuda import nvvmutils
from numba.cuda.api import get_current_device
from numba.cuda.cudadrv import nvvm
from numba.cuda.descriptor import cuda_target
from numba.cuda.target import CUDACABICallConv


Expand All @@ -24,6 +28,15 @@ def _nvvm_options_type(x):
return x


def _optional_int_type(x):
if x is None:
return None

else:
assert isinstance(x, int)
return x


class CUDAFlags(Flags):
nvvm_options = Option(
type=_nvvm_options_type,
Expand All @@ -35,6 +48,16 @@ class CUDAFlags(Flags):
default=None,
doc="Compute Capability",
)
max_registers = Option(
type=_optional_int_type,
default=None,
doc="Max registers"
)
lto = Option(
type=bool,
default=False,
doc="Enable Link-time Optimization"
)


# The CUDACompileResult (CCR) has a specially-defined entry point equal to its
Expand Down Expand Up @@ -109,7 +132,11 @@ def run_pass(self, state):
codegen = state.targetctx.codegen()
name = state.func_id.func_qualname
nvvm_options = state.flags.nvvm_options
state.library = codegen.create_library(name, nvvm_options=nvvm_options)
max_registers = state.flags.max_registers
lto = state.flags.lto
state.library = codegen.create_library(name, nvvm_options=nvvm_options,
max_registers=max_registers,
lto=lto)
# Enable object caching upfront so that the library can be serialized.
state.library.enable_object_caching()

Expand Down Expand Up @@ -152,7 +179,7 @@ def define_cuda_lowering_pipeline(self, state):
@global_compiler_lock
def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
inline=False, fastmath=False, nvvm_options=None,
cc=None):
cc=None, max_registers=None, lto=False):
if cc is None:
raise ValueError('Compute Capability must be supplied')

Expand Down Expand Up @@ -189,6 +216,8 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
if nvvm_options:
flags.nvvm_options = nvvm_options
flags.compute_capability = cc
flags.max_registers = max_registers
flags.lto = lto

# Run compilation pipeline
from numba.core.target_extension import target_override
Expand Down Expand Up @@ -247,11 +276,155 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name,
builder, func, restype, argtypes, callargs)
builder.ret(return_value)

if config.DUMP_LLVM:
utils.dump_llvm(fndesc, wrapper_module)

library.add_ir_module(wrapper_module)
library.finalize()
return library


def kernel_fixup(kernel, debug):
if debug:
exc_helper = add_exception_store_helper(kernel)

# Pass 1 - replace:
#
# ret <value>
#
# with:
#
# exc_helper(<value>)
# ret void

for block in kernel.blocks:
for i, inst in enumerate(block.instructions):
if isinstance(inst, ir.Ret):
old_ret = block.instructions.pop()
block.terminator = None

# The original return's metadata will be set on the new
# instructions in order to preserve debug info
metadata = old_ret.metadata

builder = ir.IRBuilder(block)
if debug:
status_code = old_ret.operands[0]
exc_helper_call = builder.call(exc_helper, (status_code,))
exc_helper_call.metadata = metadata

new_ret = builder.ret_void()
new_ret.metadata = old_ret.metadata

# Need to break out so we don't carry on modifying what we are
# iterating over. There can only be one return in a block
# anyway.
break

# Pass 2: remove stores of null pointer to return value argument pointer

return_value = kernel.args[0]

for block in kernel.blocks:
remove_list = []

# Find all stores first
for inst in block.instructions:
if (isinstance(inst, ir.StoreInstr)
and inst.operands[1] == return_value):
remove_list.append(inst)

# Remove all stores
for to_remove in remove_list:
block.instructions.remove(to_remove)

# Replace non-void return type with void return type and remove return
# value

if isinstance(kernel.type, ir.PointerType):
new_type = ir.PointerType(ir.FunctionType(ir.VoidType(),
kernel.type.pointee.args[1:]))
else:
new_type = ir.FunctionType(ir.VoidType(), kernel.type.args[1:])

kernel.type = new_type
kernel.return_value = ir.ReturnValue(kernel, ir.VoidType())
kernel.args = kernel.args[1:]

# Mark as a kernel for NVVM

nvvm.set_cuda_kernel(kernel)

if config.DUMP_LLVM:
print(f"LLVM DUMP: Post kernel fixup {kernel.name}".center(80, '-'))
print(kernel.module)
print('=' * 80)


def add_exception_store_helper(kernel):

# Create global variables for exception state

def define_error_gv(postfix):
name = kernel.name + postfix
gv = cgutils.add_global_variable(kernel.module, ir.IntType(32),
name)
gv.initializer = ir.Constant(gv.type.pointee, None)
return gv

gv_exc = define_error_gv("__errcode__")
gv_tid = []
gv_ctaid = []
for i in 'xyz':
gv_tid.append(define_error_gv("__tid%s__" % i))
gv_ctaid.append(define_error_gv("__ctaid%s__" % i))

# Create exception store helper function

helper_name = kernel.name + "__exc_helper__"
helper_type = ir.FunctionType(ir.VoidType(), (ir.IntType(32),))
helper_func = ir.Function(kernel.module, helper_type, helper_name)

block = helper_func.append_basic_block(name="entry")
builder = ir.IRBuilder(block)

# Implement status check / exception store logic

status_code = helper_func.args[0]
call_conv = cuda_target.target_context.call_conv
status = call_conv._get_return_status(builder, status_code)

# Check error status
with cgutils.if_likely(builder, status.is_ok):
builder.ret_void()

with builder.if_then(builder.not_(status.is_python_exc)):
# User exception raised
old = ir.Constant(gv_exc.type.pointee, None)

# Use atomic cmpxchg to prevent rewriting the error status
# Only the first error is recorded

xchg = builder.cmpxchg(gv_exc, old, status.code,
'monotonic', 'monotonic')
changed = builder.extract_value(xchg, 1)

# If the xchange is successful, save the thread ID.
sreg = nvvmutils.SRegBuilder(builder)
with builder.if_then(changed):
for dim, ptr, in zip("xyz", gv_tid):
val = sreg.tid(dim)
builder.store(val, ptr)

for dim, ptr, in zip("xyz", gv_ctaid):
val = sreg.ctaid(dim)
builder.store(val, ptr)

builder.ret_void()

return helper_func


@global_compiler_lock
def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
fastmath=False, cc=None, opt=None, abi="c", abi_info=None,
Expand Down Expand Up @@ -347,13 +520,10 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
lib = cabi_wrap_function(tgt, lib, cres.fndesc, wrapper_name,
nvvm_options)
else:
code = pyfunc.__code__
filename = code.co_filename
linenum = code.co_firstlineno

lib, kernel = tgt.prepare_cuda_kernel(cres.library, cres.fndesc, debug,
lineinfo, nvvm_options, filename,
linenum)
lib = cres.library
kernel = lib.get_function(cres.fndesc.llvm_func_name)
lib._entry_name = cres.fndesc.llvm_func_name
kernel_fixup(kernel, debug)

if lto:
code = lib.get_ltoir(cc=cc)
Expand Down
17 changes: 8 additions & 9 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from numba.cuda.api import get_current_device
from numba.cuda.args import wrap_arg
from numba.cuda.compiler import compile_cuda, CUDACompiler
from numba.cuda.compiler import compile_cuda, CUDACompiler, kernel_fixup
from numba.cuda.cudadrv import driver
from numba.cuda.cudadrv.devices import get_context
from numba.cuda.descriptor import cuda_target
Expand Down Expand Up @@ -102,15 +102,14 @@ def __init__(self, py_func, argtypes, link=None, debug=False,
inline=inline,
fastmath=fastmath,
nvvm_options=nvvm_options,
cc=cc)
cc=cc,
max_registers=max_registers,
lto=lto)
tgt_ctx = cres.target_context
code = self.py_func.__code__
filename = code.co_filename
linenum = code.co_firstlineno
lib, kernel = tgt_ctx.prepare_cuda_kernel(cres.library, cres.fndesc,
debug, lineinfo, nvvm_options,
filename, linenum,
max_registers, lto)
lib = cres.library
kernel = lib.get_function(cres.fndesc.llvm_func_name)
lib._entry_name = cres.fndesc.llvm_func_name
kernel_fixup(kernel, self.debug)

if not link:
link = []
Expand Down
1 change: 1 addition & 0 deletions numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_issue_5835(self):
def f(x):
x[0] = 0

@unittest.skip("Wrappers no longer exist")
def test_wrapper_has_debuginfo(self):
sig = (types.int32[::1],)

Expand Down
13 changes: 3 additions & 10 deletions numba_cuda/numba/cuda/tests/cudapy/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ def foo(x, y):
self.assertIn("foo", llvm)

# Kernel in LLVM
self.assertIn('cuda.kernel.wrapper', llvm)

# Wrapped device function body in LLVM
self.assertIn("define linkonce_odr i32", llvm)
self.assertIn("define void @", llvm)

asm = foo.inspect_asm(sig)

Expand Down Expand Up @@ -72,12 +69,8 @@ def foo(x, y):
self.assertIn("foo", llvmirs[float64, float64])

# Kernels in LLVM
self.assertIn('cuda.kernel.wrapper', llvmirs[intp, intp])
self.assertIn('cuda.kernel.wrapper', llvmirs[float64, float64])

# Wrapped device function bodies in LLVM
self.assertIn("define linkonce_odr i32", llvmirs[intp, intp])
self.assertIn("define linkonce_odr i32", llvmirs[float64, float64])
self.assertIn("define void @", llvmirs[intp, intp])
self.assertIn("define void @", llvmirs[float64, float64])

asmdict = foo.inspect_asm()

Expand Down
3 changes: 1 addition & 2 deletions numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,9 @@ def caller(x):
subprograms += 1

# One DISubprogram for each of:
# - The kernel wrapper
# - The caller
# - The callee
expected_subprograms = 3
expected_subprograms = 2

self.assertEqual(subprograms, expected_subprograms,
f'"Expected {expected_subprograms} DISubprograms; '
Expand Down
7 changes: 5 additions & 2 deletions numba_cuda/numba/cuda/tests/cudapy/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ def device_func(x, y, z):


# Fragments of code that are removed from kernel_func's PTX when optimization
# is on
removed_by_opt = ( '__local_depot0', 'call.uni', 'st.param.b64')
# is on. Previously this list was longer when kernel wrappers were used - if
# the test function were more complex it may be possible to isolate additional
# fragments of PTX we could check for the absence / presence of, but removal of
# the use of local memory is a good indicator that optimization was applied.
removed_by_opt = ( '__local_depot0',)


@skip_on_cudasim('Simulator does not optimize code')
Expand Down

0 comments on commit e7125bf

Please sign in to comment.