-
Notifications
You must be signed in to change notification settings - Fork 921
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use
pynvjitlink
for CUDA 12+ MVC (#13650)
Fixes #12822 This PR provides minor version compatibility in the CUDA 12.x range through `nvjitlink` via the preliminary [nvjiitlink python binding](https://github.com/gmarkall/nvjitlink). Thus far this PR merely leverages a local installation of the library and should not be merged until `nvjitlink` is hosted on `conda-forge` and cuDF's dependencies are adjusted accordingly, likely as part of this PR. Authors: - https://github.com/brandon-b-miller - Ashwin Srinath (https://github.com/shwina) Approvers: - Bradley Dice (https://github.com/bdice) - Ashwin Srinath (https://github.com/shwina) URL: #13650
- Loading branch information
1 parent
3ef13d0
commit 823d321
Showing
3 changed files
with
128 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
import subprocess | ||
import sys | ||
|
||
import pytest | ||
|
||
IS_CUDA_11 = False | ||
IS_CUDA_12 = False | ||
try: | ||
from ptxcompiler.patch import safe_get_versions | ||
except ModuleNotFoundError: | ||
from cudf.utils._ptxcompiler import safe_get_versions | ||
|
||
# do not test cuda 12 if pynvjitlink isn't present | ||
HAVE_PYNVJITLINK = False | ||
try: | ||
import pynvjitlink # noqa: F401 | ||
|
||
HAVE_PYNVJITLINK = True | ||
except ModuleNotFoundError: | ||
pass | ||
|
||
|
||
versions = safe_get_versions() | ||
driver_version, runtime_version = versions | ||
|
||
if (11, 0) <= driver_version < (12, 0): | ||
IS_CUDA_11 = True | ||
if (12, 0) <= driver_version < (13, 0): | ||
IS_CUDA_12 = True | ||
|
||
|
||
TEST_BODY = """ | ||
@numba.cuda.jit | ||
def test_kernel(x): | ||
id = numba.cuda.grid(1) | ||
if id < len(x): | ||
x[id] += 1 | ||
s = cudf.Series([1, 2, 3]) | ||
with _CUDFNumbaConfig(): | ||
test_kernel.forall(len(s))(s) | ||
""" | ||
|
||
CUDA_11_TEST = ( | ||
""" | ||
import numba.cuda | ||
import cudf | ||
from cudf.utils._numba import _CUDFNumbaConfig, patch_numba_linker_cuda_11 | ||
patch_numba_linker_cuda_11() | ||
""" | ||
+ TEST_BODY | ||
) | ||
|
||
|
||
CUDA_12_TEST = ( | ||
""" | ||
import numba.cuda | ||
import cudf | ||
from cudf.utils._numba import _CUDFNumbaConfig | ||
from pynvjitlink.patch import ( | ||
patch_numba_linker as patch_numba_linker_pynvjitlink, | ||
) | ||
patch_numba_linker_pynvjitlink() | ||
""" | ||
+ TEST_BODY | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test", | ||
[ | ||
pytest.param( | ||
CUDA_11_TEST, | ||
marks=pytest.mark.skipif( | ||
not IS_CUDA_11, | ||
reason="Minor Version Compatibility test for CUDA 11", | ||
), | ||
), | ||
pytest.param( | ||
CUDA_12_TEST, | ||
marks=pytest.mark.skipif( | ||
not IS_CUDA_12 or not HAVE_PYNVJITLINK, | ||
reason="Minor Version Compatibility test for CUDA 12", | ||
), | ||
), | ||
], | ||
) | ||
def test_numba_mvc(test): | ||
cp = subprocess.run( | ||
[sys.executable, "-c", test], | ||
capture_output=True, | ||
cwd="/", | ||
) | ||
|
||
assert cp.returncode == 0 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters