diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml index c303faf0..33459361 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yaml +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -30,6 +30,7 @@ body: description: The skrl version can be obtained with the command `pip show skrl`. options: - --- + - 1.4.1 - 1.4.0 - 1.3.0 - 1.2.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index e2b81d65..5ae853ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [1.4.1] - Unreleased +### Fixed +- Force the use of the device local to process in distributed runs in JAX + ## [1.4.0] - 2025-01-16 ### Added - Utilities to operate on Gymnasium spaces (`Box`, `Discrete`, `MultiDiscrete`, `Tuple` and `Dict`) diff --git a/docs/source/api/config/frameworks.rst b/docs/source/api/config/frameworks.rst index 72095643..15d773cd 100644 --- a/docs/source/api/config/frameworks.rst +++ b/docs/source/api/config/frameworks.rst @@ -92,11 +92,12 @@ API .. py:data:: skrl.config.jax.device :type: jax.Device - :value: "cuda:${LOCAL_RANK}" | "cpu" + :value: "cuda:${JAX_LOCAL_RANK}" | "cpu" Default device. - The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise. + The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise. + However, in a distributed environment, it is the device local to process with index ``JAX_RANK``. .. py:data:: skrl.config.jax.backend :type: str diff --git a/docs/source/conf.py b/docs/source/conf.py index b7670098..43c9c803 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,7 +16,7 @@ if skrl.__version__ != "unknown": release = version = skrl.__version__ else: - release = version = "1.4.0" + release = version = "1.4.1" master_doc = "index" diff --git a/pyproject.toml b/pyproject.toml index 5e0da0f2..4ca9ef33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "skrl" -version = "1.4.0" +version = "1.4.1" description = "Modular and flexible library for reinforcement learning on PyTorch and JAX" readme = "README.md" requires-python = ">=3.6" diff --git a/skrl/__init__.py b/skrl/__init__.py index bd424ffe..5931b8eb 100644 --- a/skrl/__init__.py +++ b/skrl/__init__.py @@ -188,6 +188,12 @@ def __init__(self) -> None: process_id=self._rank, local_device_ids=self._local_rank, ) + # get the device local to process + try: + self._device = jax.local_devices(process_index=self._rank)[0] + logger.info(f"Using device local to process with index/rank {self._rank} ({self._device})") + except Exception as e: + logger.warning(f"Failed to get the device local to process with index/rank {self._rank}: {e}") @staticmethod def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": @@ -197,6 +203,10 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``). + .. warning:: + + This method returns (forces to use) the device local to process in a distributed environment. + :param device: Device specification. If the specified device is ``None`` or it cannot be resolved, the default available device will be returned instead. @@ -204,6 +214,15 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": """ import jax + # force the use of the device local to process in distributed runs + if config.jax.is_distributed: + try: + return jax.local_devices(process_index=config.jax.rank)[0] + except Exception as e: + logger.warning( + f"Failed to get the device local to process with index/rank {config.jax.rank}: {e}" + ) + if isinstance(device, jax.Device): return device elif isinstance(device, str): @@ -218,8 +237,8 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": def device(self) -> "jax.Device": """Default device. - The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) - if CUDA is available, ``cpu`` otherwise. + The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise. + However, in a distributed environment, it is the device local to process with index ``JAX_RANK``. """ self._device = self.parse_device(self._device) return self._device