Skip to content

Commit

Permalink
Take into account truncated when computing dones in multi-agent algor…
Browse files Browse the repository at this point in the history
…ithms
  • Loading branch information
Toni-SM committed Jan 16, 2025
1 parent 0eb7cca commit 52bea23
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 38 deletions.
27 changes: 9 additions & 18 deletions skrl/multi_agents/jax/ippo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
self.memories[uid].create_tensor(name="actions", size=self.action_spaces[uid], dtype=jnp.float32)
self.memories[uid].create_tensor(name="rewards", size=1, dtype=jnp.float32)
self.memories[uid].create_tensor(name="terminated", size=1, dtype=jnp.int8)
self.memories[uid].create_tensor(name="truncated", size=1, dtype=jnp.int8)
self.memories[uid].create_tensor(name="log_prob", size=1, dtype=jnp.float32)
self.memories[uid].create_tensor(name="values", size=1, dtype=jnp.float32)
self.memories[uid].create_tensor(name="returns", size=1, dtype=jnp.float32)
Expand Down Expand Up @@ -535,24 +536,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
last_values = self._value_preprocessor[uid](last_values, inverse=True)

values = memory.get_tensor_by_name("values")
if self._jax:
returns, advantages = _compute_gae(
rewards=memory.get_tensor_by_name("rewards"),
dones=memory.get_tensor_by_name("terminated"),
values=values,
next_values=last_values,
discount_factor=self._discount_factor[uid],
lambda_coefficient=self._lambda[uid],
)
else:
returns, advantages = compute_gae(
rewards=memory.get_tensor_by_name("rewards"),
dones=memory.get_tensor_by_name("terminated"),
values=values,
next_values=last_values,
discount_factor=self._discount_factor[uid],
lambda_coefficient=self._lambda[uid],
)
returns, advantages = (_compute_gae if self._jax else compute_gae)(
rewards=memory.get_tensor_by_name("rewards"),
dones=memory.get_tensor_by_name("terminated") | memory.get_tensor_by_name("truncated"),
values=values,
next_values=last_values,
discount_factor=self._discount_factor[uid],
lambda_coefficient=self._lambda[uid],
)

memory.set_tensor_by_name("values", self._value_preprocessor[uid](values, train=True))
memory.set_tensor_by_name("returns", self._value_preprocessor[uid](returns, train=True))
Expand Down
27 changes: 9 additions & 18 deletions skrl/multi_agents/jax/mappo/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
self.memories[uid].create_tensor(name="actions", size=self.action_spaces[uid], dtype=jnp.float32)
self.memories[uid].create_tensor(name="rewards", size=1, dtype=jnp.float32)
self.memories[uid].create_tensor(name="terminated", size=1, dtype=jnp.int8)
self.memories[uid].create_tensor(name="truncated", size=1, dtype=jnp.int8)
self.memories[uid].create_tensor(name="log_prob", size=1, dtype=jnp.float32)
self.memories[uid].create_tensor(name="values", size=1, dtype=jnp.float32)
self.memories[uid].create_tensor(name="returns", size=1, dtype=jnp.float32)
Expand Down Expand Up @@ -565,24 +566,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
last_values = self._value_preprocessor[uid](last_values, inverse=True)

values = memory.get_tensor_by_name("values")
if self._jax:
returns, advantages = _compute_gae(
rewards=memory.get_tensor_by_name("rewards"),
dones=memory.get_tensor_by_name("terminated"),
values=values,
next_values=last_values,
discount_factor=self._discount_factor[uid],
lambda_coefficient=self._lambda[uid],
)
else:
returns, advantages = compute_gae(
rewards=memory.get_tensor_by_name("rewards"),
dones=memory.get_tensor_by_name("terminated"),
values=values,
next_values=last_values,
discount_factor=self._discount_factor[uid],
lambda_coefficient=self._lambda[uid],
)
returns, advantages = (_compute_gae if self._jax else compute_gae)(
rewards=memory.get_tensor_by_name("rewards"),
dones=memory.get_tensor_by_name("terminated") | memory.get_tensor_by_name("truncated"),
values=values,
next_values=last_values,
discount_factor=self._discount_factor[uid],
lambda_coefficient=self._lambda[uid],
)

memory.set_tensor_by_name("values", self._value_preprocessor[uid](values, train=True))
memory.set_tensor_by_name("returns", self._value_preprocessor[uid](returns, train=True))
Expand Down
3 changes: 2 additions & 1 deletion skrl/multi_agents/torch/ippo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
self.memories[uid].create_tensor(name="actions", size=self.action_spaces[uid], dtype=torch.float32)
self.memories[uid].create_tensor(name="rewards", size=1, dtype=torch.float32)
self.memories[uid].create_tensor(name="terminated", size=1, dtype=torch.bool)
self.memories[uid].create_tensor(name="truncated", size=1, dtype=torch.bool)
self.memories[uid].create_tensor(name="log_prob", size=1, dtype=torch.float32)
self.memories[uid].create_tensor(name="values", size=1, dtype=torch.float32)
self.memories[uid].create_tensor(name="returns", size=1, dtype=torch.float32)
Expand Down Expand Up @@ -433,7 +434,7 @@ def compute_gae(
values = memory.get_tensor_by_name("values")
returns, advantages = compute_gae(
rewards=memory.get_tensor_by_name("rewards"),
dones=memory.get_tensor_by_name("terminated"),
dones=memory.get_tensor_by_name("terminated") | memory.get_tensor_by_name("truncated"),
values=values,
next_values=last_values,
discount_factor=self._discount_factor[uid],
Expand Down
3 changes: 2 additions & 1 deletion skrl/multi_agents/torch/mappo/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
self.memories[uid].create_tensor(name="actions", size=self.action_spaces[uid], dtype=torch.float32)
self.memories[uid].create_tensor(name="rewards", size=1, dtype=torch.float32)
self.memories[uid].create_tensor(name="terminated", size=1, dtype=torch.bool)
self.memories[uid].create_tensor(name="truncated", size=1, dtype=torch.bool)
self.memories[uid].create_tensor(name="log_prob", size=1, dtype=torch.float32)
self.memories[uid].create_tensor(name="values", size=1, dtype=torch.float32)
self.memories[uid].create_tensor(name="returns", size=1, dtype=torch.float32)
Expand Down Expand Up @@ -464,7 +465,7 @@ def compute_gae(
values = memory.get_tensor_by_name("values")
returns, advantages = compute_gae(
rewards=memory.get_tensor_by_name("rewards"),
dones=memory.get_tensor_by_name("terminated"),
dones=memory.get_tensor_by_name("terminated") | memory.get_tensor_by_name("truncated"),
values=values,
next_values=last_values,
discount_factor=self._discount_factor[uid],
Expand Down

0 comments on commit 52bea23

Please sign in to comment.