Skip to content

Commit

Permalink
Merge pull request #568 from lgleim/vulkan_no_cuda_fix
Browse files Browse the repository at this point in the history
[BUG FIX][FEATURE] Rework backend & device selection logic
  • Loading branch information
zswang666 authored Jan 22, 2025
2 parents 93d5945 + 50ebafd commit 806d0a8
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
18 changes: 7 additions & 11 deletions genesis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def fake_print(*args, **kwargs):
from .constants import backend as gs_backend
from .logging import Logger
from .version import __version__
from .utils import set_random_seed, get_platform, get_cpu_device, get_gpu_device
from .utils import set_random_seed, get_platform, get_device

_initialized = False
backend = None
Expand All @@ -43,7 +43,6 @@ def init(
theme="dark",
logger_verbose_time=False,
):

# genesis._initialized
global _initialized
if _initialized:
Expand Down Expand Up @@ -72,19 +71,16 @@ def init(

first_init = False

# get default device and compute total device memory
# genesis.backend
global platform
global device
platform = get_platform()
if backend == gs_backend.cpu:
device, device_name, total_mem = get_cpu_device()
else:
device, device_name, total_mem = get_gpu_device()

# genesis.backend
if backend not in GS_ARCH[platform]:
raise_exception(f"backend ~~<{backend}>~~ not supported for platform ~~<{platform}>~~")
backend = GS_ARCH[platform][backend]

# get default device and compute total device memory
global device
device, device_name, total_mem, backend = get_device(backend)

_globalize_backend(backend)

logger.info(
Expand Down
48 changes: 32 additions & 16 deletions genesis/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

import genesis as gs
from genesis.constants import backend as gs_backend


def raise_exception(msg="Something went wrong."):
Expand Down Expand Up @@ -96,32 +97,47 @@ def get_cpu_name():
return platform.processor()


def get_cpu_device():
device_name = get_cpu_name()
total_mem = psutil.virtual_memory().total / 1024**3
device = torch.device("cpu")
return device, device_name, total_mem
def get_device(backend: gs_backend):
if backend == gs_backend.cuda:
if not torch.cuda.is_available():
gs.raise_exception("cuda device not available")

device = torch.device("cuda")
device_property = torch.cuda.get_device_properties(0)
device_name = device_property.name
total_mem = device_property.total_memory / 1024**3

def get_gpu_device():
if get_platform() == "macOS":
elif backend == gs_backend.metal:
if not torch.backends.mps.is_available():
gs.raise_exception("metal device not available")

# on mac, cpu and gpu are in the same device
_, device_name, total_mem = get_cpu_device()
_, device_name, total_mem, _ = get_device(gs_backend.cpu)
device = torch.device("mps")

else:
if not torch.cuda.is_available():
gs.raise_exception("cuda device not available")
elif backend == gs_backend.vulkan:
if torch.xpu.is_available(): # pytorch 2.5+ Intel XPU device
device = torch.device("xpu")
device_property = torch.xpu.get_device_properties(0)
device_name = device_property.name
total_mem = device_property.total_memory / 1024**3
else: # pytorch tensors on cpu
device, device_name, total_mem, _ = get_device(gs_backend.cpu)

elif backend == gs_backend.gpu:
if torch.cuda.is_available():
return get_device(gs_backend.cuda)
elif get_platform() == "macOS":
return get_device(gs_backend.metal)
else:
return get_device(gs_backend.vulkan)

device = torch.device("cuda")
device_property = torch.cuda.get_device_properties(0)
device_name = device_property.name
total_mem = device_property.total_memory / 1024**3
else:
device_name = get_cpu_name()
total_mem = psutil.virtual_memory().total / 1024**3
device = torch.device("cpu")

return device, device_name, total_mem
return device, device_name, total_mem, backend


def get_src_dir():
Expand Down

0 comments on commit 806d0a8

Please sign in to comment.