Skip to content

Commit

Permalink
Merge branch 'develop' into toni/runner_noise
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jan 28, 2025
2 parents ee204db + d13dbd7 commit d520b98
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 6 deletions.
1 change: 1 addition & 0 deletions .github/ISSUE_TEMPLATE/bug_report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ body:
description: The skrl version can be obtained with the command `pip show skrl`.
options:
- ---
- 1.4.1
- 1.4.0
- 1.3.0
- 1.2.0
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [1.4.1] - Unreleased
### Fixed
- Force the use of the device local to process in distributed runs in JAX

## [1.4.0] - 2025-01-16
### Added
- Utilities to operate on Gymnasium spaces (`Box`, `Discrete`, `MultiDiscrete`, `Tuple` and `Dict`)
Expand Down
5 changes: 3 additions & 2 deletions docs/source/api/config/frameworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ API

.. py:data:: skrl.config.jax.device
:type: jax.Device
:value: "cuda:${LOCAL_RANK}" | "cpu"
:value: "cuda:${JAX_LOCAL_RANK}" | "cpu"

Default device.

The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise.
The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise.
However, in a distributed environment, it is the device local to process with index ``JAX_RANK``.

.. py:data:: skrl.config.jax.backend
:type: str
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if skrl.__version__ != "unknown":
release = version = skrl.__version__
else:
release = version = "1.4.0"
release = version = "1.4.1"

master_doc = "index"

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "skrl"
version = "1.4.0"
version = "1.4.1"
description = "Modular and flexible library for reinforcement learning on PyTorch and JAX"
readme = "README.md"
requires-python = ">=3.6"
Expand Down
23 changes: 21 additions & 2 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def __init__(self) -> None:
process_id=self._rank,
local_device_ids=self._local_rank,
)
# get the device local to process
try:
self._device = jax.local_devices(process_index=self._rank)[0]
logger.info(f"Using device local to process with index/rank {self._rank} ({self._device})")
except Exception as e:
logger.warning(f"Failed to get the device local to process with index/rank {self._rank}: {e}")

@staticmethod
def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
Expand All @@ -197,13 +203,26 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``).
.. warning::
This method returns (forces to use) the device local to process in a distributed environment.
:param device: Device specification. If the specified device is ``None`` or it cannot be resolved,
the default available device will be returned instead.
:return: JAX Device.
"""
import jax

# force the use of the device local to process in distributed runs
if config.jax.is_distributed:
try:
return jax.local_devices(process_index=config.jax.rank)[0]
except Exception as e:
logger.warning(
f"Failed to get the device local to process with index/rank {config.jax.rank}: {e}"
)

if isinstance(device, jax.Device):
return device
elif isinstance(device, str):
Expand All @@ -218,8 +237,8 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
def device(self) -> "jax.Device":
"""Default device.
The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment)
if CUDA is available, ``cpu`` otherwise.
The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise.
However, in a distributed environment, it is the device local to process with index ``JAX_RANK``.
"""
self._device = self.parse_device(self._device)
return self._device
Expand Down

0 comments on commit d520b98

Please sign in to comment.