Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][NVIDIA] Add non-NVML CUDA mode for Jetson #9735

Merged
merged 12 commits into from
Nov 26, 2024
Merged
10 changes: 5 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")

# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")

# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
Expand Down Expand Up @@ -234,7 +234,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.9;9.0" ${CUDA_ARCHS})
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})
if (MARLIN_ARCHS)
set(MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
Expand Down Expand Up @@ -285,8 +285,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.9;9.0" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
Expand Down Expand Up @@ -412,7 +412,7 @@ set_gencode_flags_for_srcs(
CUDA_ARCHS "${CUDA_ARCHS}")

if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.9;9.0" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
set(MARLIN_MOE_SRC
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
Expand Down
10 changes: 9 additions & 1 deletion vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@
finally:
pynvml.nvmlShutdown()
except Exception:
pass
# CUDA is supported on Jetson, but NVML may not be.
import os

def cuda_is_jetson() -> bool:
return os.path.isfile("/etc/nv_tegra_release") \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to check with nvidia folks, how robust it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check came from this thread:
rapidsai/dask-cuda#400 (comment)

or os.path.exists("/sys/class/tegra-firmware")

if cuda_is_jetson():
is_cuda = True

is_rocm = False

Expand Down
222 changes: 149 additions & 73 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
"""

import os
from functools import lru_cache, wraps
from typing import Callable, List, Tuple, TypeVar
from collections.abc import Iterator
from contextlib import contextmanager
from functools import lru_cache
from typing import List, Tuple, TypeVar

import pynvml
import torch
Expand All @@ -31,67 +33,6 @@
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA


def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:

@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()

return wrapper


@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)


@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_name(device_id: int = 0) -> str:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetName(handle)


@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_total_memory(device_id: int = 0) -> int:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)


@with_nvml_context
def warn_if_different_devices():
device_ids: int = pynvml.nvmlDeviceGetCount()
if device_ids > 1:
device_names = [get_physical_device_name(i) for i in range(device_ids)]
if len(set(device_names)) > 1 and os.environ.get(
"CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
logger.warning(
"Detected different devices in the system: \n%s\nPlease"
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
"avoid unexpected behavior.", "\n".join(device_names))


try:
from sphinx.ext.autodoc.mock import _MockModule

if not isinstance(pynvml, _MockModule):
warn_if_different_devices()
except ModuleNotFoundError:
warn_if_different_devices()


def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
Expand All @@ -105,27 +46,51 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return device_id


class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA
class BaseContext:

@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
def get_device_capability(cls, device_id: int = 0) -> Tuple[int, int]:
raise NotImplementedError

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError

@classmethod
def is_full_nvlink(cls, device_ids: List[int]) -> bool:
raise NotImplementedError

@classmethod
def log_warnings(cls):
pass


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NVMLContext(BaseContext):

@classmethod
def get_device_capability(cls, device_id: int = 0) -> Tuple[int, int]:
physical_device_id = device_id_to_physical_device_id(device_id)
major, minor = get_physical_device_capability(physical_device_id)
return DeviceCapability(major=major, minor=minor)
return cls._get_physical_device_capability(physical_device_id)

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_name(physical_device_id)
return cls._get_physical_device_name(physical_device_id)

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_total_memory(physical_device_id)
return cls._get_physical_device_total_memory(physical_device_id)

@classmethod
@with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Expand All @@ -144,7 +109,118 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.")
"NVLink detection failed. This is normal if"
" your machine has no NVLink equipped.")
return False
return True

@classmethod
@lru_cache(maxsize=8)
def _get_physical_device_capability(cls,
device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)

@classmethod
@lru_cache(maxsize=8)
def _get_physical_device_name(cls, device_id: int = 0) -> str:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetName(handle)

@classmethod
@lru_cache(maxsize=8)
def _get_physical_device_total_memory(cls, device_id: int = 0) -> int:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)

@classmethod
def log_warnings(cls):
device_ids: int = pynvml.nvmlDeviceGetCount()
if device_ids > 1:
device_names = [
cls._get_physical_device_name(i) for i in range(device_ids)
]
if len(set(device_names)) > 1 and os.environ.get(
"CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
logger.warning(
"Detected different devices in the system: \n%s\nPlease"
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
"avoid unexpected behavior.", "\n".join(device_names))


class NonNVMLContext(BaseContext):

@classmethod
def get_device_capability(cls, device_id: int = 0) -> Tuple[int, int]:
return torch.cuda.get_device_capability(device_id)

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.cuda.get_device_properties(device_id)
return device_props.total_memory

@classmethod
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
logger.exception(
"NVLink detection not possible, as context support was"
" not found. Assuming no NVLink available.")
return False


@contextmanager
def get_context() -> Iterator[BaseContext]:
nvml_init_ok = False
try:
try:
pynvml.nvmlInit()
nvml_init_ok = True
yield NVMLContext()
except Exception:
# On Jetson, NVML is not supported.
yield NonNVMLContext()
finally:
if nvml_init_ok:
pynvml.nvmlShutdown()


try:
from sphinx.ext.autodoc.mock import _MockModule

if not isinstance(pynvml, _MockModule):
with get_context() as context:
context.log_warnings()
except ModuleNotFoundError:
with get_context() as context:
context.log_warnings()


class CudaPlatform(Platform):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have NvmlCudaPlatform and NonNvmlCudaPlatform inheriting from Platform, and in the end of this file, based on jetson or not, define a variable CudaPlatform to point to either NvmlCudaPlatform or NonNvmlCudaPlatform .

_enum = PlatformEnum.CUDA

@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
with get_context() as context:
major, minor = context.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
with get_context() as context:
return context.get_device_name(device_id)

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
with get_context() as context:
return context.get_device_total_memory(device_id)

@classmethod
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
with get_context() as context:
return context.is_full_nvlink(physical_device_ids)