Skip to content

Commit

Permalink
Add graph-version of sb3 DQN algorithm
Browse files Browse the repository at this point in the history
This is an off-policy algorithm. So we need to implement graph versions of
[Dict]ReplayBuffer and OffPolicyalgorithm.
  • Loading branch information
nhuet authored and fteicht committed Jan 17, 2025
1 parent 4701c12 commit 9d825f7
Show file tree
Hide file tree
Showing 7 changed files with 484 additions and 63 deletions.
266 changes: 234 additions & 32 deletions skdecide/hub/solver/stable_baselines/gnn/common/buffers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Optional, TypeVar, Union
from typing import Any, Dict, List, Optional, TypeVar, Union

import numpy as np
import torch as th
Expand All @@ -10,9 +10,18 @@
MaskableRolloutBuffer,
MaskableRolloutBufferSamples,
)
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.buffers import (
BaseBuffer,
DictReplayBuffer,
DictRolloutBuffer,
ReplayBuffer,
RolloutBuffer,
)
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.type_aliases import RolloutBufferSamples
from stable_baselines3.common.type_aliases import (
ReplayBufferSamples,
RolloutBufferSamples,
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize

Expand Down Expand Up @@ -51,26 +60,13 @@ def get_obs_shape(
)


class GraphRolloutBuffer(RolloutBuffer):
"""Rollout buffer used in on-policy algorithms like A2C/PPO with graph observations.
Handles cases where observation space is:
- a Graph space
- a Dict space whose subspaces includes a Graph space
"""

observations: Union[list[spaces.GraphInstance], list[list[spaces.GraphInstance]]]
tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]

class GraphBaseBuffer(BaseBuffer):
def __init__(
self,
buffer_size: int,
observation_space: Union[spaces.Graph, spaces.Dict],
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
self.buffer_size = buffer_size
Expand All @@ -82,11 +78,29 @@ def __init__(
self.full = False
self.device = get_device(device)
self.n_envs = n_envs
self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False

self.reset()
def _graphlist_to_torch(
self, graph_list: list[spaces.GraphInstance], batch_inds: np.ndarray
) -> thg.data.Data:
return thg.data.Batch.from_data_list(
[
graph_obs_to_thg_data(graph_list[idx], device=self.device)
for idx in batch_inds
]
)


class GraphRolloutBuffer(RolloutBuffer, GraphBaseBuffer):
"""Rollout buffer used in on-policy algorithms like A2C/PPO with graph observations.
Handles cases where observation space is:
- a Graph space
- a Dict space whose subspaces includes a Graph space
"""

observations: Union[list[spaces.GraphInstance], list[list[spaces.GraphInstance]]]
tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]

def reset(self) -> None:
assert isinstance(
Expand Down Expand Up @@ -165,16 +179,6 @@ def _get_samples(
def _get_observations_samples(self, batch_inds: np.ndarray) -> thg.data.Data:
return self._graphlist_to_torch(self.observations, batch_inds=batch_inds)

def _graphlist_to_torch(
self, graph_list: list[spaces.GraphInstance], batch_inds: np.ndarray
) -> thg.data.Data:
return thg.data.Batch.from_data_list(
[
graph_obs_to_thg_data(graph_list[idx], device=self.device)
for idx in batch_inds
]
)


class DictGraphRolloutBuffer(GraphRolloutBuffer, DictRolloutBuffer):

Expand All @@ -185,7 +189,6 @@ class DictGraphRolloutBuffer(GraphRolloutBuffer, DictRolloutBuffer):
np.ndarray,
],
]
obs_shape: dict[str, tuple[int, ...]]

def __init__(
self,
Expand Down Expand Up @@ -295,6 +298,205 @@ class MaskableDictGraphRolloutBuffer(
...


class GraphReplayBuffer(ReplayBuffer, GraphBaseBuffer):
observations: list[spaces.GraphInstance]
next_observations: list[spaces.GraphInstance]

def __init__(
self,
buffer_size: int,
observation_space: Union[spaces.Graph, spaces.Dict],
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
super().__init__(
buffer_size=buffer_size,
observation_space=observation_space,
action_space=action_space,
device=device,
n_envs=n_envs,
optimize_memory_usage=optimize_memory_usage,
handle_timeout_termination=handle_timeout_termination,
)
if optimize_memory_usage:
raise NotImplementedError(
"No memory usage optimization implemented for GraphReplayBuffer."
)
if n_envs > 1:
raise NotImplementedError(
"No multiple vectorized environements implemented for GraphReplayBuffer."
)

self._init_observations()

def _init_observations(self):
self.observations = list()
self.next_observations = list()

def add(
self,
obs: list[spaces.GraphInstance],
next_obs: list[spaces.GraphInstance],
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:

self._add_obs(obs=obs, next_obs=next_obs)

# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.dones[self.pos] = np.array(done)

if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array(
[info.get("TimeLimit.truncated", False) for info in infos]
)

self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0

def _add_obs(
self, obs: list[spaces.GraphInstance], next_obs: list[spaces.GraphInstance]
) -> None:
self.observations.append(obs)
self.next_observations.append(next_obs)

def _get_observations_samples(
self, observations: list[spaces.GraphInstance], batch_inds: np.ndarray
) -> thg.data.Data:
return self._graphlist_to_torch(observations, batch_inds=batch_inds)

def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> ReplayBufferSamples:
if env is not None:
raise NotImplementedError(
"observation normalization not yet implemented for graphReplayBuffer."
)
env_indices = 0 # single env
return ReplayBufferSamples(
observations=self._get_observations_samples(
self.observations, batch_inds=batch_inds
),
actions=self.to_torch(self.actions[batch_inds, env_indices, :]),
next_observations=self._get_observations_samples(
self.next_observations, batch_inds=batch_inds
),
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
dones=self.to_torch(
(
self.dones[batch_inds, env_indices]
* (1 - self.timeouts[batch_inds, env_indices])
).reshape(-1, 1)
),
rewards=self.to_torch(
self._normalize_reward(
self.rewards[batch_inds, env_indices].reshape(-1, 1), env
)
),
)


class DictGraphReplayBuffer(GraphReplayBuffer, DictReplayBuffer):
observations: dict[
str,
Union[
list[spaces.GraphInstance],
np.ndarray,
],
]
next_observations: dict[
str,
Union[
list[spaces.GraphInstance],
np.ndarray,
],
]

def __init__(
self,
buffer_size: int,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
self.is_observation_subspace_graph: dict[str, bool] = {
k: isinstance(space, spaces.Graph)
for k, space in observation_space.spaces.items()
}
super().__init__(
buffer_size,
observation_space,
action_space,
device,
n_envs,
optimize_memory_usage,
handle_timeout_termination,
)

def _init_observations(self):
for k, is_graph in self.is_observation_subspace_graph.items():
if is_graph:
self.observations[k] = list()
self.next_observations[k] = list()

def _add_obs(
self,
obs: dict[str, Union[np.ndarray, list[spaces.GraphInstance]]],
next_obs: dict[str, Union[np.ndarray, list[spaces.GraphInstance]]],
) -> None:
for key in self.observations.keys():
if self.is_observation_subspace_graph[key]:
self.observations[key].append(
[copy_graph_instance(g) for g in obs[key]]
)
self.next_observations[key].append(
[copy_graph_instance(g) for g in next_obs[key]]
)
else:
obs_ = np.array(obs[key])
next_obs_ = np.array(next_obs[key])
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
next_obs_ = next_obs_.reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos] = obs_
self.next_observations[key][self.pos] = next_obs_

def _get_observations_samples(
self,
observations: dict[
str,
Union[
list[spaces.GraphInstance],
np.ndarray,
],
],
batch_inds: np.ndarray,
) -> dict[str, Union[thg.data.Data, th.Tensor]]:
return {
k: self._graphlist_to_torch(obs, batch_inds=batch_inds)
if self.is_observation_subspace_graph[k]
else self.to_torch(obs[batch_inds])
for k, obs in observations.items()
}


T = TypeVar("T")


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Optional, Union

from gymnasium import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv

from .buffers import DictGraphReplayBuffer, GraphReplayBuffer
from .vec_env.dummy_vec_env import wrap_graph_env


class GraphOffPolicyAlgorithm(OffPolicyAlgorithm):
"""Base class for On-Policy algorithms (ex: SAC/TD3) with graph observations."""

def __init__(
self,
policy: Union[str, type[ActorCriticPolicy]],
env: GymEnv,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
**kwargs,
):

# Use proper default rollout buffer class
if replay_buffer_class is None:
if isinstance(env.observation_space, spaces.Graph):
replay_buffer_class = GraphReplayBuffer
elif isinstance(env.observation_space, spaces.Dict):
replay_buffer_class = DictGraphReplayBuffer

# Use proper VecEnv wrapper for env with Graph spaces
env = wrap_graph_env(env)
if env.num_envs > 1:
raise NotImplementedError(
"GraphOnPolicyAlgorithm not implemented for real vectorized environment "
"(ie. with more than 1 wrapped environment)"
)

super().__init__(
policy=policy,
env=env,
replay_buffer_class=replay_buffer_class,
**kwargs,
)
Loading

0 comments on commit 9d825f7

Please sign in to comment.