Skip to content

Commit

Permalink
small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller committed Oct 8, 2024
1 parent ff18c5c commit b2f4245
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
4 changes: 1 addition & 3 deletions numba_cuda/numba/cuda/cudadrv/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
system to freeze in some cases.
"""

import sys
import os
import ctypes
Expand Down Expand Up @@ -85,12 +84,11 @@ def _readenv(name, ctor, default):

ENABLE_PYNVJITLINK = (
_readenv("ENABLE_PYNVJITLINK", bool, False)
or getattr(config, "ENABLE_PYNVJITLINK", None)
or getattr(config, "ENABLE_PYNVJITLINK", False)
)
if not hasattr(config, "ENABLE_PYNVJITLINK"):
config.ENABLE_PYNVJITLINK = ENABLE_PYNVJITLINK


if ENABLE_PYNVJITLINK:
try:
from pynvjitlink.api import NvJitLinker, NvJitLinkError
Expand Down
16 changes: 3 additions & 13 deletions numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,8 @@
from numba import cuda
from numba import config

HAVE_PYNVJITLINK = False
try:
import pynvjitlink # noqa: F401
from pynvjitlink.api import NvJitLinkError

HAVE_PYNVJITLINK = True
except ImportError:
pass


@unittest.skipIf(not HAVE_PYNVJITLINK, "pynvjitlink not available")
@unittest.skipIf(config.ENABLE_PYNVJITLINK, "pynvjitlink not enabled")
@skip_on_cudasim("Linking unsupported in the simulator")
class TestLinker(CUDATestCase):
_NUMBA_NVIDIA_BINDING_0_ENV = {"NUMBA_CUDA_USE_NVIDIA_BINDING": "0"}
Expand All @@ -35,6 +26,8 @@ def test_nvjitlink_create_no_cc_error(self):
PyNvJitLinker()

def test_nvjitlink_invalid_arch_error(self):
from pynvjitlink.api import NvJitLinkError

# CC 0.0 is not a valid compute capability
with self.assertRaisesRegex(
NvJitLinkError, "NVJITLINK_ERROR_UNRECOGNIZED_OPTION error"
Expand Down Expand Up @@ -126,7 +119,6 @@ def test_nvjitlink_test_add_file_guess_ext_invalid_input(self):
# because there's no way to know what kind of file to treat it as
patched_linker.add_file_guess_ext(content)

@unittest.skipIf(not HAVE_PYNVJITLINK, "pynvjitlink not available")
def test_nvjitlink_jit_with_linkable_code(self):
files = (
"test_device_functions.a",
Expand All @@ -138,8 +130,6 @@ def test_nvjitlink_jit_with_linkable_code(self):
)
for file in files:
with self.subTest(file=file):
# TODO: unsafe teardown if test errors
config.ENABLE_PYNVJITLINK = True
sig = "uint32(uint32, uint32)"
add_from_numba = cuda.declare_device("add_from_numba", sig)

Expand Down

0 comments on commit b2f4245

Please sign in to comment.