From c474146e78c59cda25aa9464a2670c5ec2d5423f Mon Sep 17 00:00:00 2001 From: Michael Yh Wang Date: Tue, 27 Aug 2024 09:00:04 +0800 Subject: [PATCH] style --- numba_cuda/numba/cuda/codegen.py | 7 +++++-- numba_cuda/numba/cuda/cudadrv/nvrtc.py | 11 ++++++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/numba_cuda/numba/cuda/codegen.py b/numba_cuda/numba/cuda/codegen.py index 4fd9c89..1087f37 100644 --- a/numba_cuda/numba/cuda/codegen.py +++ b/numba_cuda/numba/cuda/codegen.py @@ -192,7 +192,11 @@ def get_cubin(self, cc=None): return cubin if config.DUMP_ASSEMBLY: - linker = driver.Linker.new(max_registers=self._max_registers, cc=cc, additional_flags=["-ptx"]) + linker = driver.Linker.new( + max_registers=self._max_registers, + cc=cc, + additional_flags=["-ptx"] + ) self._link_all(linker, cc) ptx = linker.get_linked_ptx().decode('utf-8') @@ -201,7 +205,6 @@ def get_cubin(self, cc=None): print(ptx) print('=' * 80) - linker = driver.Linker.new(max_registers=self._max_registers, cc=cc) self._link_all(linker, cc) cubin = linker.complete() diff --git a/numba_cuda/numba/cuda/cudadrv/nvrtc.py b/numba_cuda/numba/cuda/cudadrv/nvrtc.py index 4c19e5a..9b14b6f 100644 --- a/numba_cuda/numba/cuda/cudadrv/nvrtc.py +++ b/numba_cuda/numba/cuda/cudadrv/nvrtc.py @@ -90,7 +90,9 @@ class NVRTC: # nvrtcResult nvrtcCompileProgram(nvrtcProgram prog, # int numOptions, # const char * const *options) - "nvrtcCompileProgram": (nvrtc_result, nvrtc_program, c_int, POINTER(c_char_p)), + "nvrtcCompileProgram": ( + nvrtc_result, nvrtc_program, c_int, POINTER(c_char_p) + ), # nvrtcResult nvrtcGetPTXSize(nvrtcProgram prog, size_t *ptxSizeRet); "nvrtcGetPTXSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)), # nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx); @@ -106,7 +108,9 @@ class NVRTC: "nvrtcGetCUBIN": (nvrtc_result, nvrtc_program, c_char_p), # nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog, # size_t *logSizeRet); - "nvrtcGetProgramLogSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)), + "nvrtcGetProgramLogSize": ( + nvrtc_result, nvrtc_program, POINTER(c_size_t) + ), # nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log); "nvrtcGetProgramLog": (nvrtc_result, nvrtc_program, c_char_p), } @@ -142,7 +146,8 @@ def checked_call(*args, func=func, name=name): error_name = NvrtcResult(error).name except ValueError: error_name = ( - "Unknown nvrtc_result " f"(error code: {error})" + "Unknown nvrtc_result " + f"(error code: {error})" ) msg = f"Failed to call {name}: {error_name}" raise NvrtcError(msg)