Skip to content

Commit

Permalink
Merge pull request #1 from isVoid/numba-cuda-runtime
Browse files Browse the repository at this point in the history
Determine conda include path based on machine kind
  • Loading branch information
brandon-b-miller authored Oct 28, 2024
2 parents 5833c87 + c5b7df4 commit f882ea1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
30 changes: 29 additions & 1 deletion numba_cuda/numba/cuda/cuda_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import os
from collections import namedtuple
import platform

from numba.core.config import IS_WIN32
from numba.misc.findlib import find_lib, find_file
Expand Down Expand Up @@ -259,14 +260,41 @@ def get_debian_pkg_libdevice():
return pkg_libdevice_location


def get_current_cuda_target_name():
"""Determine conda's CTK target folder based on system and machine arch.
CTK's conda package delivers headers based on its architecture type. For example,
`x86_64` machine places header under `$CONDA_PREFIX/targets/x86_64-linux`, and
`aarch64` places under `$CONDA_PREFIX/targets/sbsa-linux`. Read more about the
nuances at cudart's conda feedstock:
https://github.com/conda-forge/cuda-cudart-feedstock/blob/main/recipe/meta.yaml#L8-L11 # noqa: E501
"""
system = platform.system()
machine = platform.machine()

if system == "Linux":
arch_to_targets = {
'x86_64': 'x86_64-linux',
'aarch64': 'sbsa-linux'
}
return arch_to_targets.get(machine)

return None

def get_conda_include_dir():
"""
Return the include directory in the current conda environment, if one
is active and it exists.
"""
conda_prefix = os.environ.get('CONDA_PREFIX')
target_name = get_current_cuda_target_name()

if conda_prefix:
include_dir = os.path.join(conda_prefix, 'include')
if target_name:
include_dir = os.path.join(conda_prefix, f'targets/{target_name}/include')
else:
# A fallback when target cannot determined, though usually it shouldn't.
include_dir = os.path.join(conda_prefix, f'include')
if os.path.exists(include_dir):
return include_dir
return None
Expand Down
9 changes: 5 additions & 4 deletions numba_cuda/numba/cuda/cudadrv/nvrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from numba.core import config
from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError,
NvrtcSupportError)
from numba.cuda.cuda_paths import _get_include_dir
from numba.cuda.cuda_paths import get_cuda_paths
import functools
import os
import threading
Expand Down Expand Up @@ -233,12 +233,13 @@ def compile(src, name, cc):
# being optimized away.
major, minor = cc
arch = f'--gpu-architecture=compute_{major}{minor}'
include = f'-I{config.CUDA_INCLUDE_PATH}'

cuda_include = f"-I{get_cuda_paths()['include_dir'].info}"

cudadrv_path = os.path.dirname(os.path.abspath(__file__))
numba_cuda_path = os.path.dirname(cudadrv_path)
numba_include = f'-I{numba_cuda_path} -I{_get_include_dir()}'
options = [arch, include, numba_include, '-rdc', 'true']
numba_include = f'-I{numba_cuda_path}'
options = [arch, cuda_include, numba_include, '-rdc', 'true']

# Compile the program
compile_error = nvrtc.compile_program(program, options)
Expand Down

0 comments on commit f882ea1

Please sign in to comment.