Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Add Graphs as States for torchgfn #210

Draft
wants to merge 31 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9ae28b2
including Graphs as States for torchgfn
alip67 Nov 6, 2024
de6ab1c
add GraphEnv
younik Nov 7, 2024
24e23e8
add deps and reformat
younik Nov 7, 2024
1f7b220
add test, fix errors, add valid action check
younik Nov 8, 2024
63e4f1c
fix formatting
younik Nov 8, 2024
8034fb2
add GraphAction
younik Nov 14, 2024
d179671
fix batching mechanism
younik Nov 14, 2024
e018f4e
Merge branch 'GFNOrg:master' into graph-states
alip67 Nov 15, 2024
7ff96d5
add support for EXIT action
younik Nov 16, 2024
cf482da
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
younik Nov 16, 2024
dacbbf7
add GraphActionPolicyEstimator
younik Nov 19, 2024
98ea448
Merge branch 'GFNOrg:master' into graph-states
alip67 Nov 19, 2024
e74e500
Sampler integration work
younik Nov 22, 2024
a862bb4
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
younik Nov 22, 2024
5e64c84
use TensorDict
younik Nov 26, 2024
81f8b71
solve some errors
younik Nov 28, 2024
34781ef
use tensordict in actions
younik Nov 28, 2024
3e584f2
handle sf
younik Dec 2, 2024
d5e438f
remove Data
younik Dec 3, 2024
fba5d50
categorical action type
younik Dec 6, 2024
478bd14
change batching
younik Dec 10, 2024
dd80f28
fix stacking
younik Dec 11, 2024
616551c
fix graph stacking
younik Dec 11, 2024
77611d4
fix test graph env
younik Dec 12, 2024
5874ff6
add ring example
younik Dec 19, 2024
9d42332
remove check edge_features
younik Dec 20, 2024
2d44242
fix GraphStates set
younik Dec 20, 2024
173d4fb
remove debug
younik Dec 20, 2024
7265857
fix add_edge action
younik Dec 20, 2024
2b3208f
fix edge_index after get
younik Dec 20, 2024
b84246f
push updated code
younik Dec 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ einops = ">=0.6.1"
numpy = ">=1.21.2"
python = "^3.10"
torch = ">=1.9.0"
torch_geometric = ">=2.6.0"

# dev dependencies.
black = { version = "24.3", optional = true }
Expand Down
107 changes: 106 additions & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Optional, Tuple, Union

import torch
from torch_geometric.data import Batch, Data

from gfn.actions import Actions
from gfn.preprocessors import IdentityPreprocessor, Preprocessor
from gfn.states import DiscreteStates, States
from gfn.states import DiscreteStates, GraphStates, States
from gfn.utils.common import set_seed

# Errors
Expand Down Expand Up @@ -559,3 +560,107 @@ def terminating_states(self) -> DiscreteStates:
raise NotImplementedError(
"The environment does not support enumeration of states"
)


class GraphEnv(Env):
"""Base class for graph-based environments."""

def __init__(
self,
s0: Data,
node_feature_dim: int,
edge_feature_dim: int,
action_shape: Tuple,
dummy_action: torch.Tensor,
exit_action: torch.Tensor,
sf: Optional[Data] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
):
"""Initializes a graph-based environment.

Args:
s0: The initial graph state.
node_feature_dim: The dimension of the node features.
edge_feature_dim: The dimension of the edge features.
action_shape: Tuple representing the shape of the actions.
dummy_action: Tensor of shape "action_shape" representing a dummy action.
exit_action: Tensor of shape "action_shape" representing the exit action.
sf: The final graph state.
device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is
inferred from s0.
preprocessor: a Preprocessor object that converts raw graph states to a tensor
that can be fed into a neural network. Defaults to None, in which case
the IdentityPreprocessor is used.
"""
self.device = get_device(device_str, default_device=s0.x.device)

if sf is None:
sf = Data(
x=torch.full((s0.num_nodes, node_feature_dim), -float("inf")).to(
self.device
),
edge_attr=torch.full(
(s0.num_edges, edge_feature_dim), -float("inf")
).to(self.device),
edge_index=s0.edge_index,
batch=torch.zeros(s0.num_nodes, dtype=torch.long, device=self.device),
)

super().__init__(
s0=s0,
state_shape=(s0.num_nodes, node_feature_dim),
action_shape=action_shape,
dummy_action=dummy_action,
exit_action=exit_action,
sf=sf,
device_str=device_str,
preprocessor=preprocessor,
)

self.node_feature_dim = node_feature_dim
self.edge_feature_dim = edge_feature_dim
self.GraphStates = self.make_graph_states_class()

def make_graph_states_class(self) -> type[GraphStates]:
env = self

class GraphEnvStates(GraphStates):
s0 = env.s0
sf = env.sf
node_feature_dim = env.node_feature_dim
edge_feature_dim = env.edge_feature_dim
make_random_states_graph = env.make_random_states_graph

return GraphEnvStates

def states_from_tensor(self, tensor: Batch) -> GraphStates:
"""Wraps the supplied Batch in a GraphStates instance."""
return self.GraphStates(tensor)

def states_from_batch_shape(self, batch_shape: int) -> GraphStates:
"""Returns a batch of s0 states with a given batch_shape."""
return self.GraphStates.from_batch_shape(batch_shape)

@abstractmethod
def step(self, states: GraphStates, actions: Actions) -> GraphStates:
"""Function that takes a batch of graph states and actions and returns a batch of next
graph states."""

@abstractmethod
def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates:
"""Function that takes a batch of graph states and actions and returns a batch of previous
graph states."""

@abstractmethod
def is_action_valid(
self,
states: GraphStates,
actions: Actions,
backward: bool = False,
) -> bool:
"""Returns True if the actions are valid in the given graph states."""

@abstractmethod
def make_random_states_graph(self, batch_shape: int) -> Batch:
"""Optional method inherited by all GraphStates instances to emit a random Batch of graphs."""
118 changes: 118 additions & 0 deletions src/gfn/gym/graph_building.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from copy import copy
from typing import Callable, Literal

import torch
from torch_geometric.data import Batch, Data
from torch_geometric.nn import GCNConv

from gfn.actions import Actions
from gfn.env import GraphEnv
from gfn.states import GraphStates


class GraphBuilding(GraphEnv):
def __init__(
self,
num_nodes: int,
node_feature_dim: int,
edge_feature_dim: int,
state_evaluator: Callable[[Batch], torch.Tensor] | None = None,
device_str: Literal["cpu", "cuda"] = "cpu",
):
s0 = Data(x=torch.zeros((num_nodes, node_feature_dim)).to(device_str))
exit_action = torch.tensor(
[-float("inf"), -float("inf")], device=torch.device(device_str)
)
dummy_action = torch.tensor(
[float("inf"), float("inf")], device=torch.device(device_str)
)
if state_evaluator is None:
state_evaluator = GCNConvEvaluator(node_feature_dim)
self.state_evaluator = state_evaluator

super().__init__(
s0=s0,
node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim,
action_shape=(2,),
dummy_action=dummy_action,
exit_action=exit_action,
device_str=device_str,
)

def step(self, states: GraphStates, actions: Actions) -> GraphStates:
"""Step function for the GraphBuilding environment.

Args:
states: GraphStates object representing the current graph.
actions: Actions indicating which edge to add.

Returns the next graph the new GraphStates.
"""
graphs: Batch = copy.deepcopy(states.data)
assert len(graphs) == len(actions)

for i, act in enumerate(actions.tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self - we should evaluate if this can be vectorized.

edge_index = torch.cat([graphs[i].edge_index, act.unsqueeze(1)], dim=1)
graphs[i].edge_index = edge_index

return GraphStates(graphs)

def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates:
"""Backward step function for the GraphBuilding environment.

Args:
states: GraphStates object representing the current graph.
actions: Actions indicating which edge to remove.

Returns the previous graph as a new GraphStates.
"""
graphs: Batch = copy.deepcopy(states.data)
assert len(graphs) == len(actions)

for i, act in enumerate(actions.tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self - we should evaluate if this can be vectorized.

edge_index = graphs[i].edge_index
edge_index = edge_index[:, edge_index[1] != act]
graphs[i].edge_index = edge_index

return GraphStates(graphs)

def is_action_valid(
self, states: GraphStates, actions: Actions, backward: bool = False
) -> bool:
for i, act in enumerate(actions.tensor):
if backward and len(states.data[i].edge_index[1]) == 0:
return False
if not backward and torch.any(states.data[i].edge_index[1] == act):
return False
return True

def reward(self, final_states: GraphStates) -> torch.Tensor:
"""The environment's reward given a state.
This or log_reward must be implemented.

Args:
final_states: A batch of final states.

Returns:
torch.Tensor: Tensor of shape "batch_shape" containing the rewards.
"""
return self.state_evaluator(final_states.data).sum(dim=1)

@property
def log_partition(self) -> float:
"Returns the logarithm of the partition function."
raise NotImplementedError

@property
def true_dist_pmf(self) -> torch.Tensor:
"Returns a one-dimensional tensor representing the true distribution."
raise NotImplementedError


class GCNConvEvaluator:
def __init__(self, num_features):
self.net = GCNConv(num_features, 1)

def __call__(self, batch: Batch) -> torch.Tensor:
return self.net(batch.x, batch.edge_index)
141 changes: 141 additions & 0 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable, ClassVar, List, Optional, Sequence

import torch
from torch_geometric.data import Batch, Data


class States(ABC):
Expand Down Expand Up @@ -501,3 +502,143 @@ def stack_states(states: List[States]):
) + state_example.batch_shape

return stacked_states


class GraphStates(ABC):
"""
Base class for Graph as a state representation. The `GraphStates` object is a batched collection of
multiple graph objects. The `Batch` object from PyTorch Geometric is used to represent the batch of
graph objects as states.
"""

s0: ClassVar[Data]
sf: ClassVar[Data]
node_feature_dim: ClassVar[int]
edge_feature_dim: ClassVar[int]
make_random_states_graph: Callable = lambda x: (_ for _ in ()).throw(
NotImplementedError(
"The environment does not support initialization of random Graph states."
)
)

def __init__(self, graphs: Batch):
self.data: Batch = graphs
self.batch_shape: int = self.data.num_graphs
self._log_rewards: float = None

@classmethod
def from_batch_shape(
cls, batch_shape: int, random: bool = False, sink: bool = False
) -> GraphStates:
if random and sink:
raise ValueError("Only one of `random` and `sink` should be True.")
if random:
data = cls.make_random_states_graph(batch_shape)
elif sink:
data = cls.make_sink_states_graph(batch_shape)
else:
data = cls.make_initial_states_graph(batch_shape)
return cls(data)

@classmethod
def make_initial_states_graph(cls, batch_shape: int) -> Batch:
data = Batch.from_data_list([cls.s0 for _ in range(batch_shape)])
return data

@classmethod
def make_sink_states_graph(cls, batch_shape: int) -> Batch:
data = Batch.from_data_list([cls.sf for _ in range(batch_shape)])
return data

# @classmethod
# def make_random_states_graph(cls, batch_shape: int) -> Batch:
# data = Batch.from_data_list([cls.make_random_states_graph() for _ in range(batch_shape)])
# return data

def __len__(self):
return self.data.batch_size

def __repr__(self):
return (
f"{self.__class__.__name__} object of batch shape {self.batch_shape} and "
f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}"
)

def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates:
if isinstance(index, int):
out = self.__class__(Batch.from_data_list([self.data[index]]))
elif isinstance(index, (Sequence, slice)):
out = self.__class__(Batch.from_data_list(self.data.index_select(index)))
else:
raise NotImplementedError(
"Indexing with type {} is not implemented".format(type(index))
)

if self._log_rewards is not None:
out._log_rewards = self._log_rewards[index]

return out

def __setitem__(self, index: int | Sequence[int], graph: GraphStates):
"""
Set particular states of the Batch
"""
data_list = self.data.to_data_list()
if isinstance(index, int):
assert (
len(graph) == 1
), "GraphStates must have a batch size of 1 for single index assignment"
data_list[index] = graph.data[0]
self.data = Batch.from_data_list(data_list)
elif isinstance(index, Sequence):
assert len(index) == len(
graph
), "Index and GraphState must have the same length"
for i, idx in enumerate(index):
data_list[idx] = graph.data[i]
self.data = Batch.from_data_list(data_list)
elif isinstance(index, slice):
assert index.stop - index.start == len(
graph
), "Index slice and GraphStates must have the same length"
data_list[index] = graph.data.to_data_list()
self.data = Batch.from_data_list(data_list)
else:
raise NotImplementedError(
"Setters with type {} is not implemented".format(type(index))
)

@property
def device(self) -> torch.device:
return self.data.get_example(0).x.device

def to(self, device: torch.device) -> GraphStates:
"""
Moves and/or casts the graph states to the specified device
"""
if self.device != device:
self.data = self.data.to(device)
return self

def clone(self) -> GraphStates:
"""Returns a *detached* clone of the current instance using deepcopy."""
return deepcopy(self)

def extend(self, other: GraphStates):
"""Concatenates to another GraphStates object along the batch dimension"""
self.data = Batch.from_data_list(
self.data.to_data_list() + other.data.to_data_list()
)
if self._log_rewards is not None:
assert other._log_rewards is not None
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards), dim=0
)

@property
def log_rewards(self) -> torch.Tensor:
return self._log_rewards

@log_rewards.setter
def log_rewards(self, log_rewards: torch.Tensor) -> None:
self._log_rewards = log_rewards