Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 20, 2024
1 parent b59d5de commit f9c4e00
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions sota-implementations/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.action_spec_unbatched,
spec=env.full_action_spec_unbatched,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.action_spec_unbatched,
spec=env.full_action_spec_unbatched,
),
)

Expand Down
8 changes: 4 additions & 4 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,21 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.action_spec_unbatched,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "param")],
out_keys=[env.action_key],
distribution_class=TanhDelta,
distribution_kwargs={
"low": env.action_spec_unbatched[("agents", "action")].space.low,
"high": env.action_spec_unbatched[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=False,
)

policy_explore = TensorDictSequential(
policy,
AdditiveGaussianModule(
spec=env.action_spec_unbatched,
spec=env.full_action_spec_unbatched,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
device=cfg.train.device,
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.action_spec_unbatched,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec_unbatched[("agents", "action")].space.low,
"high": env.action_spec_unbatched[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train(cfg: "DictConfig"): # noqa: F821
("agents", "action_value"),
("agents", "chosen_action_value"),
],
spec=env.action_spec_unbatched,
spec=env.full_action_spec_unbatched,
action_space=None,
)
qnet = SafeSequential(module, value_module)
Expand All @@ -103,7 +103,7 @@ def train(cfg: "DictConfig"): # noqa: F821
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.action_spec_unbatched,
spec=env.full_action_spec_unbatched,
),
)

Expand Down
12 changes: 6 additions & 6 deletions sota-implementations/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def train(cfg: "DictConfig"): # noqa: F821

policy = ProbabilisticActor(
module=policy_module,
spec=env.action_spec_unbatched,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec_unbatched[("agents", "action")].space.low,
"high": env.action_spec_unbatched[("agents", "action")].space.high,
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down Expand Up @@ -146,7 +146,7 @@ def train(cfg: "DictConfig"): # noqa: F821
)
policy = ProbabilisticActor(
module=policy_module,
spec=env.action_spec_unbatched,
spec=env.full_action_spec_unbatched,
in_keys=[("agents", "logits")],
out_keys=[env.action_key],
distribution_class=OneHotCategorical
Expand Down Expand Up @@ -194,7 +194,7 @@ def train(cfg: "DictConfig"): # noqa: F821
actor_network=policy,
qvalue_network=value_module,
delay_qvalue=True,
action_spec=env.action_spec_unbatched,
action_spec=env.full_action_spec_unbatched,
)
loss_module.set_keys(
state_action_value=("agents", "state_action_value"),
Expand All @@ -209,7 +209,7 @@ def train(cfg: "DictConfig"): # noqa: F821
qvalue_network=value_module,
delay_qvalue=True,
num_actions=env.action_spec.space.n,
action_space=env.action_spec_unbatched,
action_space=env.full_action_spec_unbatched,
)
loss_module.set_keys(
action_value=("agents", "action_value"),
Expand Down

0 comments on commit f9c4e00

Please sign in to comment.