Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Optax's learning rate schedulers integration in JAX #245

Merged
merged 3 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Update model instantiators definitions to process supported fundamental and composite Gymnasium spaces
- Make flattened tensor storage in memory the default option (revert changed introduced in version 1.3.0)
- Drop support for PyTorch versions prior to 1.10 (the previous supported version was 1.9)
- Update KL Adaptive learning rate scheduler implementation to match Optax's behavior in JAX
- Speed up PyTorch implementation:
- Disable argument checking when instantiating distributions
- Replace PyTorch's `BatchSampler` by Python slice when sampling data from memory
Expand All @@ -27,6 +28,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Move the batch sampling inside gradient step loop for DQN, DDQN, DDPG (RNN), TD3 (RNN), SAC and SAC (RNN)
- Model state dictionary initialization for composite Gymnasium spaces in JAX
- Add missing `reduction` parameter to Gaussian model instantiator
- Optax's learning rate schedulers integration in JAX implementation

### Removed
- Remove OpenAI Gym (`gym`) from dependencies and source code. **skrl** continues to support gym environments,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/resources/schedulers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Learning rate schedulers are techniques that adjust the learning rate over time

- **PyTorch**: The implemented schedulers inherit from the PyTorch :literal:`_LRScheduler` class. Visit `How to adjust learning rate <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_ in the PyTorch documentation for more details.

- **JAX**: The implemented schedulers must parameterize and return a function that maps step counts to values. Visit `Schedules <https://optax.readthedocs.io/en/latest/api.html#schedules>`_ in the Optax documentation for more details.
- **JAX**: The implemented schedulers must parameterize and return a function that maps step counts to values. Visit `Optimizer Schedules <https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html>`_ in the Optax documentation for more details.

.. raw:: html

Expand Down
5 changes: 1 addition & 4 deletions docs/source/api/resources/schedulers/kl_adaptive.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,4 @@ API (PyTorch)
API (JAX)
---------

.. autoclass:: skrl.resources.schedulers.jax.kl_adaptive.KLAdaptiveLR
:show-inheritance:
:inherited-members:
:members:
.. autofunction:: skrl.resources.schedulers.jax.kl_adaptive.KLAdaptiveLR
38 changes: 18 additions & 20 deletions skrl/agents/jax/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"lambda": 0.95, # TD(lambda) coefficient (lam) for computing returns and advantages

"learning_rate": 1e-3, # learning rate
"learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler)
"learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules)
"learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

"state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors)
Expand Down Expand Up @@ -270,25 +270,21 @@ def __init__(
# set up optimizer and learning rate scheduler
if self.policy is not None and self.value is not None:
# scheduler
scale = True
self.scheduler = None
if self._learning_rate_scheduler is not None:
if self._learning_rate_scheduler == KLAdaptiveLR:
scale = False
self.scheduler = self._learning_rate_scheduler(
self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"]
)
else:
self._learning_rate = self._learning_rate_scheduler(
self._learning_rate, **self.cfg["learning_rate_scheduler_kwargs"]
)
if self._learning_rate_scheduler:
self.scheduler = self._learning_rate_scheduler(**self.cfg["learning_rate_scheduler_kwargs"])
# optimizer
with jax.default_device(self.device):
self.policy_optimizer = Adam(
model=self.policy, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale
model=self.policy,
lr=self._learning_rate,
grad_norm_clip=self._grad_norm_clip,
scale=not self._learning_rate_scheduler,
)
self.value_optimizer = Adam(
model=self.value, lr=self._learning_rate, grad_norm_clip=self._grad_norm_clip, scale=scale
model=self.value,
lr=self._learning_rate,
grad_norm_clip=self._grad_norm_clip,
scale=not self._learning_rate_scheduler,
)

self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer
Expand Down Expand Up @@ -541,7 +537,7 @@ def _update(self, timestep: int, timesteps: int) -> None:
if config.jax.is_distributed:
grad = self.policy.reduce_parameters(grad)
self.policy_optimizer = self.policy_optimizer.step(
grad, self.policy, self.scheduler._lr if self.scheduler else None
grad, self.policy, self._learning_rate if self._learning_rate_scheduler else None
)

# compute value loss
Expand All @@ -551,7 +547,7 @@ def _update(self, timestep: int, timesteps: int) -> None:
if config.jax.is_distributed:
grad = self.value.reduce_parameters(grad)
self.value_optimizer = self.value_optimizer.step(
grad, self.value, self.scheduler._lr if self.scheduler else None
grad, self.value, self._learning_rate if self._learning_rate_scheduler else None
)

# update cumulative losses
Expand All @@ -562,13 +558,15 @@ def _update(self, timestep: int, timesteps: int) -> None:

# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
if self._learning_rate_scheduler is KLAdaptiveLR:
kl = np.mean(kl_divergences)
# reduce (collect from all workers/processes) KL in distributed runs
if config.jax.is_distributed:
kl = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(kl.reshape(1)).item()
kl /= config.jax.world_size
self.scheduler.step(kl)
self._learning_rate = self.scheduler(timestep, self._learning_rate, kl)
else:
self._learning_rate *= self.scheduler(timestep)

# record data
self.track_data("Loss / Policy loss", cumulative_policy_loss / len(sampled_batches))
Expand All @@ -580,4 +578,4 @@ def _update(self, timestep: int, timesteps: int) -> None:
self.track_data("Policy / Standard deviation", stddev.mean().item())

if self._learning_rate_scheduler:
self.track_data("Learning / Learning rate", self.scheduler._lr)
self.track_data("Learning / Learning rate", self._learning_rate)
20 changes: 12 additions & 8 deletions skrl/agents/jax/cem/cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"discount_factor": 0.99, # discount factor (gamma)

"learning_rate": 1e-2, # learning rate
"learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler)
"learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules)
"learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

"state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors)
Expand Down Expand Up @@ -122,11 +122,13 @@ def __init__(

# set up optimizer and learning rate scheduler
if self.policy is not None:
# scheduler
if self._learning_rate_scheduler:
self.scheduler = self._learning_rate_scheduler(**self.cfg["learning_rate_scheduler_kwargs"])
# optimizer
with jax.default_device(self.device):
self.optimizer = Adam(model=self.policy, lr=self._learning_rate)
if self._learning_rate_scheduler is not None:
self.scheduler = self._learning_rate_scheduler(
self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]
self.optimizer = Adam(
model=self.policy, lr=self._learning_rate, scale=not self._learning_rate_scheduler
)

self.checkpoint_modules["optimizer"] = self.optimizer
Expand Down Expand Up @@ -338,11 +340,13 @@ def _policy_loss(params):
policy_loss, grad = jax.value_and_grad(_policy_loss, has_aux=False)(self.policy.state_dict.params)

# optimization step (policy)
self.optimizer = self.optimizer.step(grad, self.policy)
self.optimizer = self.optimizer.step(
grad, self.policy, self._learning_rate if self._learning_rate_scheduler else None
)

# update learning rate
if self._learning_rate_scheduler:
self.scheduler.step()
self._learning_rate *= self.scheduler(timestep)

# record data
self.track_data("Loss / Policy loss", policy_loss.item())
Expand All @@ -351,4 +355,4 @@ def _policy_loss(params):
self.track_data("Coefficient / Mean discounted returns", returns.mean().item())

if self._learning_rate_scheduler:
self.track_data("Learning / Learning rate", self.scheduler.get_last_lr()[0])
self.track_data("Learning / Learning rate", self._learning_rate)
40 changes: 24 additions & 16 deletions skrl/agents/jax/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

"actor_learning_rate": 1e-3, # actor learning rate
"critic_learning_rate": 1e-3, # critic learning rate
"learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler)
"learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules)
"learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

"state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors)
Expand Down Expand Up @@ -202,19 +202,23 @@ def __init__(

# set up optimizers and learning rate schedulers
if self.policy is not None and self.critic is not None:
# schedulers
if self._learning_rate_scheduler:
self.policy_scheduler = self._learning_rate_scheduler(**self.cfg["learning_rate_scheduler_kwargs"])
self.critic_scheduler = self._learning_rate_scheduler(**self.cfg["learning_rate_scheduler_kwargs"])
# optimizers
with jax.default_device(self.device):
self.policy_optimizer = Adam(
model=self.policy, lr=self._actor_learning_rate, grad_norm_clip=self._grad_norm_clip
model=self.policy,
lr=self._actor_learning_rate,
grad_norm_clip=self._grad_norm_clip,
scale=not self._learning_rate_scheduler,
)
self.critic_optimizer = Adam(
model=self.critic, lr=self._critic_learning_rate, grad_norm_clip=self._grad_norm_clip
)
if self._learning_rate_scheduler is not None:
self.policy_scheduler = self._learning_rate_scheduler(
self.policy_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]
)
self.critic_scheduler = self._learning_rate_scheduler(
self.critic_optimizer, **self.cfg["learning_rate_scheduler_kwargs"]
model=self.critic,
lr=self._critic_learning_rate,
grad_norm_clip=self._grad_norm_clip,
scale=not self._learning_rate_scheduler,
)

self.checkpoint_modules["policy_optimizer"] = self.policy_optimizer
Expand Down Expand Up @@ -458,7 +462,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
if config.jax.is_distributed:
grad = self.critic.reduce_parameters(grad)
self.critic_optimizer = self.critic_optimizer.step(grad, self.critic)
self.critic_optimizer = self.critic_optimizer.step(
grad, self.critic, self._critic_learning_rate if self._learning_rate_scheduler else None
)

# compute policy (actor) loss
grad, policy_loss = _update_policy(
Expand All @@ -468,16 +474,18 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
if config.jax.is_distributed:
grad = self.policy.reduce_parameters(grad)
self.policy_optimizer = self.policy_optimizer.step(grad, self.policy)
self.policy_optimizer = self.policy_optimizer.step(
grad, self.policy, self._actor_learning_rate if self._learning_rate_scheduler else None
)

# update target networks
self.target_policy.update_parameters(self.policy, polyak=self._polyak)
self.target_critic.update_parameters(self.critic, polyak=self._polyak)

# update learning rate
if self._learning_rate_scheduler:
self.policy_scheduler.step()
self.critic_scheduler.step()
self._actor_learning_rate *= self.policy_scheduler(timestep)
self._critic_learning_rate *= self.critic_scheduler(timestep)

# record data
self.track_data("Loss / Policy loss", policy_loss.item())
Expand All @@ -492,5 +500,5 @@ def _update(self, timestep: int, timesteps: int) -> None:
self.track_data("Target / Target (mean)", target_values.mean().item())

if self._learning_rate_scheduler:
self.track_data("Learning / Policy learning rate", self.policy_scheduler.get_last_lr()[0])
self.track_data("Learning / Critic learning rate", self.critic_scheduler.get_last_lr()[0])
self.track_data("Learning / Policy learning rate", self._actor_learning_rate)
self.track_data("Learning / Critic learning rate", self._critic_learning_rate)
22 changes: 14 additions & 8 deletions skrl/agents/jax/dqn/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"polyak": 0.005, # soft update hyperparameter (tau)

"learning_rate": 1e-3, # learning rate
"learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler)
"learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules)
"learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

"state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors)
Expand Down Expand Up @@ -175,11 +175,15 @@ def __init__(

# set up optimizer and learning rate scheduler
if self.q_network is not None:
# scheduler
if self._learning_rate_scheduler:
self.scheduler = self._learning_rate_scheduler(**self.cfg["learning_rate_scheduler_kwargs"])
# optimizer
with jax.default_device(self.device):
self.optimizer = Adam(model=self.q_network, lr=self._learning_rate)
if self._learning_rate_scheduler is not None:
self.scheduler = self._learning_rate_scheduler(
self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]
self.optimizer = Adam(
model=self.q_network,
lr=self._learning_rate,
scale=not self._learning_rate_scheduler,
)

self.checkpoint_modules["optimizer"] = self.optimizer
Expand Down Expand Up @@ -392,15 +396,17 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (Q-network)
if config.jax.is_distributed:
grad = self.q_network.reduce_parameters(grad)
self.optimizer = self.optimizer.step(grad, self.q_network)
self.optimizer = self.optimizer.step(
grad, self.q_network, self._learning_rate if self._learning_rate_scheduler else None
)

# update target network
if not timestep % self._target_update_interval:
self.target_q_network.update_parameters(self.q_network, polyak=self._polyak)

# update learning rate
if self._learning_rate_scheduler:
self.scheduler.step()
self._learning_rate *= self.scheduler(timestep)

# record data
self.track_data("Loss / Q-network loss", q_network_loss.item())
Expand All @@ -410,4 +416,4 @@ def _update(self, timestep: int, timesteps: int) -> None:
self.track_data("Target / Target (mean)", target_values.mean().item())

if self._learning_rate_scheduler:
self.track_data("Learning / Learning rate", self.scheduler._lr)
self.track_data("Learning / Learning rate", self._learning_rate)
22 changes: 14 additions & 8 deletions skrl/agents/jax/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"polyak": 0.005, # soft update hyperparameter (tau)

"learning_rate": 1e-3, # learning rate
"learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler)
"learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules)
"learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

"state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors)
Expand Down Expand Up @@ -172,11 +172,15 @@ def __init__(

# set up optimizer and learning rate scheduler
if self.q_network is not None:
# scheduler
if self._learning_rate_scheduler:
self.scheduler = self._learning_rate_scheduler(**self.cfg["learning_rate_scheduler_kwargs"])
# optimizer
with jax.default_device(self.device):
self.optimizer = Adam(model=self.q_network, lr=self._learning_rate)
if self._learning_rate_scheduler is not None:
self.scheduler = self._learning_rate_scheduler(
self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]
self.optimizer = Adam(
model=self.q_network,
lr=self._learning_rate,
scale=not self._learning_rate_scheduler,
)

self.checkpoint_modules["optimizer"] = self.optimizer
Expand Down Expand Up @@ -388,15 +392,17 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (Q-network)
if config.jax.is_distributed:
grad = self.q_network.reduce_parameters(grad)
self.optimizer = self.optimizer.step(grad, self.q_network)
self.optimizer = self.optimizer.step(
grad, self.q_network, self._learning_rate if self._learning_rate_scheduler else None
)

# update target network
if not timestep % self._target_update_interval:
self.target_q_network.update_parameters(self.q_network, polyak=self._polyak)

# update learning rate
if self._learning_rate_scheduler:
self.scheduler.step()
self._learning_rate *= self.scheduler(timestep)

# record data
self.track_data("Loss / Q-network loss", q_network_loss.item())
Expand All @@ -406,4 +412,4 @@ def _update(self, timestep: int, timesteps: int) -> None:
self.track_data("Target / Target (mean)", target_values.mean().item())

if self._learning_rate_scheduler:
self.track_data("Learning / Learning rate", self.scheduler._lr)
self.track_data("Learning / Learning rate", self._learning_rate)
Loading
Loading