From c5b7df4495b0ce98e71b98e466d27bfa463e5767 Mon Sep 17 00:00:00 2001 From: isVoid Date: Sun, 27 Oct 2024 22:40:47 -0700 Subject: [PATCH] determine cuda include paths based on the machine kind --- numba_cuda/numba/cuda/cuda_paths.py | 30 +++++++++++++++++++++++++- numba_cuda/numba/cuda/cudadrv/nvrtc.py | 9 ++++---- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/numba_cuda/numba/cuda/cuda_paths.py b/numba_cuda/numba/cuda/cuda_paths.py index ac5475a..a45c175 100644 --- a/numba_cuda/numba/cuda/cuda_paths.py +++ b/numba_cuda/numba/cuda/cuda_paths.py @@ -2,6 +2,7 @@ import re import os from collections import namedtuple +import platform from numba.core.config import IS_WIN32 from numba.misc.findlib import find_lib, find_file @@ -259,14 +260,41 @@ def get_debian_pkg_libdevice(): return pkg_libdevice_location +def get_current_cuda_target_name(): + """Determine conda's CTK target folder based on system and machine arch. + + CTK's conda package delivers headers based on its architecture type. For example, + `x86_64` machine places header under `$CONDA_PREFIX/targets/x86_64-linux`, and + `aarch64` places under `$CONDA_PREFIX/targets/sbsa-linux`. Read more about the + nuances at cudart's conda feedstock: + https://github.com/conda-forge/cuda-cudart-feedstock/blob/main/recipe/meta.yaml#L8-L11 # noqa: E501 + """ + system = platform.system() + machine = platform.machine() + + if system == "Linux": + arch_to_targets = { + 'x86_64': 'x86_64-linux', + 'aarch64': 'sbsa-linux' + } + return arch_to_targets.get(machine) + + return None + def get_conda_include_dir(): """ Return the include directory in the current conda environment, if one is active and it exists. """ conda_prefix = os.environ.get('CONDA_PREFIX') + target_name = get_current_cuda_target_name() + if conda_prefix: - include_dir = os.path.join(conda_prefix, 'include') + if target_name: + include_dir = os.path.join(conda_prefix, f'targets/{target_name}/include') + else: + # A fallback when target cannot determined, though usually it shouldn't. + include_dir = os.path.join(conda_prefix, f'include') if os.path.exists(include_dir): return include_dir return None diff --git a/numba_cuda/numba/cuda/cudadrv/nvrtc.py b/numba_cuda/numba/cuda/cudadrv/nvrtc.py index 2483b4a..d6d64bd 100644 --- a/numba_cuda/numba/cuda/cudadrv/nvrtc.py +++ b/numba_cuda/numba/cuda/cudadrv/nvrtc.py @@ -3,7 +3,7 @@ from numba.core import config from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError, NvrtcSupportError) -from numba.cuda.cuda_paths import _get_include_dir +from numba.cuda.cuda_paths import get_cuda_paths import functools import os import threading @@ -233,12 +233,13 @@ def compile(src, name, cc): # being optimized away. major, minor = cc arch = f'--gpu-architecture=compute_{major}{minor}' - include = f'-I{config.CUDA_INCLUDE_PATH}' + + cuda_include = f"-I{get_cuda_paths()['include_dir'].info}" cudadrv_path = os.path.dirname(os.path.abspath(__file__)) numba_cuda_path = os.path.dirname(cudadrv_path) - numba_include = f'-I{numba_cuda_path} -I{_get_include_dir()}' - options = [arch, include, numba_include, '-rdc', 'true'] + numba_include = f'-I{numba_cuda_path}' + options = [arch, cuda_include, numba_include, '-rdc', 'true'] # Compile the program compile_error = nvrtc.compile_program(program, options)