Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
isVoid committed Aug 27, 2024
1 parent e50e8d8 commit c474146
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
7 changes: 5 additions & 2 deletions numba_cuda/numba/cuda/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,11 @@ def get_cubin(self, cc=None):
return cubin

if config.DUMP_ASSEMBLY:
linker = driver.Linker.new(max_registers=self._max_registers, cc=cc, additional_flags=["-ptx"])
linker = driver.Linker.new(
max_registers=self._max_registers,
cc=cc,
additional_flags=["-ptx"]
)
self._link_all(linker, cc)
ptx = linker.get_linked_ptx().decode('utf-8')

Expand All @@ -201,7 +205,6 @@ def get_cubin(self, cc=None):
print(ptx)
print('=' * 80)


linker = driver.Linker.new(max_registers=self._max_registers, cc=cc)
self._link_all(linker, cc)
cubin = linker.complete()
Expand Down
11 changes: 8 additions & 3 deletions numba_cuda/numba/cuda/cudadrv/nvrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ class NVRTC:
# nvrtcResult nvrtcCompileProgram(nvrtcProgram prog,
# int numOptions,
# const char * const *options)
"nvrtcCompileProgram": (nvrtc_result, nvrtc_program, c_int, POINTER(c_char_p)),
"nvrtcCompileProgram": (
nvrtc_result, nvrtc_program, c_int, POINTER(c_char_p)
),
# nvrtcResult nvrtcGetPTXSize(nvrtcProgram prog, size_t *ptxSizeRet);
"nvrtcGetPTXSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
# nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx);
Expand All @@ -106,7 +108,9 @@ class NVRTC:
"nvrtcGetCUBIN": (nvrtc_result, nvrtc_program, c_char_p),
# nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog,
# size_t *logSizeRet);
"nvrtcGetProgramLogSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
"nvrtcGetProgramLogSize": (
nvrtc_result, nvrtc_program, POINTER(c_size_t)
),
# nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
"nvrtcGetProgramLog": (nvrtc_result, nvrtc_program, c_char_p),
}
Expand Down Expand Up @@ -142,7 +146,8 @@ def checked_call(*args, func=func, name=name):
error_name = NvrtcResult(error).name
except ValueError:
error_name = (
"Unknown nvrtc_result " f"(error code: {error})"
"Unknown nvrtc_result "
f"(error code: {error})"
)
msg = f"Failed to call {name}: {error_name}"
raise NvrtcError(msg)
Expand Down

0 comments on commit c474146

Please sign in to comment.