Skip to content

Commit

Permalink
[Hardware][NVIDIA] Add non-NVML CUDA mode for Jetson (vllm-project#9735)
Browse files Browse the repository at this point in the history
Signed-off-by: Conroy Cheers <[email protected]>
  • Loading branch information
conroy-cheers authored and weilong.yu committed Dec 13, 2024
1 parent 35bcd10 commit 36bf87e
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 87 deletions.
10 changes: 5 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")

# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")

# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
Expand Down Expand Up @@ -249,7 +249,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.9;9.0" ${CUDA_ARCHS})
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})
if (MARLIN_ARCHS)
set(MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
Expand Down Expand Up @@ -300,8 +300,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.9;9.0" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
Expand Down Expand Up @@ -427,7 +427,7 @@ set_gencode_flags_for_srcs(
CUDA_ARCHS "${CUDA_ARCHS}")

if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.9;9.0" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
set(MARLIN_MOE_SRC
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
Expand Down
10 changes: 9 additions & 1 deletion vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@
finally:
pynvml.nvmlShutdown()
except Exception:
pass
# CUDA is supported on Jetson, but NVML may not be.
import os

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
is_cuda = True

is_rocm = False

Expand Down
222 changes: 141 additions & 81 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from functools import lru_cache, wraps
from typing import TYPE_CHECKING, Callable, List, Tuple, TypeVar
from typing import TYPE_CHECKING, Callable, List, TypeVar

import pynvml
import torch
Expand Down Expand Up @@ -38,10 +38,23 @@
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA

def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
if device_ids == [""]:
msg = (
"CUDA_VISIBLE_DEVICES is set to empty string, which means"
" GPU support is disabled. If you are using ray, please unset"
" the environment variable `CUDA_VISIBLE_DEVICES` inside the"
" worker/actor. "
"Check https://github.com/vllm-project/vllm/issues/8402 for"
" more information.")
raise RuntimeError(msg)
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id


def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
Expand All @@ -57,87 +70,75 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
return wrapper


@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)


@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_name(device_id: int = 0) -> str:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetName(handle)


@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_total_memory(device_id: int = 0) -> int:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)

class CudaPlatformBase(Platform):
_enum = PlatformEnum.CUDA
device_type: str = "cuda"
dispatch_key: str = "CUDA"

@with_nvml_context
def warn_if_different_devices():
device_ids: int = pynvml.nvmlDeviceGetCount()
if device_ids > 1:
device_names = [get_physical_device_name(i) for i in range(device_ids)]
if len(set(device_names)) > 1 and os.environ.get(
"CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
logger.warning(
"Detected different devices in the system: \n%s\nPlease"
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
"avoid unexpected behavior.", "\n".join(device_names))
@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
raise NotImplementedError

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError

try:
from sphinx.ext.autodoc.mock import _MockModule
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError

if not isinstance(pynvml, _MockModule):
warn_if_different_devices()
except ModuleNotFoundError:
warn_if_different_devices()
@classmethod
def is_full_nvlink(cls, device_ids: List[int]) -> bool:
raise NotImplementedError

@classmethod
def log_warnings(cls):
pass

def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
if device_ids == [""]:
msg = (
"CUDA_VISIBLE_DEVICES is set to empty string, which means"
" GPU support is disabled. If you are using ray, please unset"
" the environment variable `CUDA_VISIBLE_DEVICES` inside the"
" worker/actor. "
"Check https://github.com/vllm-project/vllm/issues/8402 for"
" more information.")
raise RuntimeError(msg)
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker"
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"


class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA
device_type: str = "cuda"
dispatch_key: str = "CUDA"
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):

@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
physical_device_id = device_id_to_physical_device_id(device_id)
major, minor = get_physical_device_capability(physical_device_id)
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
return DeviceCapability(major=major, minor=minor)

@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_name(physical_device_id)
return cls._get_physical_device_name(physical_device_id)

@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_total_memory(cls, device_id: int = 0) -> int:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_total_memory(physical_device_id)
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)

@classmethod
@with_nvml_context
Expand All @@ -153,27 +154,86 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle,
pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
handle,
peer_handle,
pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.")
"NVLink detection failed. This is normal if"
" your machine has no NVLink equipped.")
return False
return True

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker"
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
def _get_physical_device_name(cls, device_id: int = 0) -> str:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetName(handle)

@classmethod
@with_nvml_context
def log_warnings(cls):
device_ids: int = pynvml.nvmlDeviceGetCount()
if device_ids > 1:
device_names = [
cls._get_physical_device_name(i) for i in range(device_ids)
]
if (len(set(device_names)) > 1
and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
logger.warning(
"Detected different devices in the system: \n%s\nPlease"
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
"avoid unexpected behavior.",
"\n".join(device_names),
)


class NonNvmlCudaPlatform(CudaPlatformBase):

@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.cuda.get_device_properties(device_id)
return device_props.total_memory

@classmethod
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
logger.exception(
"NVLink detection not possible, as context support was"
" not found. Assuming no NVLink available.")
return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
try:
pynvml.nvmlInit()
nvml_available = True
except Exception:
# On Jetson, NVML is not supported.
nvml_available = False
finally:
if nvml_available:
pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

try:
from sphinx.ext.autodoc.mock import _MockModule

if not isinstance(pynvml, _MockModule):
CudaPlatform.log_warnings()
except ModuleNotFoundError:
CudaPlatform.log_warnings()

0 comments on commit 36bf87e

Please sign in to comment.