Skip to content

Commit

Permalink
Treat truncation signal when computing 'done' (environment reset) (#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM authored Jan 16, 2025
1 parent 7f9992b commit 663f546
Show file tree
Hide file tree
Showing 46 changed files with 409 additions and 253 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Add missing `reduction` parameter to Gaussian model instantiator
- Optax's learning rate schedulers integration in JAX implementation
- Isaac Lab wrapper's multi-agent state retrieval with gymnasium 1.0
- Treat truncation signal when computing 'done' (environment reset)

### Removed
- Remove OpenAI Gym (`gym`) from dependencies and source code. **skrl** continues to support gym environments,
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Algorithm implementation

| Main notation/symbols:
| - policy function approximator (:math:`\pi_\theta`), value function approximator (:math:`V_\phi`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), dones (:math:`d`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
| - values (:math:`V`), advantages (:math:`A`), returns (:math:`R`)
| - log probabilities (:math:`logp`)
| - loss (:math:`L`)
Expand Down Expand Up @@ -59,7 +59,7 @@ Learning algorithm
| :literal:`_update(...)`
| :green:`# compute returns and advantages`
| :math:`V_{_{last}}' \leftarrow V_\phi(s')`
| :math:`R, A \leftarrow f_{GAE}(r, d, V, V_{_{last}}')`
| :math:`R, A \leftarrow f_{GAE}(r, d_{_{end}} \lor d_{_{timeout}}, V, V_{_{last}}')`
| :green:`# sample mini-batches from memory`
| [[:math:`s, a, logp, V, R, A`]] :math:`\leftarrow` states, actions, log_prob, values, returns, advantages
| :green:`# mini-batches loop`
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Algorithm implementation

| Main notation/symbols:
| - policy (:math:`\pi_\theta`), value (:math:`V_\phi`) and discriminator (:math:`D_\psi`) function approximators
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), dones (:math:`d`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
| - values (:math:`V`), next values (:math:`V'`), advantages (:math:`A`), returns (:math:`R`)
| - log probabilities (:math:`logp`)
| - loss (:math:`L`)
Expand Down Expand Up @@ -57,7 +57,7 @@ Learning algorithm
| :math:`r_D \leftarrow -log(\text{max}( 1 - \hat{y}(D_\psi(s_{_{AMP}})), \, 10^{-4})) \qquad` with :math:`\; \hat{y}(x) = \dfrac{1}{1 + e^{-x}}`
| :math:`r' \leftarrow` :guilabel:`task_reward_weight` :math:`r \, +` :guilabel:`style_reward_weight` :guilabel:`discriminator_reward_scale` :math:`r_D`
| :green:`# compute returns and advantages`
| :math:`R, A \leftarrow f_{GAE}(r', d, V, V')`
| :math:`R, A \leftarrow f_{GAE}(r', d_{_{end}} \lor d_{_{timeout}}, V, V')`
| :green:`# sample mini-batches from memory`
| [[:math:`s, a, logp, V, R, A, s_{_{AMP}}`]] :math:`\leftarrow` states, actions, log_prob, values, returns, advantages, AMP states
| [[:math:`s_{_{AMP}}^{^M}`]] :math:`\leftarrow` AMP states from :math:`M`
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/cem.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Algorithm implementation

| Main notation/symbols:
| - policy function approximator (:math:`\pi_\theta`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), dones (:math:`d`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
| - loss (:math:`L`)
.. raw:: html
Expand All @@ -41,7 +41,7 @@ Learning algorithm
|
| :literal:`_update(...)`
| :green:`# sample all memory`
| :math:`s, a, r, s', d \leftarrow` states, actions, rewards, next_states, dones
| :math:`s, a, r \leftarrow` states, actions, rewards
| :green:`# compute discounted return threshold`
| :math:`[G] \leftarrow \sum_{t=0}^{E-1}` :guilabel:`discount_factor`:math:`^{t} \, r_t` for each episode
| :math:`G_{_{bound}} \leftarrow q_{th_{quantile}}([G])` at the given :guilabel:`percentile`
Expand Down
6 changes: 3 additions & 3 deletions docs/source/api/agents/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Algorithm implementation

| Main notation/symbols:
| - policy function approximator (:math:`\mu_\theta`), critic function approximator (:math:`Q_\phi`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), dones (:math:`d`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
| - loss (:math:`L`)
.. raw:: html
Expand Down Expand Up @@ -50,11 +50,11 @@ Learning algorithm
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| [:math:`s, a, r, s', d_{_{end}}, d_{_{timeout}}`] with size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`a' \leftarrow \mu_{\theta_{target}}(s')`
| :math:`Q_{_{target}} \leftarrow Q_{\phi_{target}}(s', a')`
| :math:`y \leftarrow r \;+` :guilabel:`discount_factor` :math:`\neg d \; Q_{_{target}}`
| :math:`y \leftarrow r \;+` :guilabel:`discount_factor` :math:`\neg (d_{_{end}} \lor d_{_{timeout}}) \; Q_{_{target}}`
| :green:`# compute critic loss`
| :math:`Q \leftarrow Q_\phi(s, a)`
| :math:`L_{Q_\phi} \leftarrow \frac{1}{N} \sum_{i=1}^N (Q - y)^2`
Expand Down
13 changes: 9 additions & 4 deletions docs/source/api/agents/ddqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ Algorithm
Algorithm implementation
^^^^^^^^^^^^^^^^^^^^^^^^

| Main notation/symbols:
| - epsilon (:math:`\epsilon`), Q-network (:math:`Q_\phi`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
| - loss (:math:`L`)
.. raw:: html

<br>
Expand All @@ -43,16 +48,16 @@ Learning algorithm
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| [:math:`s, a, r, s', d_{_{end}}, d_{_{timeout}}`] with size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`Q' \leftarrow Q_{\phi_{target}}(s')`
| :math:`Q_{_{target}} \leftarrow Q'[\underset{a}{\arg\max} \; Q_\phi(s')] \qquad` :gray:`# the only difference with DQN`
| :math:`y \leftarrow r \;+` :guilabel:`discount_factor` :math:`\neg d \; Q_{_{target}}`
| :math:`y \leftarrow r \;+` :guilabel:`discount_factor` :math:`\neg (d_{_{end}} \lor d_{_{timeout}}) \; Q_{_{target}}`
| :green:`# compute Q-network loss`
| :math:`Q \leftarrow Q_\phi(s)[a]`
| :math:`{Loss}_{Q_\phi} \leftarrow \frac{1}{N} \sum_{i=1}^N (Q - y)^2`
| :math:`L_{Q_\phi} \leftarrow \frac{1}{N} \sum_{i=1}^N (Q - y)^2`
| :green:`# optimize Q-network`
| :math:`\nabla_{\phi} {Loss}_{Q_\phi}`
| :math:`\nabla_{\phi} L_{Q_\phi}`
| :green:`# update target network`
| **IF** it's time to update target network **THEN**
| :math:`\phi_{target} \leftarrow` :guilabel:`polyak` :math:`\phi + (1 \;-` :guilabel:`polyak` :math:`) \phi_{target}`
Expand Down
13 changes: 9 additions & 4 deletions docs/source/api/agents/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ Algorithm
Algorithm implementation
^^^^^^^^^^^^^^^^^^^^^^^^

| Main notation/symbols:
| - epsilon (:math:`\epsilon`), Q-network (:math:`Q_\phi`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
| - loss (:math:`L`)
.. raw:: html

<br>
Expand All @@ -43,16 +48,16 @@ Learning algorithm
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| [:math:`s, a, r, s', d_{_{end}}, d_{_{timeout}}`] with size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`Q' \leftarrow Q_{\phi_{target}}(s')`
| :math:`Q_{_{target}} \leftarrow \underset{a}{\max} \; Q' \qquad` :gray:`# the only difference with DDQN`
| :math:`y \leftarrow r \;+` :guilabel:`discount_factor` :math:`\neg d \; Q_{_{target}}`
| :math:`y \leftarrow r \;+` :guilabel:`discount_factor` :math:`\neg (d_{_{end}} \lor d_{_{timeout}}) \; Q_{_{target}}`
| :green:`# compute Q-network loss`
| :math:`Q \leftarrow Q_\phi(s)[a]`
| :math:`{Loss}_{Q_\phi} \leftarrow \frac{1}{N} \sum_{i=1}^N (Q - y)^2`
| :math:`L_{Q_\phi} \leftarrow \frac{1}{N} \sum_{i=1}^N (Q - y)^2`
| :green:`# optimize Q-network`
| :math:`\nabla_{\phi} {Loss}_{Q_\phi}`
| :math:`\nabla_{\phi} L_{Q_\phi}`
| :green:`# update target network`
| **IF** it's time to update target network **THEN**
| :math:`\phi_{target} \leftarrow` :guilabel:`polyak` :math:`\phi + (1 \;-` :guilabel:`polyak` :math:`) \phi_{target}`
Expand Down
8 changes: 4 additions & 4 deletions docs/source/api/agents/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Algorithm
---------

| For each iteration do:
| :math:`\bullet \;` Collect, in a rollout memory, a set of states :math:`s`, actions :math:`a`, rewards :math:`r`, dones :math:`d`, log probabilities :math:`logp` and values :math:`V` on policy using :math:`\pi_\theta` and :math:`V_\phi`
| :math:`\bullet \;` Estimate returns :math:`R` and advantages :math:`A` using Generalized Advantage Estimation (GAE(:math:`\lambda`)) from the collected data [:math:`r, d, V`]
| :math:`\bullet \;` Collect, in a rollout memory, a set of states :math:`s`, actions :math:`a`, rewards :math:`r`, terminated :math:`d_{_{end}}`, truncated :math:`d_{_{timeout}}`, log probabilities :math:`logp` and values :math:`V` on policy using :math:`\pi_\theta` and :math:`V_\phi`
| :math:`\bullet \;` Estimate returns :math:`R` and advantages :math:`A` using Generalized Advantage Estimation (GAE(:math:`\lambda`)) from the collected data [:math:`r, d_{_{end}} \lor d_{_{timeout}}, V`]
| :math:`\bullet \;` Compute the entropy loss :math:`{L}_{entropy}`
| :math:`\bullet \;` Compute the clipped surrogate objective (policy loss) with :math:`ratio` as the probability ratio between the action under the current policy and the action under the previous policy: :math:`L^{clip}_{\pi_\theta} = \mathbb{E}[\min(A \; ratio, A \; \text{clip}(ratio, 1-c, 1+c))]`
| :math:`\bullet \;` Compute the value loss :math:`L_{V_\phi}` as the mean squared error (MSE) between the predicted values :math:`V_{_{predicted}}` and the estimated returns :math:`R`
Expand All @@ -29,7 +29,7 @@ Algorithm implementation

| Main notation/symbols:
| - policy function approximator (:math:`\pi_\theta`), value function approximator (:math:`V_\phi`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), dones (:math:`d`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
| - values (:math:`V`), advantages (:math:`A`), returns (:math:`R`)
| - log probabilities (:math:`logp`)
| - loss (:math:`L`)
Expand Down Expand Up @@ -63,7 +63,7 @@ Learning algorithm
| :literal:`_update(...)`
| :green:`# compute returns and advantages`
| :math:`V_{_{last}}' \leftarrow V_\phi(s')`
| :math:`R, A \leftarrow f_{GAE}(r, d, V, V_{_{last}}')`
| :math:`R, A \leftarrow f_{GAE}(r, d_{_{end}} \lor d_{_{timeout}}, V, V_{_{last}}')`
| :green:`# sample mini-batches from memory`
| [[:math:`s, a, logp, V, R, A`]] :math:`\leftarrow` states, actions, log_prob, values, returns, advantages
| :green:`# learning epochs`
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/q_learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Algorithm implementation

| Main notation/symbols:
| - action-value function (:math:`Q`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), dones (:math:`d`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
.. raw:: html

Expand All @@ -46,7 +46,7 @@ Learning algorithm
| :green:`# compute next actions`
| :math:`a' \leftarrow \underset{a}{\arg\max} \; Q[s'] \qquad` :gray:`# the only difference with SARSA`
| :green:`# update Q-table`
| :math:`Q[s,a] \leftarrow Q[s,a] \;+` :guilabel:`learning_rate` :math:`(r \;+` :guilabel:`discount_factor` :math:`\neg d \; Q[s',a'] - Q[s,a])`
| :math:`Q[s,a] \leftarrow Q[s,a] \;+` :guilabel:`learning_rate` :math:`(r \;+` :guilabel:`discount_factor` :math:`\neg (d_{_{end}} \lor d_{_{timeout}}) \; Q[s',a'] - Q[s,a])`
.. raw:: html

Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/rpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Algorithm implementation

| Main notation/symbols:
| - policy function approximator (:math:`\pi_\theta`), value function approximator (:math:`V_\phi`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), dones (:math:`d`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
| - values (:math:`V`), advantages (:math:`A`), returns (:math:`R`)
| - log probabilities (:math:`logp`)
| - loss (:math:`L`)
Expand Down Expand Up @@ -102,7 +102,7 @@ Learning algorithm
| :literal:`_update(...)`
| :green:`# compute returns and advantages`
| :math:`V_{_{last}}' \leftarrow V_\phi(s')`
| :math:`R, A \leftarrow f_{GAE}(r, d, V, V_{_{last}}')`
| :math:`R, A \leftarrow f_{GAE}(r, d_{_{end}} \lor d_{_{timeout}}, V, V_{_{last}}')`
| :green:`# sample mini-batches from memory`
| [[:math:`s, a, logp, V, R, A`]] :math:`\leftarrow` states, actions, log_prob, values, returns, advantages
| :green:`# learning epochs`
Expand Down
6 changes: 3 additions & 3 deletions docs/source/api/agents/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Algorithm implementation

| Main notation/symbols:
| - policy function approximator (:math:`\pi_\theta`), critic function approximator (:math:`Q_\phi`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), dones (:math:`d`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
| - log probabilities (:math:`logp`), entropy coefficient (:math:`\alpha`)
| - loss (:math:`L`)
Expand All @@ -37,13 +37,13 @@ Learning algorithm
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| [:math:`s, a, r, s', d_{_{end}}, d_{_{timeout}}`] with size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`a',\; logp' \leftarrow \pi_\theta(s')`
| :math:`Q_{1_{target}} \leftarrow Q_{{\phi 1}_{target}}(s', a')`
| :math:`Q_{2_{target}} \leftarrow Q_{{\phi 2}_{target}}(s', a')`
| :math:`Q_{_{target}} \leftarrow \text{min}(Q_{1_{target}}, Q_{2_{target}}) - \alpha \; logp'`
| :math:`y \leftarrow r \;+` :guilabel:`discount_factor` :math:`\neg d \; Q_{_{target}}`
| :math:`y \leftarrow r \;+` :guilabel:`discount_factor` :math:`\neg (d_{_{end}} \lor d_{_{timeout}}) \; Q_{_{target}}`
| :green:`# compute critic loss`
| :math:`Q_1 \leftarrow Q_{\phi 1}(s, a)`
| :math:`Q_2 \leftarrow Q_{\phi 2}(s, a)`
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/sarsa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Algorithm implementation

| Main notation/symbols:
| - action-value function (:math:`Q`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), dones (:math:`d`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
.. raw:: html

Expand All @@ -46,7 +46,7 @@ Learning algorithm
| :green:`# compute next actions`
| :math:`a' \leftarrow \pi_{Q[s,a]}(s') \qquad` :gray:`# the only difference with Q-learning`
| :green:`# update Q-table`
| :math:`Q[s,a] \leftarrow Q[s,a] \;+` :guilabel:`learning_rate` :math:`(r \;+` :guilabel:`discount_factor` :math:`\neg d \; Q[s',a'] - Q[s,a])`
| :math:`Q[s,a] \leftarrow Q[s,a] \;+` :guilabel:`learning_rate` :math:`(r \;+` :guilabel:`discount_factor` :math:`\neg (d_{_{end}} \lor d_{_{timeout}}) \; Q[s',a'] - Q[s,a])`
.. raw:: html

Expand Down
6 changes: 3 additions & 3 deletions docs/source/api/agents/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Algorithm implementation

| Main notation/symbols:
| - policy function approximator (:math:`\mu_\theta`), critic function approximator (:math:`Q_\phi`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), dones (:math:`d`)
| - states (:math:`s`), actions (:math:`a`), rewards (:math:`r`), next states (:math:`s'`), terminated (:math:`d_{_{end}}`), truncated (:math:`d_{_{timeout}}`)
| - loss (:math:`L`)
.. raw:: html
Expand Down Expand Up @@ -50,7 +50,7 @@ Learning algorithm
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| [:math:`s, a, r, s', d_{_{end}}, d_{_{timeout}}`] with size :guilabel:`batch_size`
| :green:`# target policy smoothing`
| :math:`a' \leftarrow \mu_{\theta_{target}}(s')`
| :math:`noise \leftarrow \text{clip}(` :guilabel:`smooth_regularization_noise` :math:`, -c, c) \qquad` with :math:`c` as :guilabel:`smooth_regularization_clip`
Expand All @@ -60,7 +60,7 @@ Learning algorithm
| :math:`Q_{1_{target}} \leftarrow Q_{{\phi 1}_{target}}(s', a')`
| :math:`Q_{2_{target}} \leftarrow Q_{{\phi 2}_{target}}(s', a')`
| :math:`Q_{_{target}} \leftarrow \text{min}(Q_{1_{target}}, Q_{2_{target}})`
| :math:`y \leftarrow r \;+` :guilabel:`discount_factor` :math:`\neg d \; Q_{_{target}}`
| :math:`y \leftarrow r \;+` :guilabel:`discount_factor` :math:`\neg (d_{_{end}} \lor d_{_{timeout}}) \; Q_{_{target}}`
| :green:`# compute critic loss`
| :math:`Q_1 \leftarrow Q_{\phi 1}(s, a)`
| :math:`Q_2 \leftarrow Q_{\phi 2}(s, a)`
Expand Down
Loading

0 comments on commit 663f546

Please sign in to comment.