Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add graph-version of sb3 DQN algorithm #458

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 234 additions & 32 deletions skdecide/hub/solver/stable_baselines/gnn/common/buffers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Optional, TypeVar, Union
from typing import Any, Dict, List, Optional, TypeVar, Union

import numpy as np
import torch as th
Expand All @@ -10,9 +10,18 @@
MaskableRolloutBuffer,
MaskableRolloutBufferSamples,
)
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.buffers import (
BaseBuffer,
DictReplayBuffer,
DictRolloutBuffer,
ReplayBuffer,
RolloutBuffer,
)
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.type_aliases import RolloutBufferSamples
from stable_baselines3.common.type_aliases import (
ReplayBufferSamples,
RolloutBufferSamples,
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize

Expand Down Expand Up @@ -51,26 +60,13 @@ def get_obs_shape(
)


class GraphRolloutBuffer(RolloutBuffer):
"""Rollout buffer used in on-policy algorithms like A2C/PPO with graph observations.

Handles cases where observation space is:
- a Graph space
- a Dict space whose subspaces includes a Graph space

"""

observations: Union[list[spaces.GraphInstance], list[list[spaces.GraphInstance]]]
tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]

class GraphBaseBuffer(BaseBuffer):
def __init__(
self,
buffer_size: int,
observation_space: Union[spaces.Graph, spaces.Dict],
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
self.buffer_size = buffer_size
Expand All @@ -82,11 +78,29 @@ def __init__(
self.full = False
self.device = get_device(device)
self.n_envs = n_envs
self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False

self.reset()
def _graphlist_to_torch(
self, graph_list: list[spaces.GraphInstance], batch_inds: np.ndarray
) -> thg.data.Data:
return thg.data.Batch.from_data_list(
[
graph_obs_to_thg_data(graph_list[idx], device=self.device)
for idx in batch_inds
]
)


class GraphRolloutBuffer(RolloutBuffer, GraphBaseBuffer):
"""Rollout buffer used in on-policy algorithms like A2C/PPO with graph observations.

Handles cases where observation space is:
- a Graph space
- a Dict space whose subspaces includes a Graph space

"""

observations: Union[list[spaces.GraphInstance], list[list[spaces.GraphInstance]]]
tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]

def reset(self) -> None:
assert isinstance(
Expand Down Expand Up @@ -165,16 +179,6 @@ def _get_samples(
def _get_observations_samples(self, batch_inds: np.ndarray) -> thg.data.Data:
return self._graphlist_to_torch(self.observations, batch_inds=batch_inds)

def _graphlist_to_torch(
self, graph_list: list[spaces.GraphInstance], batch_inds: np.ndarray
) -> thg.data.Data:
return thg.data.Batch.from_data_list(
[
graph_obs_to_thg_data(graph_list[idx], device=self.device)
for idx in batch_inds
]
)


class DictGraphRolloutBuffer(GraphRolloutBuffer, DictRolloutBuffer):

Expand All @@ -185,7 +189,6 @@ class DictGraphRolloutBuffer(GraphRolloutBuffer, DictRolloutBuffer):
np.ndarray,
],
]
obs_shape: dict[str, tuple[int, ...]]

def __init__(
self,
Expand Down Expand Up @@ -295,6 +298,205 @@ class MaskableDictGraphRolloutBuffer(
...


class GraphReplayBuffer(ReplayBuffer, GraphBaseBuffer):
observations: list[spaces.GraphInstance]
next_observations: list[spaces.GraphInstance]

def __init__(
self,
buffer_size: int,
observation_space: Union[spaces.Graph, spaces.Dict],
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
super().__init__(
buffer_size=buffer_size,
observation_space=observation_space,
action_space=action_space,
device=device,
n_envs=n_envs,
optimize_memory_usage=optimize_memory_usage,
handle_timeout_termination=handle_timeout_termination,
)
if optimize_memory_usage:
raise NotImplementedError(
"No memory usage optimization implemented for GraphReplayBuffer."
)
if n_envs > 1:
raise NotImplementedError(
"No multiple vectorized environements implemented for GraphReplayBuffer."
)

self._init_observations()

def _init_observations(self):
self.observations = list()
self.next_observations = list()

def add(
self,
obs: list[spaces.GraphInstance],
next_obs: list[spaces.GraphInstance],
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:

self._add_obs(obs=obs, next_obs=next_obs)

# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.dones[self.pos] = np.array(done)

if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array(
[info.get("TimeLimit.truncated", False) for info in infos]
)

self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0

def _add_obs(
self, obs: list[spaces.GraphInstance], next_obs: list[spaces.GraphInstance]
) -> None:
self.observations.append(obs)
self.next_observations.append(next_obs)

def _get_observations_samples(
self, observations: list[spaces.GraphInstance], batch_inds: np.ndarray
) -> thg.data.Data:
return self._graphlist_to_torch(observations, batch_inds=batch_inds)

def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> ReplayBufferSamples:
if env is not None:
raise NotImplementedError(
"observation normalization not yet implemented for graphReplayBuffer."
)
env_indices = 0 # single env
return ReplayBufferSamples(
observations=self._get_observations_samples(
self.observations, batch_inds=batch_inds
),
actions=self.to_torch(self.actions[batch_inds, env_indices, :]),
next_observations=self._get_observations_samples(
self.next_observations, batch_inds=batch_inds
),
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
dones=self.to_torch(
(
self.dones[batch_inds, env_indices]
* (1 - self.timeouts[batch_inds, env_indices])
).reshape(-1, 1)
),
rewards=self.to_torch(
self._normalize_reward(
self.rewards[batch_inds, env_indices].reshape(-1, 1), env
)
),
)


class DictGraphReplayBuffer(GraphReplayBuffer, DictReplayBuffer):
observations: dict[
str,
Union[
list[spaces.GraphInstance],
np.ndarray,
],
]
next_observations: dict[
str,
Union[
list[spaces.GraphInstance],
np.ndarray,
],
]

def __init__(
self,
buffer_size: int,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
self.is_observation_subspace_graph: dict[str, bool] = {
k: isinstance(space, spaces.Graph)
for k, space in observation_space.spaces.items()
}
super().__init__(
buffer_size,
observation_space,
action_space,
device,
n_envs,
optimize_memory_usage,
handle_timeout_termination,
)

def _init_observations(self):
for k, is_graph in self.is_observation_subspace_graph.items():
if is_graph:
self.observations[k] = list()
self.next_observations[k] = list()

def _add_obs(
self,
obs: dict[str, Union[np.ndarray, list[spaces.GraphInstance]]],
next_obs: dict[str, Union[np.ndarray, list[spaces.GraphInstance]]],
) -> None:
for key in self.observations.keys():
if self.is_observation_subspace_graph[key]:
self.observations[key].append(
[copy_graph_instance(g) for g in obs[key]]
)
self.next_observations[key].append(
[copy_graph_instance(g) for g in next_obs[key]]
)
else:
obs_ = np.array(obs[key])
next_obs_ = np.array(next_obs[key])
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
next_obs_ = next_obs_.reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos] = obs_
self.next_observations[key][self.pos] = next_obs_

def _get_observations_samples(
self,
observations: dict[
str,
Union[
list[spaces.GraphInstance],
np.ndarray,
],
],
batch_inds: np.ndarray,
) -> dict[str, Union[thg.data.Data, th.Tensor]]:
return {
k: self._graphlist_to_torch(obs, batch_inds=batch_inds)
if self.is_observation_subspace_graph[k]
else self.to_torch(obs[batch_inds])
for k, obs in observations.items()
}


T = TypeVar("T")


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Optional, Union

from gymnasium import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv

from .buffers import DictGraphReplayBuffer, GraphReplayBuffer
from .vec_env.dummy_vec_env import wrap_graph_env


class GraphOffPolicyAlgorithm(OffPolicyAlgorithm):
"""Base class for On-Policy algorithms (ex: SAC/TD3) with graph observations."""

def __init__(
self,
policy: Union[str, type[ActorCriticPolicy]],
env: GymEnv,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
**kwargs,
):

# Use proper default rollout buffer class
if replay_buffer_class is None:
if isinstance(env.observation_space, spaces.Graph):
replay_buffer_class = GraphReplayBuffer
elif isinstance(env.observation_space, spaces.Dict):
replay_buffer_class = DictGraphReplayBuffer

# Use proper VecEnv wrapper for env with Graph spaces
env = wrap_graph_env(env)
if env.num_envs > 1:
raise NotImplementedError(
"GraphOnPolicyAlgorithm not implemented for real vectorized environment "
"(ie. with more than 1 wrapped environment)"
)

super().__init__(
policy=policy,
env=env,
replay_buffer_class=replay_buffer_class,
**kwargs,
)
Loading
Loading