-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add graph-version of sb3 DQN algorithm
This is an off-policy algorithm. So we need to implement graph versions of [Dict]ReplayBuffer and OffPolicyalgorithm.
- Loading branch information
Showing
7 changed files
with
484 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
skdecide/hub/solver/stable_baselines/gnn/common/off_policy_algorithm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from typing import Optional, Union | ||
|
||
from gymnasium import spaces | ||
from stable_baselines3.common.buffers import ReplayBuffer | ||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm | ||
from stable_baselines3.common.policies import ActorCriticPolicy | ||
from stable_baselines3.common.type_aliases import GymEnv | ||
|
||
from .buffers import DictGraphReplayBuffer, GraphReplayBuffer | ||
from .vec_env.dummy_vec_env import wrap_graph_env | ||
|
||
|
||
class GraphOffPolicyAlgorithm(OffPolicyAlgorithm): | ||
"""Base class for On-Policy algorithms (ex: SAC/TD3) with graph observations.""" | ||
|
||
def __init__( | ||
self, | ||
policy: Union[str, type[ActorCriticPolicy]], | ||
env: GymEnv, | ||
replay_buffer_class: Optional[type[ReplayBuffer]] = None, | ||
**kwargs, | ||
): | ||
|
||
# Use proper default rollout buffer class | ||
if replay_buffer_class is None: | ||
if isinstance(env.observation_space, spaces.Graph): | ||
replay_buffer_class = GraphReplayBuffer | ||
elif isinstance(env.observation_space, spaces.Dict): | ||
replay_buffer_class = DictGraphReplayBuffer | ||
|
||
# Use proper VecEnv wrapper for env with Graph spaces | ||
env = wrap_graph_env(env) | ||
if env.num_envs > 1: | ||
raise NotImplementedError( | ||
"GraphOnPolicyAlgorithm not implemented for real vectorized environment " | ||
"(ie. with more than 1 wrapped environment)" | ||
) | ||
|
||
super().__init__( | ||
policy=policy, | ||
env=env, | ||
replay_buffer_class=replay_buffer_class, | ||
**kwargs, | ||
) |
Oops, something went wrong.