-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Put in common gnn features extractor between ray.rllib and sb3 wrappers
The GNN features extractor used by rallib and sb3 are the same (except the one for sb3 derives from a BaseFeaturesExtractor) and so is the function conveting GraphInstance into torch_geometric.data.Data. We put that in common in hub/solver/utils/gnn. We also remove unused code in ray.rllib gnn code.
- Loading branch information
Showing
10 changed files
with
50 additions
and
251 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
116 changes: 0 additions & 116 deletions
116
skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from typing import Optional | ||
|
||
import gymnasium as gym | ||
import torch as th | ||
import torch_geometric as thg | ||
|
||
|
||
def graph_obs_to_thg_data( | ||
obs: gym.spaces.GraphInstance, | ||
device: Optional[th.device] = None, | ||
pin_memory: bool = False, | ||
) -> thg.data.Data: | ||
# Node features | ||
flatten_node_features = obs.nodes.reshape((len(obs.nodes), -1)) | ||
x = th.tensor(flatten_node_features).float() | ||
# Edge features | ||
if obs.edges is None: | ||
edge_attr = None | ||
else: | ||
flatten_edge_features = obs.edges.reshape((len(obs.edges), -1)) | ||
edge_attr = th.tensor(flatten_edge_features).float() | ||
edge_index = th.tensor(obs.edge_links, dtype=th.long).t().contiguous().view(2, -1) | ||
# thg.Data | ||
data = thg.data.Data(x=x, edge_index=edge_index, edge_attr=edge_attr) | ||
# Pin the tensor's memory (for faster transfer to GPU later). | ||
if pin_memory and th.cuda.is_available(): | ||
data.pin_memory() | ||
|
||
return data if device is None else data.to(device) |