Skip to content

Commit

Permalink
Minor fixups following #23 / #56
Browse files Browse the repository at this point in the history
- Update the codegen class docstring for LTO.
- Simplify / correct some logic in `_readenv()` (`value.lower()` could
  never be `"True"`, only `"true"`.
- Simplify additional flags and linker checks.
- Setting `self._linker.complete` in `complete()` is unnecessary, as
  calling `get_linked_cubin()` sets the link as complete already.
  • Loading branch information
gmarkall committed Oct 22, 2024
1 parent 7e01ab0 commit 197c80b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 3 additions & 1 deletion numba_cuda/numba/cuda/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,16 @@ def __init__(
):
"""
codegen:
Codegen object.
Codegen object.
name:
Name of the function in the source.
entry_name:
Name of the kernel function in the binary, if this is a global
kernel and not a device function.
max_registers:
The maximum register usage to aim for when linking.
lto:
Whether to enable link-time optimization.
nvvm_options:
Dict of options to pass to NVVM.
"""
Expand Down
8 changes: 3 additions & 5 deletions numba_cuda/numba/cuda/cudadrv/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _readenv(name, ctor, default):
return default() if callable(default) else default
try:
if ctor is bool:
return bool(value.lower() in {'1', "True"})
return value.lower() in {'1', "true"}
return ctor(value)
except Exception:
warnings.warn(
Expand Down Expand Up @@ -2631,7 +2631,7 @@ def new(cls,

if linker is PyNvJitLinker:
return linker(max_registers, lineinfo, cc, lto, additional_flags)
elif additional_flags is not None or lto is True:
elif additional_flags or lto:
raise ValueError("LTO and additional flags require PyNvJitLinker")
else:
return linker(max_registers, lineinfo, cc)
Expand Down Expand Up @@ -3088,9 +3088,7 @@ def add_data(self, data, kind, name):

def complete(self):
try:
cubin = self._linker.get_linked_cubin()
self._linker._complete = True
return cubin
return self._linker.get_linked_cubin()
except NvJitLinkError as e:
raise LinkerError from e

Expand Down

0 comments on commit 197c80b

Please sign in to comment.