From e7125bf89c04d00f93e3856fd07167f4870ed5a0 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Tue, 13 Aug 2024 23:38:37 +0100 Subject: [PATCH] Elide kernel wrappers by altering device function IR --- numba_cuda/numba/cuda/compiler.py | 190 +++++++++++++++++- numba_cuda/numba/cuda/dispatcher.py | 17 +- .../numba/cuda/tests/cudapy/test_debuginfo.py | 1 + .../numba/cuda/tests/cudapy/test_inspect.py | 13 +- .../numba/cuda/tests/cudapy/test_lineinfo.py | 3 +- .../cuda/tests/cudapy/test_optimization.py | 7 +- 6 files changed, 198 insertions(+), 33 deletions(-) diff --git a/numba_cuda/numba/cuda/compiler.py b/numba_cuda/numba/cuda/compiler.py index ec41598..7d830aa 100644 --- a/numba_cuda/numba/cuda/compiler.py +++ b/numba_cuda/numba/cuda/compiler.py @@ -1,6 +1,7 @@ from llvmlite import ir from numba.core.typing.templates import ConcreteTemplate -from numba.core import types, typing, funcdesc, config, compiler, sigutils +from numba.core import (cgutils, types, typing, funcdesc, config, compiler, + sigutils, utils) from numba.core.compiler import (sanitize_compile_result_entries, CompilerBase, DefaultPassBuilder, Flags, Option, CompileResult) @@ -11,7 +12,10 @@ from numba.core.typed_passes import (IRLegalization, NativeLowering, AnnotateTypes) from warnings import warn +from numba.cuda import nvvmutils from numba.cuda.api import get_current_device +from numba.cuda.cudadrv import nvvm +from numba.cuda.descriptor import cuda_target from numba.cuda.target import CUDACABICallConv @@ -24,6 +28,15 @@ def _nvvm_options_type(x): return x +def _optional_int_type(x): + if x is None: + return None + + else: + assert isinstance(x, int) + return x + + class CUDAFlags(Flags): nvvm_options = Option( type=_nvvm_options_type, @@ -35,6 +48,16 @@ class CUDAFlags(Flags): default=None, doc="Compute Capability", ) + max_registers = Option( + type=_optional_int_type, + default=None, + doc="Max registers" + ) + lto = Option( + type=bool, + default=False, + doc="Enable Link-time Optimization" + ) # The CUDACompileResult (CCR) has a specially-defined entry point equal to its @@ -109,7 +132,11 @@ def run_pass(self, state): codegen = state.targetctx.codegen() name = state.func_id.func_qualname nvvm_options = state.flags.nvvm_options - state.library = codegen.create_library(name, nvvm_options=nvvm_options) + max_registers = state.flags.max_registers + lto = state.flags.lto + state.library = codegen.create_library(name, nvvm_options=nvvm_options, + max_registers=max_registers, + lto=lto) # Enable object caching upfront so that the library can be serialized. state.library.enable_object_caching() @@ -152,7 +179,7 @@ def define_cuda_lowering_pipeline(self, state): @global_compiler_lock def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False, inline=False, fastmath=False, nvvm_options=None, - cc=None): + cc=None, max_registers=None, lto=False): if cc is None: raise ValueError('Compute Capability must be supplied') @@ -189,6 +216,8 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False, if nvvm_options: flags.nvvm_options = nvvm_options flags.compute_capability = cc + flags.max_registers = max_registers + flags.lto = lto # Run compilation pipeline from numba.core.target_extension import target_override @@ -247,11 +276,155 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name, builder, func, restype, argtypes, callargs) builder.ret(return_value) + if config.DUMP_LLVM: + utils.dump_llvm(fndesc, wrapper_module) + library.add_ir_module(wrapper_module) library.finalize() return library +def kernel_fixup(kernel, debug): + if debug: + exc_helper = add_exception_store_helper(kernel) + + # Pass 1 - replace: + # + # ret + # + # with: + # + # exc_helper() + # ret void + + for block in kernel.blocks: + for i, inst in enumerate(block.instructions): + if isinstance(inst, ir.Ret): + old_ret = block.instructions.pop() + block.terminator = None + + # The original return's metadata will be set on the new + # instructions in order to preserve debug info + metadata = old_ret.metadata + + builder = ir.IRBuilder(block) + if debug: + status_code = old_ret.operands[0] + exc_helper_call = builder.call(exc_helper, (status_code,)) + exc_helper_call.metadata = metadata + + new_ret = builder.ret_void() + new_ret.metadata = old_ret.metadata + + # Need to break out so we don't carry on modifying what we are + # iterating over. There can only be one return in a block + # anyway. + break + + # Pass 2: remove stores of null pointer to return value argument pointer + + return_value = kernel.args[0] + + for block in kernel.blocks: + remove_list = [] + + # Find all stores first + for inst in block.instructions: + if (isinstance(inst, ir.StoreInstr) + and inst.operands[1] == return_value): + remove_list.append(inst) + + # Remove all stores + for to_remove in remove_list: + block.instructions.remove(to_remove) + + # Replace non-void return type with void return type and remove return + # value + + if isinstance(kernel.type, ir.PointerType): + new_type = ir.PointerType(ir.FunctionType(ir.VoidType(), + kernel.type.pointee.args[1:])) + else: + new_type = ir.FunctionType(ir.VoidType(), kernel.type.args[1:]) + + kernel.type = new_type + kernel.return_value = ir.ReturnValue(kernel, ir.VoidType()) + kernel.args = kernel.args[1:] + + # Mark as a kernel for NVVM + + nvvm.set_cuda_kernel(kernel) + + if config.DUMP_LLVM: + print(f"LLVM DUMP: Post kernel fixup {kernel.name}".center(80, '-')) + print(kernel.module) + print('=' * 80) + + +def add_exception_store_helper(kernel): + + # Create global variables for exception state + + def define_error_gv(postfix): + name = kernel.name + postfix + gv = cgutils.add_global_variable(kernel.module, ir.IntType(32), + name) + gv.initializer = ir.Constant(gv.type.pointee, None) + return gv + + gv_exc = define_error_gv("__errcode__") + gv_tid = [] + gv_ctaid = [] + for i in 'xyz': + gv_tid.append(define_error_gv("__tid%s__" % i)) + gv_ctaid.append(define_error_gv("__ctaid%s__" % i)) + + # Create exception store helper function + + helper_name = kernel.name + "__exc_helper__" + helper_type = ir.FunctionType(ir.VoidType(), (ir.IntType(32),)) + helper_func = ir.Function(kernel.module, helper_type, helper_name) + + block = helper_func.append_basic_block(name="entry") + builder = ir.IRBuilder(block) + + # Implement status check / exception store logic + + status_code = helper_func.args[0] + call_conv = cuda_target.target_context.call_conv + status = call_conv._get_return_status(builder, status_code) + + # Check error status + with cgutils.if_likely(builder, status.is_ok): + builder.ret_void() + + with builder.if_then(builder.not_(status.is_python_exc)): + # User exception raised + old = ir.Constant(gv_exc.type.pointee, None) + + # Use atomic cmpxchg to prevent rewriting the error status + # Only the first error is recorded + + xchg = builder.cmpxchg(gv_exc, old, status.code, + 'monotonic', 'monotonic') + changed = builder.extract_value(xchg, 1) + + # If the xchange is successful, save the thread ID. + sreg = nvvmutils.SRegBuilder(builder) + with builder.if_then(changed): + for dim, ptr, in zip("xyz", gv_tid): + val = sreg.tid(dim) + builder.store(val, ptr) + + for dim, ptr, in zip("xyz", gv_ctaid): + val = sreg.ctaid(dim) + builder.store(val, ptr) + + builder.ret_void() + + return helper_func + + @global_compiler_lock def compile(pyfunc, sig, debug=None, lineinfo=False, device=True, fastmath=False, cc=None, opt=None, abi="c", abi_info=None, @@ -347,13 +520,10 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True, lib = cabi_wrap_function(tgt, lib, cres.fndesc, wrapper_name, nvvm_options) else: - code = pyfunc.__code__ - filename = code.co_filename - linenum = code.co_firstlineno - - lib, kernel = tgt.prepare_cuda_kernel(cres.library, cres.fndesc, debug, - lineinfo, nvvm_options, filename, - linenum) + lib = cres.library + kernel = lib.get_function(cres.fndesc.llvm_func_name) + lib._entry_name = cres.fndesc.llvm_func_name + kernel_fixup(kernel, debug) if lto: code = lib.get_ltoir(cc=cc) diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index a483119..d4f28fc 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -14,7 +14,7 @@ from numba.cuda.api import get_current_device from numba.cuda.args import wrap_arg -from numba.cuda.compiler import compile_cuda, CUDACompiler +from numba.cuda.compiler import compile_cuda, CUDACompiler, kernel_fixup from numba.cuda.cudadrv import driver from numba.cuda.cudadrv.devices import get_context from numba.cuda.descriptor import cuda_target @@ -102,15 +102,14 @@ def __init__(self, py_func, argtypes, link=None, debug=False, inline=inline, fastmath=fastmath, nvvm_options=nvvm_options, - cc=cc) + cc=cc, + max_registers=max_registers, + lto=lto) tgt_ctx = cres.target_context - code = self.py_func.__code__ - filename = code.co_filename - linenum = code.co_firstlineno - lib, kernel = tgt_ctx.prepare_cuda_kernel(cres.library, cres.fndesc, - debug, lineinfo, nvvm_options, - filename, linenum, - max_registers, lto) + lib = cres.library + kernel = lib.get_function(cres.fndesc.llvm_func_name) + lib._entry_name = cres.fndesc.llvm_func_name + kernel_fixup(kernel, self.debug) if not link: link = [] diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py b/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py index efe42b5..44a2020 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py @@ -72,6 +72,7 @@ def test_issue_5835(self): def f(x): x[0] = 0 + @unittest.skip("Wrappers no longer exist") def test_wrapper_has_debuginfo(self): sig = (types.int32[::1],) diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py b/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py index 20d6792..5c122db 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py @@ -33,10 +33,7 @@ def foo(x, y): self.assertIn("foo", llvm) # Kernel in LLVM - self.assertIn('cuda.kernel.wrapper', llvm) - - # Wrapped device function body in LLVM - self.assertIn("define linkonce_odr i32", llvm) + self.assertIn("define void @", llvm) asm = foo.inspect_asm(sig) @@ -72,12 +69,8 @@ def foo(x, y): self.assertIn("foo", llvmirs[float64, float64]) # Kernels in LLVM - self.assertIn('cuda.kernel.wrapper', llvmirs[intp, intp]) - self.assertIn('cuda.kernel.wrapper', llvmirs[float64, float64]) - - # Wrapped device function bodies in LLVM - self.assertIn("define linkonce_odr i32", llvmirs[intp, intp]) - self.assertIn("define linkonce_odr i32", llvmirs[float64, float64]) + self.assertIn("define void @", llvmirs[intp, intp]) + self.assertIn("define void @", llvmirs[float64, float64]) asmdict = foo.inspect_asm() diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py b/numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py index b509387..182873b 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py @@ -170,10 +170,9 @@ def caller(x): subprograms += 1 # One DISubprogram for each of: - # - The kernel wrapper # - The caller # - The callee - expected_subprograms = 3 + expected_subprograms = 2 self.assertEqual(subprograms, expected_subprograms, f'"Expected {expected_subprograms} DISubprograms; ' diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_optimization.py b/numba_cuda/numba/cuda/tests/cudapy/test_optimization.py index 812b1cf..2739972 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_optimization.py @@ -14,8 +14,11 @@ def device_func(x, y, z): # Fragments of code that are removed from kernel_func's PTX when optimization -# is on -removed_by_opt = ( '__local_depot0', 'call.uni', 'st.param.b64') +# is on. Previously this list was longer when kernel wrappers were used - if +# the test function were more complex it may be possible to isolate additional +# fragments of PTX we could check for the absence / presence of, but removal of +# the use of local memory is a good indicator that optimization was applied. +removed_by_opt = ( '__local_depot0',) @skip_on_cudasim('Simulator does not optimize code')