From 9ae28b2419953b3bbe873cc7913d1d6447f8c240 Mon Sep 17 00:00:00 2001 From: alip67 Date: Wed, 6 Nov 2024 21:34:27 +0900 Subject: [PATCH 01/27] including Graphs as States for torchgfn --- src/gfn/states.py | 128 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 127 insertions(+), 1 deletion(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index c95ac91d..a4a15f7b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,9 +3,10 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, List, Optional, Sequence +from typing import Callable, ClassVar, List, Optional, Sequence, cast, Tuple import torch +from torch_geometric.data import Batch, Data class States(ABC): @@ -501,3 +502,128 @@ def stack_states(states: List[States]): ) + state_example.batch_shape return stacked_states + +class GraphStates(ABC): + """ + Base class for Graph as a state representation. The `GraphStates` object is a batched collection of + multiple graph objects. The `Batch` object from PyTorch Geometric is used to represent the batch of + graph objects as states. + """ + + + s0: ClassVar[Data] + sf: ClassVar[Data] + node_feature_dim: ClassVar[int] + edge_feature_dim: ClassVar[int] + make_random_states_graph: Callable = lambda x: (_ for _ in ()).throw( + NotImplementedError( + "The environment does not support initialization of random Graph states." + ) + ) + + def __init__(self, graphs: Batch): + self.data: Batch = graphs + self.batch_shape: int = self.data.num_graphs + self._log_rewards: float = None + + @classmethod + def from_batch_shape(cls, batch_shape: int, random: bool = False, sink: bool=False) -> GraphStates: + if random and sink: + raise ValueError("Only one of `random` and `sink` should be True.") + if random: + data = cls.make_random_states_graph(batch_shape) + elif sink: + data = cls.make_sink_states_graph(batch_shape) + else: + data = cls.make_initial_states_graph(batch_shape) + return cls(data) + + @classmethod + def make_initial_states_graph(cls, batch_shape: int) -> Batch: + data = Batch.from_data_list([cls.s0 for _ in range(batch_shape)]) + return data + + @classmethod + def make_sink_states_graph(cls, batch_shape: int) -> Batch: + data = Batch.from_data_list([cls.sf for _ in range(batch_shape)]) + return data + + # @classmethod + # def make_random_states_graph(cls, batch_shape: int) -> Batch: + # data = Batch.from_data_list([cls.make_random_states_graph() for _ in range(batch_shape)]) + # return data + + def __len__(self): + return self.data.batch_size + + def __repr__(self): + return (f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " + f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}") + + def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates: + if isinstance(index, int): + out = self.__class__(Batch.from_data_list([self.data[index]])) + elif isinstance(index, (Sequence, slice)): + out = self.__class__(Batch.from_data_list(self.data.index_select(index))) + else: + raise NotImplementedError("Indexing with type {} is not implemented".format(type(index))) + + if self._log_rewards is not None: + out._log_rewards = self._log_rewards[index] + + return out + + def __setitem__(self, index: int | Sequence[int], graph: GraphStates): + """ + Set particular states of the Batch + """ + data_list = self.data.to_data_list() + if isinstance(index, int): + assert len(graph) == 1, "GraphStates must have a batch size of 1 for single index assignment" + data_list[index] = graph.data[0] + self.data = Batch.from_data_list(data_list) + elif isinstance(index, Sequence): + assert len(index) == len(graph), "Index and GraphState must have the same length" + for i, idx in enumerate(index): + data_list[idx] = graph.data[i] + self.data = Batch.from_data_list(data_list) + elif isinstance(index, slice): + assert index.stop - index.start == len(graph), "Index slice and GraphStates must have the same length" + data_list[index] = graph.data.to_data_list() + self.data = Batch.from_data_list(data_list) + else: + raise NotImplementedError("Setters with type {} is not implemented".format(type(index))) + + @property + def device(self) -> torch.device: + return self.data.get_example(0).x.device + + def to(self, device: torch.device) -> GraphStates: + """ + Moves and/or casts the graph states to the specified device + """ + if self.device != device: + self.data = self.data.to(device) + return self + + def clone(self) -> GraphStates: + """Returns a *detached* clone of the current instance using deepcopy.""" + return deepcopy(self) + + def extend(self, other: GraphStates): + """Concatenates to another GraphStates object along the batch dimension""" + self.data = Batch.from_data_list(self.data.to_data_list() + other.data.to_data_list()) + if self._log_rewards is not None: + assert other._log_rewards is not None + self._log_rewards = torch.cat( + (self._log_rewards, other._log_rewards), dim=0 + ) + + + @property + def log_rewards(self) -> torch.Tensor: + return self._log_rewards + + @log_rewards.setter + def log_rewards(self, log_rewards: torch.Tensor) -> None: + self._log_rewards = log_rewards \ No newline at end of file From de6ab1c022a45b5e8462bf283ea30eec19c8e337 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 7 Nov 2024 21:50:10 +0100 Subject: [PATCH 02/27] add GraphEnv --- src/gfn/env.py | 107 +++++++++++++++++++++++++++++- src/gfn/gym/graph_building.py | 118 ++++++++++++++++++++++++++++++++++ src/gfn/states.py | 41 ++++++++---- 3 files changed, 252 insertions(+), 14 deletions(-) create mode 100644 src/gfn/gym/graph_building.py diff --git a/src/gfn/env.py b/src/gfn/env.py index 7a60e8ec..517fa97c 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -2,10 +2,11 @@ from typing import Optional, Tuple, Union import torch +from torch_geometric.data import Batch, Data from gfn.actions import Actions from gfn.preprocessors import IdentityPreprocessor, Preprocessor -from gfn.states import DiscreteStates, States +from gfn.states import DiscreteStates, GraphStates, States from gfn.utils.common import set_seed # Errors @@ -559,3 +560,107 @@ def terminating_states(self) -> DiscreteStates: raise NotImplementedError( "The environment does not support enumeration of states" ) + + +class GraphEnv(Env): + """Base class for graph-based environments.""" + + def __init__( + self, + s0: Data, + node_feature_dim: int, + edge_feature_dim: int, + action_shape: Tuple, + dummy_action: torch.Tensor, + exit_action: torch.Tensor, + sf: Optional[Data] = None, + device_str: Optional[str] = None, + preprocessor: Optional[Preprocessor] = None, + ): + """Initializes a graph-based environment. + + Args: + s0: The initial graph state. + node_feature_dim: The dimension of the node features. + edge_feature_dim: The dimension of the edge features. + action_shape: Tuple representing the shape of the actions. + dummy_action: Tensor of shape "action_shape" representing a dummy action. + exit_action: Tensor of shape "action_shape" representing the exit action. + sf: The final graph state. + device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is + inferred from s0. + preprocessor: a Preprocessor object that converts raw graph states to a tensor + that can be fed into a neural network. Defaults to None, in which case + the IdentityPreprocessor is used. + """ + self.device = get_device(device_str, default_device=s0.x.device) + + if sf is None: + sf = Data( + x=torch.full((s0.num_nodes, node_feature_dim), -float("inf")).to( + self.device + ), + edge_attr=torch.full( + (s0.num_edges, edge_feature_dim), -float("inf") + ).to(self.device), + edge_index=s0.edge_index, + batch=torch.zeros(s0.num_nodes, dtype=torch.long, device=self.device), + ) + + super().__init__( + s0=s0, + state_shape=(s0.num_nodes, node_feature_dim), + action_shape=action_shape, + dummy_action=dummy_action, + exit_action=exit_action, + sf=sf, + device_str=device_str, + preprocessor=preprocessor, + ) + + self.node_feature_dim = node_feature_dim + self.edge_feature_dim = edge_feature_dim + self.GraphStates = self.make_graph_states_class() + + def make_graph_states_class(self) -> type[GraphStates]: + env = self + + class GraphEnvStates(GraphStates): + s0 = env.s0 + sf = env.sf + node_feature_dim = env.node_feature_dim + edge_feature_dim = env.edge_feature_dim + make_random_states_graph = env.make_random_states_graph + + return GraphEnvStates + + def states_from_tensor(self, tensor: Batch) -> GraphStates: + """Wraps the supplied Batch in a GraphStates instance.""" + return self.GraphStates(tensor) + + def states_from_batch_shape(self, batch_shape: int) -> GraphStates: + """Returns a batch of s0 states with a given batch_shape.""" + return self.GraphStates.from_batch_shape(batch_shape) + + @abstractmethod + def step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Function that takes a batch of graph states and actions and returns a batch of next + graph states.""" + + @abstractmethod + def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Function that takes a batch of graph states and actions and returns a batch of previous + graph states.""" + + @abstractmethod + def is_action_valid( + self, + states: GraphStates, + actions: Actions, + backward: bool = False, + ) -> bool: + """Returns True if the actions are valid in the given graph states.""" + + @abstractmethod + def make_random_states_graph(self, batch_shape: int) -> Batch: + """Optional method inherited by all GraphStates instances to emit a random Batch of graphs.""" diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py new file mode 100644 index 00000000..99848870 --- /dev/null +++ b/src/gfn/gym/graph_building.py @@ -0,0 +1,118 @@ +from copy import copy +from typing import Callable, Literal, Tuple + +import torch +from gfn.actions import Actions +from torch_geometric.data import Data, Batch +from torch_geometric.nn import GCNConv +from gfn.env import GraphEnv +from gfn.states import GraphStates + + +class GraphBuilding(GraphEnv): + + def __init__(self, + num_nodes: int, + node_feature_dim: int, + edge_feature_dim: int, + state_evaluator: Callable[[Batch], torch.Tensor] | None = None, + device_str: Literal["cpu", "cuda"] = "cpu" + ): + s0 = Data(x=torch.zeros((num_nodes, node_feature_dim)).to(device_str)) + exit_action = torch.tensor( + [-float("inf"), -float("inf")], device=torch.device(device_str) + ) + dummy_action = torch.tensor( + [float("inf"), float("inf")], device=torch.device(device_str) + ) + if state_evaluator is None: + state_evaluator = GCNConvEvaluator(node_feature_dim) + self.state_evaluator = state_evaluator + + super().__init__( + s0=s0, + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + action_shape=(2,), + dummy_action=dummy_action, + exit_action=exit_action, + device_str=device_str, + ) + + + def step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Step function for the GraphBuilding environment. + + Args: + states: GraphStates object representing the current graph. + actions: Actions indicating which edge to add. + + Returns the next graph the new GraphStates. + """ + graphs: Batch = copy.deepcopy(states.data) + assert len(graphs) == len(actions) + + for i, act in enumerate(actions.tensor): + edge_index = torch.cat([graphs[i].edge_index, act.unsqueeze(1)], dim=1) + graphs[i].edge_index = edge_index + + return GraphStates(graphs) + + def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + """Backward step function for the GraphBuilding environment. + + Args: + states: GraphStates object representing the current graph. + actions: Actions indicating which edge to remove. + + Returns the previous graph as a new GraphStates. + """ + graphs: Batch = copy.deepcopy(states.data) + assert len(graphs) == len(actions) + + for i, act in enumerate(actions.tensor): + edge_index = graphs[i].edge_index + edge_index = edge_index[:, edge_index[1] != act] + graphs[i].edge_index = edge_index + + return GraphStates(graphs) + + def is_action_valid( + self, states: GraphStates, actions: Actions, backward: bool = False + ) -> bool: + for i, act in enumerate(actions.tensor): + if backward and len(states.data[i].edge_index[1]) == 0: + return False + if not backward and torch.any(states.data[i].edge_index[1] == act): + return False + return True + + def reward(self, final_states: GraphStates) -> torch.Tensor: + """The environment's reward given a state. + This or log_reward must be implemented. + + Args: + final_states: A batch of final states. + + Returns: + torch.Tensor: Tensor of shape "batch_shape" containing the rewards. + """ + return self.state_evaluator(final_states.data).sum(dim=1) + + @property + def log_partition(self) -> float: + "Returns the logarithm of the partition function." + raise NotImplementedError + + @property + def true_dist_pmf(self) -> torch.Tensor: + "Returns a one-dimensional tensor representing the true distribution." + raise NotImplementedError + + +class GCNConvEvaluator: + def __init__(self, num_features): + self.net = GCNConv(num_features, 1) + + def __call__(self, batch: Batch) -> torch.Tensor: + return self.net(batch.x, batch.edge_index) \ No newline at end of file diff --git a/src/gfn/states.py b/src/gfn/states.py index a4a15f7b..8e6d513f 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, List, Optional, Sequence, cast, Tuple +from typing import Callable, ClassVar, List, Optional, Sequence import torch from torch_geometric.data import Batch, Data @@ -503,6 +503,7 @@ def stack_states(states: List[States]): return stacked_states + class GraphStates(ABC): """ Base class for Graph as a state representation. The `GraphStates` object is a batched collection of @@ -510,7 +511,6 @@ class GraphStates(ABC): graph objects as states. """ - s0: ClassVar[Data] sf: ClassVar[Data] node_feature_dim: ClassVar[int] @@ -527,7 +527,9 @@ def __init__(self, graphs: Batch): self._log_rewards: float = None @classmethod - def from_batch_shape(cls, batch_shape: int, random: bool = False, sink: bool=False) -> GraphStates: + def from_batch_shape( + cls, batch_shape: int, random: bool = False, sink: bool = False + ) -> GraphStates: if random and sink: raise ValueError("Only one of `random` and `sink` should be True.") if random: @@ -557,8 +559,10 @@ def __len__(self): return self.data.batch_size def __repr__(self): - return (f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " - f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}") + return ( + f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " + f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}" + ) def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates: if isinstance(index, int): @@ -566,7 +570,9 @@ def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates: elif isinstance(index, (Sequence, slice)): out = self.__class__(Batch.from_data_list(self.data.index_select(index))) else: - raise NotImplementedError("Indexing with type {} is not implemented".format(type(index))) + raise NotImplementedError( + "Indexing with type {} is not implemented".format(type(index)) + ) if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] @@ -579,20 +585,28 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): """ data_list = self.data.to_data_list() if isinstance(index, int): - assert len(graph) == 1, "GraphStates must have a batch size of 1 for single index assignment" + assert ( + len(graph) == 1 + ), "GraphStates must have a batch size of 1 for single index assignment" data_list[index] = graph.data[0] self.data = Batch.from_data_list(data_list) elif isinstance(index, Sequence): - assert len(index) == len(graph), "Index and GraphState must have the same length" + assert len(index) == len( + graph + ), "Index and GraphState must have the same length" for i, idx in enumerate(index): data_list[idx] = graph.data[i] self.data = Batch.from_data_list(data_list) elif isinstance(index, slice): - assert index.stop - index.start == len(graph), "Index slice and GraphStates must have the same length" + assert index.stop - index.start == len( + graph + ), "Index slice and GraphStates must have the same length" data_list[index] = graph.data.to_data_list() self.data = Batch.from_data_list(data_list) else: - raise NotImplementedError("Setters with type {} is not implemented".format(type(index))) + raise NotImplementedError( + "Setters with type {} is not implemented".format(type(index)) + ) @property def device(self) -> torch.device: @@ -612,18 +626,19 @@ def clone(self) -> GraphStates: def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" - self.data = Batch.from_data_list(self.data.to_data_list() + other.data.to_data_list()) + self.data = Batch.from_data_list( + self.data.to_data_list() + other.data.to_data_list() + ) if self._log_rewards is not None: assert other._log_rewards is not None self._log_rewards = torch.cat( (self._log_rewards, other._log_rewards), dim=0 ) - @property def log_rewards(self) -> torch.Tensor: return self._log_rewards @log_rewards.setter def log_rewards(self, log_rewards: torch.Tensor) -> None: - self._log_rewards = log_rewards \ No newline at end of file + self._log_rewards = log_rewards From 24e23e8a08a172982b2cd1cd5dc3ed762af69e5a Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 7 Nov 2024 23:43:15 +0100 Subject: [PATCH 03/27] add deps and reformat --- pyproject.toml | 1 + src/gfn/gym/graph_building.py | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0523821a..53f7f768 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ einops = ">=0.6.1" numpy = ">=1.21.2" python = "^3.10" torch = ">=1.9.0" +torch_geometric = ">=2.6.0" # dev dependencies. black = { version = "24.3", optional = true } diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 99848870..39f6d8d0 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -1,22 +1,23 @@ from copy import copy -from typing import Callable, Literal, Tuple +from typing import Callable, Literal import torch -from gfn.actions import Actions -from torch_geometric.data import Data, Batch +from torch_geometric.data import Batch, Data from torch_geometric.nn import GCNConv + +from gfn.actions import Actions from gfn.env import GraphEnv from gfn.states import GraphStates class GraphBuilding(GraphEnv): - - def __init__(self, + def __init__( + self, num_nodes: int, node_feature_dim: int, edge_feature_dim: int, state_evaluator: Callable[[Batch], torch.Tensor] | None = None, - device_str: Literal["cpu", "cuda"] = "cpu" + device_str: Literal["cpu", "cuda"] = "cpu", ): s0 = Data(x=torch.zeros((num_nodes, node_feature_dim)).to(device_str)) exit_action = torch.tensor( @@ -32,14 +33,13 @@ def __init__(self, super().__init__( s0=s0, node_feature_dim=node_feature_dim, - edge_feature_dim=edge_feature_dim, + edge_feature_dim=edge_feature_dim, action_shape=(2,), dummy_action=dummy_action, exit_action=exit_action, device_str=device_str, ) - def step(self, states: GraphStates, actions: Actions) -> GraphStates: """Step function for the GraphBuilding environment. @@ -55,7 +55,7 @@ def step(self, states: GraphStates, actions: Actions) -> GraphStates: for i, act in enumerate(actions.tensor): edge_index = torch.cat([graphs[i].edge_index, act.unsqueeze(1)], dim=1) graphs[i].edge_index = edge_index - + return GraphStates(graphs) def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: @@ -74,7 +74,7 @@ def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: edge_index = graphs[i].edge_index edge_index = edge_index[:, edge_index[1] != act] graphs[i].edge_index = edge_index - + return GraphStates(graphs) def is_action_valid( @@ -86,7 +86,7 @@ def is_action_valid( if not backward and torch.any(states.data[i].edge_index[1] == act): return False return True - + def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. This or log_reward must be implemented. @@ -115,4 +115,4 @@ def __init__(self, num_features): self.net = GCNConv(num_features, 1) def __call__(self, batch: Batch) -> torch.Tensor: - return self.net(batch.x, batch.edge_index) \ No newline at end of file + return self.net(batch.x, batch.edge_index) From 1f7b220a22f2e77eb721d5f7b57caed4d8166b30 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 8 Nov 2024 15:10:33 +0100 Subject: [PATCH 04/27] add test, fix errors, add valid action check --- src/gfn/env.py | 67 ++++++++++++----------------------- src/gfn/gym/graph_building.py | 62 +++++++++++++++++++++----------- src/gfn/states.py | 50 +++++++++++++++++--------- testing/test_environments.py | 38 +++++++++++++++++++- 4 files changed, 134 insertions(+), 83 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 517fa97c..3d61a06a 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -594,35 +594,33 @@ def __init__( the IdentityPreprocessor is used. """ self.device = get_device(device_str, default_device=s0.x.device) - + self.s0 = s0.to(self.device) + + self.node_feature_dim = node_feature_dim + self.edge_feature_dim = edge_feature_dim + self.state_shape = (s0.num_nodes, self.node_feature_dim) + assert s0.x.shape == self.state_shape + if sf is None: sf = Data( - x=torch.full((s0.num_nodes, node_feature_dim), -float("inf")).to( - self.device - ), - edge_attr=torch.full( - (s0.num_edges, edge_feature_dim), -float("inf") - ).to(self.device), + x=torch.full(self.state_shape, -float("inf")), + edge_attr=torch.full((s0.num_edges, edge_feature_dim), -float("inf")), edge_index=s0.edge_index, - batch=torch.zeros(s0.num_nodes, dtype=torch.long, device=self.device), - ) + ).to(self.device) + self.sf: torch.Tensor = sf + assert self.sf.x.shape == self.state_shape - super().__init__( - s0=s0, - state_shape=(s0.num_nodes, node_feature_dim), - action_shape=action_shape, - dummy_action=dummy_action, - exit_action=exit_action, - sf=sf, - device_str=device_str, - preprocessor=preprocessor, - ) + self.action_shape = action_shape + self.dummy_action = dummy_action + self.exit_action = exit_action - self.node_feature_dim = node_feature_dim - self.edge_feature_dim = edge_feature_dim - self.GraphStates = self.make_graph_states_class() + self.States = self.make_states_class() + self.Actions = self.make_actions_class() - def make_graph_states_class(self) -> type[GraphStates]: + self.preprocessor = preprocessor + self.is_discrete = False + + def make_states_class(self) -> type[GraphStates]: env = self class GraphEnvStates(GraphStates): @@ -630,18 +628,10 @@ class GraphEnvStates(GraphStates): sf = env.sf node_feature_dim = env.node_feature_dim edge_feature_dim = env.edge_feature_dim - make_random_states_graph = env.make_random_states_graph + make_random_states_graph = env.make_random_states_tensor return GraphEnvStates - def states_from_tensor(self, tensor: Batch) -> GraphStates: - """Wraps the supplied Batch in a GraphStates instance.""" - return self.GraphStates(tensor) - - def states_from_batch_shape(self, batch_shape: int) -> GraphStates: - """Returns a batch of s0 states with a given batch_shape.""" - return self.GraphStates.from_batch_shape(batch_shape) - @abstractmethod def step(self, states: GraphStates, actions: Actions) -> GraphStates: """Function that takes a batch of graph states and actions and returns a batch of next @@ -651,16 +641,3 @@ def step(self, states: GraphStates, actions: Actions) -> GraphStates: def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: """Function that takes a batch of graph states and actions and returns a batch of previous graph states.""" - - @abstractmethod - def is_action_valid( - self, - states: GraphStates, - actions: Actions, - backward: bool = False, - ) -> bool: - """Returns True if the actions are valid in the given graph states.""" - - @abstractmethod - def make_random_states_graph(self, batch_shape: int) -> Batch: - """Optional method inherited by all GraphStates instances to emit a random Batch of graphs.""" diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 39f6d8d0..611a28fd 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -1,12 +1,12 @@ -from copy import copy -from typing import Callable, Literal +from copy import deepcopy +from typing import Callable, Literal, Tuple import torch from torch_geometric.data import Batch, Data from torch_geometric.nn import GCNConv from gfn.actions import Actions -from gfn.env import GraphEnv +from gfn.env import GraphEnv, NonValidActionsError from gfn.states import GraphStates @@ -19,7 +19,10 @@ def __init__( state_evaluator: Callable[[Batch], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): - s0 = Data(x=torch.zeros((num_nodes, node_feature_dim)).to(device_str)) + s0 = Data( + x=torch.zeros((num_nodes, node_feature_dim)), + edge_index=torch.zeros((2, 0), dtype=torch.long), + ).to(device_str) exit_action = torch.tensor( [-float("inf"), -float("inf")], device=torch.device(device_str) ) @@ -49,14 +52,14 @@ def step(self, states: GraphStates, actions: Actions) -> GraphStates: Returns the next graph the new GraphStates. """ - graphs: Batch = copy.deepcopy(states.data) + if not self.is_action_valid(states, actions): + raise NonValidActionsError("Invalid action.") + graphs: Batch = deepcopy(states.data) assert len(graphs) == len(actions) - for i, act in enumerate(actions.tensor): - edge_index = torch.cat([graphs[i].edge_index, act.unsqueeze(1)], dim=1) - graphs[i].edge_index = edge_index - - return GraphStates(graphs) + edge_index = torch.cat([graphs.edge_index, actions.tensor.T], dim=1) + graphs.edge_index = edge_index + return self.States(graphs) def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: """Backward step function for the GraphBuilding environment. @@ -67,7 +70,9 @@ def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: Returns the previous graph as a new GraphStates. """ - graphs: Batch = copy.deepcopy(states.data) + if not self.is_action_valid(states, actions, backward=True): + raise NonValidActionsError("Invalid action.") + graphs: Batch = deepcopy(states.data) assert len(graphs) == len(actions) for i, act in enumerate(actions.tensor): @@ -75,17 +80,29 @@ def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: edge_index = edge_index[:, edge_index[1] != act] graphs[i].edge_index = edge_index - return GraphStates(graphs) + return self.States(graphs) def is_action_valid( self, states: GraphStates, actions: Actions, backward: bool = False ) -> bool: - for i, act in enumerate(actions.tensor): - if backward and len(states.data[i].edge_index[1]) == 0: - return False - if not backward and torch.any(states.data[i].edge_index[1] == act): - return False - return True + current_edges = states.data.edge_index + new_edges = actions.tensor + + if torch.any(new_edges[:, 0] == new_edges[:, 1]): + return False + if current_edges.shape[1] == 0: + return not backward + + if backward: + some_edges_not_exist = torch.any( + torch.all(current_edges[:, None, :] != new_edges.T[:, :, None], dim=0) + ) + return not some_edges_not_exist + else: + some_edges_exist = torch.any( + torch.all(current_edges[:, None, :] == new_edges.T[:, :, None], dim=0) + ) + return not some_edges_exist def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. @@ -97,7 +114,9 @@ def reward(self, final_states: GraphStates) -> torch.Tensor: Returns: torch.Tensor: Tensor of shape "batch_shape" containing the rewards. """ - return self.state_evaluator(final_states.data).sum(dim=1) + per_node_rew = self.state_evaluator(final_states.data) + node_batch_idx = final_states.data.batch + return torch.bincount(node_batch_idx, weights=per_node_rew) @property def log_partition(self) -> float: @@ -109,10 +128,13 @@ def true_dist_pmf(self) -> torch.Tensor: "Returns a one-dimensional tensor representing the true distribution." raise NotImplementedError + def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: + """Generates random states tensor of shape (*batch_shape, num_nodes, node_feature_dim).""" + return self.States.from_batch_shape(batch_shape) class GCNConvEvaluator: def __init__(self, num_features): self.net = GCNConv(num_features, 1) def __call__(self, batch: Batch) -> torch.Tensor: - return self.net(batch.x, batch.edge_index) + return self.net(batch.x, batch.edge_index).squeeze(-1) diff --git a/src/gfn/states.py b/src/gfn/states.py index 8e6d513f..6a94fe84 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, List, Optional, Sequence +from typing import Callable, ClassVar, List, Optional, Sequence, Tuple import torch from torch_geometric.data import Batch, Data @@ -523,7 +523,8 @@ class GraphStates(ABC): def __init__(self, graphs: Batch): self.data: Batch = graphs - self.batch_shape: int = self.data.num_graphs + self.batch_shape: int = len(self.data) + self.state_shape = (self.data.get_example(0).num_nodes, self.node_feature_dim) self._log_rewards: float = None @classmethod @@ -541,19 +542,41 @@ def from_batch_shape( return cls(data) @classmethod - def make_initial_states_graph(cls, batch_shape: int) -> Batch: + def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: + if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: + raise NotImplementedError("Batch shape with more than one dimension is not supported") + if isinstance(batch_shape, Tuple): + batch_shape = batch_shape[0] + data = Batch.from_data_list([cls.s0 for _ in range(batch_shape)]) return data @classmethod - def make_sink_states_graph(cls, batch_shape: int) -> Batch: + def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: + if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: + raise NotImplementedError("Batch shape with more than one dimension is not supported") + if isinstance(batch_shape, Tuple): + batch_shape = batch_shape[0] + data = Batch.from_data_list([cls.sf for _ in range(batch_shape)]) return data - # @classmethod - # def make_random_states_graph(cls, batch_shape: int) -> Batch: - # data = Batch.from_data_list([cls.make_random_states_graph() for _ in range(batch_shape)]) - # return data + @classmethod + def make_random_states_graph(cls, batch_shape: int) -> Batch: + if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: + raise NotImplementedError("Batch shape with more than one dimension is not supported") + if isinstance(batch_shape, Tuple): + batch_shape = batch_shape[0] + + data_list = [] + for _ in range(batch_shape): + data = Data( + x=torch.rand(cls.s0.num_nodes, cls.node_feature_dim), + edge_attr=torch.rand(cls.s0.num_edges, cls.edge_feature_dim), + edge_index=cls.s0.edge_index, # TODO: make it random + ) + data_list.append(data) + return Batch.from_data_list(data_list) def __len__(self): return self.data.batch_size @@ -564,15 +587,8 @@ def __repr__(self): f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}" ) - def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates: - if isinstance(index, int): - out = self.__class__(Batch.from_data_list([self.data[index]])) - elif isinstance(index, (Sequence, slice)): - out = self.__class__(Batch.from_data_list(self.data.index_select(index))) - else: - raise NotImplementedError( - "Indexing with type {} is not implemented".format(type(index)) - ) + def __getitem__(self, index: int | Sequence[int] | slice | torch.Tensor) -> GraphStates: + out = self.__class__(Batch(self.data[index])) if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] diff --git a/testing/test_environments.py b/testing/test_environments.py index 5dbd4cc6..2786d61a 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -4,6 +4,7 @@ from gfn.env import NonValidActionsError from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym.graph_building import GraphBuilding # Utilities. @@ -273,7 +274,7 @@ def test_states_getitem(ndim: int, env_name: str): states = env.reset(batch_shape=ND_BATCH_SHAPE, random=True) # Boolean selector to index batch elements. - selections = torch.randint(0, 2, ND_BATCH_SHAPE, dtype=torch.bool) + selections = torch.randint(0, 2,ND_BATCH_SHAPE, dtype=torch.bool) n_selections = int(torch.sum(selections)) selected_states = states[selections] @@ -316,3 +317,38 @@ def test_get_grid(): # State indices of the grid are ordered from 0:HEIGHT**2. assert (env.get_states_indices(grid).ravel() == torch.arange(HEIGHT**2)).all() + + +def test_graph_env(): + NUM_NODES = 4 + FEATURE_DIM = 8 + BATCH_SIZE = 3 + + env = GraphBuilding(num_nodes=NUM_NODES, node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM) + states = env.reset(batch_shape=BATCH_SIZE) + assert states.batch_shape == BATCH_SIZE + assert states.state_shape == (NUM_NODES, FEATURE_DIM) + + actions_traj = torch.tensor([ + [[0, 1], [1, 2], [2, 3]], + [[0, 2], [1, 3], [2, 4]], + [[0, 3], [1, 4], [2, 5]], + [[0, 4], [1, 5], [2, 6]], + [[0, 5], [1, 6], [2, 7]], + ], dtype=torch.long) + + for action_tensor in actions_traj: + actions = env.actions_from_tensor(action_tensor) + states = env.step(states, actions) + + invalid_actions = torch.tensor([[0, 0], [1, 1], [2, 2]]) + actions = env.actions_from_tensor(invalid_actions) + with pytest.raises(NonValidActionsError): + states = env.step(states, actions) + invalid_actions = torch.tensor(actions_traj[0]) + actions = env.actions_from_tensor(invalid_actions) + with pytest.raises(NonValidActionsError): + states = env.step(states, actions) + + expected_rewards = torch.zeros(BATCH_SIZE) + assert (env.reward(states) == expected_rewards).all() \ No newline at end of file From 63e4f1cb0eee6c17464d7844aedf9bfa31c33aa7 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 8 Nov 2024 15:13:28 +0100 Subject: [PATCH 05/27] fix formatting --- src/gfn/env.py | 4 ++-- src/gfn/gym/graph_building.py | 3 ++- src/gfn/states.py | 21 ++++++++++++--------- testing/test_environments.py | 29 +++++++++++++++++------------ 4 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 3d61a06a..3c86a3de 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -595,12 +595,12 @@ def __init__( """ self.device = get_device(device_str, default_device=s0.x.device) self.s0 = s0.to(self.device) - + self.node_feature_dim = node_feature_dim self.edge_feature_dim = edge_feature_dim self.state_shape = (s0.num_nodes, self.node_feature_dim) assert s0.x.shape == self.state_shape - + if sf is None: sf = Data( x=torch.full(self.state_shape, -float("inf")), diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 611a28fd..8b9dcc59 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -97,7 +97,7 @@ def is_action_valid( some_edges_not_exist = torch.any( torch.all(current_edges[:, None, :] != new_edges.T[:, :, None], dim=0) ) - return not some_edges_not_exist + return not some_edges_not_exist else: some_edges_exist = torch.any( torch.all(current_edges[:, None, :] == new_edges.T[:, :, None], dim=0) @@ -132,6 +132,7 @@ def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: """Generates random states tensor of shape (*batch_shape, num_nodes, node_feature_dim).""" return self.States.from_batch_shape(batch_shape) + class GCNConvEvaluator: def __init__(self, num_features): self.net = GCNConv(num_features, 1) diff --git a/src/gfn/states.py b/src/gfn/states.py index 6a94fe84..513f814b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -515,11 +515,6 @@ class GraphStates(ABC): sf: ClassVar[Data] node_feature_dim: ClassVar[int] edge_feature_dim: ClassVar[int] - make_random_states_graph: Callable = lambda x: (_ for _ in ()).throw( - NotImplementedError( - "The environment does not support initialization of random Graph states." - ) - ) def __init__(self, graphs: Batch): self.data: Batch = graphs @@ -544,7 +539,9 @@ def from_batch_shape( @classmethod def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError("Batch shape with more than one dimension is not supported") + raise NotImplementedError( + "Batch shape with more than one dimension is not supported" + ) if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] @@ -554,7 +551,9 @@ def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: @classmethod def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError("Batch shape with more than one dimension is not supported") + raise NotImplementedError( + "Batch shape with more than one dimension is not supported" + ) if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] @@ -564,7 +563,9 @@ def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: @classmethod def make_random_states_graph(cls, batch_shape: int) -> Batch: if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError("Batch shape with more than one dimension is not supported") + raise NotImplementedError( + "Batch shape with more than one dimension is not supported" + ) if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] @@ -587,7 +588,9 @@ def __repr__(self): f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}" ) - def __getitem__(self, index: int | Sequence[int] | slice | torch.Tensor) -> GraphStates: + def __getitem__( + self, index: int | Sequence[int] | slice | torch.Tensor + ) -> GraphStates: out = self.__class__(Batch(self.data[index])) if self._log_rewards is not None: diff --git a/testing/test_environments.py b/testing/test_environments.py index 2786d61a..002e5ea4 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -274,7 +274,7 @@ def test_states_getitem(ndim: int, env_name: str): states = env.reset(batch_shape=ND_BATCH_SHAPE, random=True) # Boolean selector to index batch elements. - selections = torch.randint(0, 2,ND_BATCH_SHAPE, dtype=torch.bool) + selections = torch.randint(0, 2, ND_BATCH_SHAPE, dtype=torch.bool) n_selections = int(torch.sum(selections)) selected_states = states[selections] @@ -324,31 +324,36 @@ def test_graph_env(): FEATURE_DIM = 8 BATCH_SIZE = 3 - env = GraphBuilding(num_nodes=NUM_NODES, node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM) + env = GraphBuilding( + num_nodes=NUM_NODES, node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM + ) states = env.reset(batch_shape=BATCH_SIZE) assert states.batch_shape == BATCH_SIZE assert states.state_shape == (NUM_NODES, FEATURE_DIM) - actions_traj = torch.tensor([ - [[0, 1], [1, 2], [2, 3]], - [[0, 2], [1, 3], [2, 4]], - [[0, 3], [1, 4], [2, 5]], - [[0, 4], [1, 5], [2, 6]], - [[0, 5], [1, 6], [2, 7]], - ], dtype=torch.long) + actions_traj = torch.tensor( + [ + [[0, 1], [1, 2], [2, 3]], + [[0, 2], [1, 3], [2, 4]], + [[0, 3], [1, 4], [2, 5]], + [[0, 4], [1, 5], [2, 6]], + [[0, 5], [1, 6], [2, 7]], + ], + dtype=torch.long, + ) for action_tensor in actions_traj: actions = env.actions_from_tensor(action_tensor) states = env.step(states, actions) - invalid_actions = torch.tensor([[0, 0], [1, 1], [2, 2]]) + invalid_actions = torch.tensor([[0, 0], [1, 1], [2, 2]]) actions = env.actions_from_tensor(invalid_actions) with pytest.raises(NonValidActionsError): states = env.step(states, actions) - invalid_actions = torch.tensor(actions_traj[0]) + invalid_actions = torch.tensor(actions_traj[0]) actions = env.actions_from_tensor(invalid_actions) with pytest.raises(NonValidActionsError): states = env.step(states, actions) expected_rewards = torch.zeros(BATCH_SIZE) - assert (env.reward(states) == expected_rewards).all() \ No newline at end of file + assert (env.reward(states) == expected_rewards).all() From 8034fb2bdcdb1bff2a5cc44cf777164de51b7736 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 14 Nov 2024 13:42:25 +0100 Subject: [PATCH 06/27] add GraphAction --- src/gfn/actions.py | 86 ++++++++++++++++++++++++- src/gfn/env.py | 38 ++++++----- src/gfn/gym/graph_building.py | 116 +++++++++++++++++++++------------- src/gfn/states.py | 6 +- testing/test_environments.py | 112 +++++++++++++++++++++++++------- 5 files changed, 268 insertions(+), 90 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 2006b018..e6a1e67f 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -1,8 +1,9 @@ from __future__ import annotations # This allows to use the class name in type hints from abc import ABC +import enum from math import prod -from typing import ClassVar, Sequence +from typing import ClassVar, Optional, Sequence import torch @@ -168,3 +169,86 @@ def is_exit(self) -> torch.Tensor: *self.batch_shape, *((1,) * len(self.__class__.action_shape)) ) return self.compare(exit_actions_tensor) + + +class GraphActionType(enum.Enum): + EXIT = enum.auto() + ADD_NODE = enum.auto() + ADD_EDGE = enum.auto() + + +class GraphActions: + + nodes_features_dim: ClassVar[int] # Dim size of the features tensor. + edge_features_dim: ClassVar[int] # Dim size of the edge features tensor. + + def __init__(self, action_type: GraphActionType, features: torch.Tensor, edge_index: Optional[torch.Tensor] = None): + """Initializes a GraphAction object. + + Args: + action: a GraphActionType indicating the type of action. + features: a tensor of shape (*batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type + edge_index: an tensor of shape (*batch_shape, 2) representing the edge to add. + This must defined if and only if the action type is GraphActionType.AddEdge. + """ + self.action_type = action_type + if self.action_type == GraphActionType.ADD_NODE: + assert features.shape[-1] == self.nodes_features_dim + assert edge_index is None + elif self.action_type == GraphActionType.ADD_EDGE: + assert features.shape[-1] == self.edge_features_dim + assert edge_index is not None + assert edge_index.shape[-1] == 2 + + + self.features = features + self.edge_index = edge_index + self.batch_shape = tuple(self.features.shape[:-1]) + + def __repr__(self): + return f"""GraphAction object of type {self.action_type} and features of shape {self.features.shape}.""" + + @property + def device(self) -> torch.device: + """Returns the device of the features tensor.""" + return self.features.device + + def __len__(self) -> int: + """Returns the number of actions in the batch.""" + return prod(self.batch_shape) + + def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: + """Get particular actions of the batch.""" + features = self.features[index] + edge_index = self.edge_index[index] if self.edge_index is not None else None + return GraphActions(self.action_type, features, edge_index) + + def __setitem__(self, index: int | Sequence[int] | Sequence[bool], action: GraphActions) -> None: + """Set particular actions of the batch.""" + assert self.action_type == action.action_type + self.features[index] = action.features + if self.edge_index is not None: + self.edge_index[index] = action.edge_index + + def compare(self, other: GraphActions) -> torch.Tensor: + """Compares the actions to another GraphAction object. + + Args: + other: GraphAction object to compare. + + Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. + """ + if self.action_type != other.action_type: + return torch.zeros(self.batch_shape, dtype=torch.bool, device=self.device) + out = torch.all(self.features == other.features, dim=-1) + if self.edge_index is not None: + out &= torch.all(self.edge_index == other.edge_index, dim=-1) + return out + + @property + def is_exit(self) -> torch.Tensor: + """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" + return torch.full(self.batch_shape, self.action_type == GraphActionType.Exit, dtype=torch.bool, device=self.device) + + + diff --git a/src/gfn/env.py b/src/gfn/env.py index 3c86a3de..8780c069 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -4,7 +4,7 @@ import torch from torch_geometric.data import Batch, Data -from gfn.actions import Actions +from gfn.actions import Actions, GraphActions from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, GraphStates, States from gfn.utils.common import set_seed @@ -570,9 +570,6 @@ def __init__( s0: Data, node_feature_dim: int, edge_feature_dim: int, - action_shape: Tuple, - dummy_action: torch.Tensor, - exit_action: torch.Tensor, sf: Optional[Data] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, @@ -593,26 +590,12 @@ def __init__( that can be fed into a neural network. Defaults to None, in which case the IdentityPreprocessor is used. """ - self.device = get_device(device_str, default_device=s0.x.device) - self.s0 = s0.to(self.device) + self.s0 = s0.to(device_str) self.node_feature_dim = node_feature_dim self.edge_feature_dim = edge_feature_dim - self.state_shape = (s0.num_nodes, self.node_feature_dim) - assert s0.x.shape == self.state_shape - - if sf is None: - sf = Data( - x=torch.full(self.state_shape, -float("inf")), - edge_attr=torch.full((s0.num_edges, edge_feature_dim), -float("inf")), - edge_index=s0.edge_index, - ).to(self.device) - self.sf: torch.Tensor = sf - assert self.sf.x.shape == self.state_shape - self.action_shape = action_shape - self.dummy_action = dummy_action - self.exit_action = exit_action + self.sf = sf self.States = self.make_states_class() self.Actions = self.make_actions_class() @@ -632,6 +615,21 @@ class GraphEnvStates(GraphStates): return GraphEnvStates + def make_actions_class(self) -> type[GraphActions]: + """The default Actions class factory for all Environments. + + Returns a class that inherits from Actions and implements assumed methods. + The make_actions_class method should be overwritten to achieve more + environment-specific Actions functionality. + """ + env = self + + class DefaultGraphAction(GraphActions): + nodes_features_dim = env.node_feature_dim + edge_features_dim = env.edge_feature_dim + + return DefaultGraphAction + @abstractmethod def step(self, states: GraphStates, actions: Actions) -> GraphStates: """Function that takes a batch of graph states and actions and returns a batch of next diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 8b9dcc59..6d017d38 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -5,7 +5,7 @@ from torch_geometric.data import Batch, Data from torch_geometric.nn import GCNConv -from gfn.actions import Actions +from gfn.actions import GraphActions, GraphActionType from gfn.env import GraphEnv, NonValidActionsError from gfn.states import GraphStates @@ -13,22 +13,13 @@ class GraphBuilding(GraphEnv): def __init__( self, - num_nodes: int, node_feature_dim: int, edge_feature_dim: int, state_evaluator: Callable[[Batch], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): - s0 = Data( - x=torch.zeros((num_nodes, node_feature_dim)), - edge_index=torch.zeros((2, 0), dtype=torch.long), - ).to(device_str) - exit_action = torch.tensor( - [-float("inf"), -float("inf")], device=torch.device(device_str) - ) - dummy_action = torch.tensor( - [float("inf"), float("inf")], device=torch.device(device_str) - ) + s0 = Data().to(device_str) + if state_evaluator is None: state_evaluator = GCNConvEvaluator(node_feature_dim) self.state_evaluator = state_evaluator @@ -37,13 +28,10 @@ def __init__( s0=s0, node_feature_dim=node_feature_dim, edge_feature_dim=edge_feature_dim, - action_shape=(2,), - dummy_action=dummy_action, - exit_action=exit_action, device_str=device_str, ) - def step(self, states: GraphStates, actions: Actions) -> GraphStates: + def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: """Step function for the GraphBuilding environment. Args: @@ -57,11 +45,26 @@ def step(self, states: GraphStates, actions: Actions) -> GraphStates: graphs: Batch = deepcopy(states.data) assert len(graphs) == len(actions) - edge_index = torch.cat([graphs.edge_index, actions.tensor.T], dim=1) - graphs.edge_index = edge_index + if actions.action_type == GraphActionType.ADD_NODE: + if graphs.x is None: + graphs.x = actions.features[:, None, :] + else: + graphs.x = torch.cat([graphs.x, actions.features[:, None, :]], dim=1) + + if actions.action_type == GraphActionType.ADD_EDGE: + assert actions.edge_index is not None + if graphs.edge_attr is None: + graphs.edge_attr = actions.features[:, None, :] + assert graphs.edge_index is None + graphs.edge_index = actions.edge_index[:, :, None] + else: + graphs.edge_attr = torch.cat([graphs.edge_attr, actions.features[:, None, :]], dim=1) + graphs.edge_index = torch.cat([graphs.edge_index, actions.edge_index[:, :, None]], dim=2) + return self.States(graphs) - def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + + def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: """Backward step function for the GraphBuilding environment. Args: @@ -75,34 +78,59 @@ def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: graphs: Batch = deepcopy(states.data) assert len(graphs) == len(actions) - for i, act in enumerate(actions.tensor): - edge_index = graphs[i].edge_index - edge_index = edge_index[:, edge_index[1] != act] - graphs[i].edge_index = edge_index + if actions.action_type == GraphActionType.ADD_NODE: + assert graphs.x is not None + is_equal = torch.all(graphs.x == actions.features[:, None], dim=-1) + assert torch.all(torch.sum(is_equal, dim=-1) == 1) + graphs.x = graphs.x[~is_equal].reshape(states.data.batch_size, -1, self.node_feature_dim) + + elif actions.action_type == GraphActionType.ADD_EDGE: + assert actions.edge_index is not None + is_equal = torch.all(graphs.edge_index == actions.edge_index[:, :, None], dim=1) + assert torch.all(torch.sum(is_equal, dim=-1) == 1) + graphs.edge_attr = graphs.edge_attr[~is_equal].reshape(states.data.batch_size, -1, self.edge_feature_dim) + edge_index = graphs.edge_index.permute(0, 2, 1)[~is_equal] + graphs.edge_index = edge_index.reshape(states.data.batch_size, -1, 2).permute(0, 2, 1) return self.States(graphs) def is_action_valid( - self, states: GraphStates, actions: Actions, backward: bool = False + self, states: GraphStates, actions: GraphActions, backward: bool = False ) -> bool: - current_edges = states.data.edge_index - new_edges = actions.tensor - - if torch.any(new_edges[:, 0] == new_edges[:, 1]): - return False - if current_edges.shape[1] == 0: - return not backward + if actions.action_type == GraphActionType.ADD_NODE: + if actions.edge_index is not None: + return False + if states.data.x is None: + return not backward + + equal_nodes_per_batch = torch.sum( + torch.all(states.data.x == actions.features[:, None], dim=-1), + dim=-1 + ) - if backward: - some_edges_not_exist = torch.any( - torch.all(current_edges[:, None, :] != new_edges.T[:, :, None], dim=0) + if backward: # TODO: check if no edge are connected? + return torch.all(equal_nodes_per_batch == 1) + return torch.all(equal_nodes_per_batch == 0) + + if actions.action_type == GraphActionType.ADD_EDGE: + assert actions.edge_index is not None + if torch.any(actions.edge_index[:, 0] == actions.edge_index[:, 1]): + return False + if states.data.edge_index is None: + return not backward + + equal_edges_per_batch_attr = torch.sum( + torch.all(states.data.edge_attr == actions.features[:, None], dim=-1), + dim=-1 ) - return not some_edges_not_exist - else: - some_edges_exist = torch.any( - torch.all(current_edges[:, None, :] == new_edges.T[:, :, None], dim=0) + equal_edges_per_batch_index = torch.sum( + torch.all(states.data.edge_index == actions.edge_index[:, :, None], dim=1), + dim=-1 ) - return not some_edges_exist + if backward: + return torch.all(equal_edges_per_batch_attr == 1) and torch.all(equal_edges_per_batch_index == 1) + return torch.all(equal_edges_per_batch_attr == 0) and torch.all(equal_edges_per_batch_index == 0) + def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. @@ -114,9 +142,7 @@ def reward(self, final_states: GraphStates) -> torch.Tensor: Returns: torch.Tensor: Tensor of shape "batch_shape" containing the rewards. """ - per_node_rew = self.state_evaluator(final_states.data) - node_batch_idx = final_states.data.batch - return torch.bincount(node_batch_idx, weights=per_node_rew) + return self.state_evaluator(final_states.data) @property def log_partition(self) -> float: @@ -138,4 +164,8 @@ def __init__(self, num_features): self.net = GCNConv(num_features, 1) def __call__(self, batch: Batch) -> torch.Tensor: - return self.net(batch.x, batch.edge_index).squeeze(-1) + out = torch.empty(len(batch), device=batch.x.device) + for i in range(len(batch)): # looks like net doesn't work with batch + out[i] = self.net(batch.x[i], batch.edge_index[i]).mean() + + return out diff --git a/src/gfn/states.py b/src/gfn/states.py index 513f814b..503514c4 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -512,14 +512,13 @@ class GraphStates(ABC): """ s0: ClassVar[Data] - sf: ClassVar[Data] + sf: ClassVar[Optional[Data]] node_feature_dim: ClassVar[int] edge_feature_dim: ClassVar[int] def __init__(self, graphs: Batch): self.data: Batch = graphs self.batch_shape: int = len(self.data) - self.state_shape = (self.data.get_example(0).num_nodes, self.node_feature_dim) self._log_rewards: float = None @classmethod @@ -550,6 +549,9 @@ def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: @classmethod def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: + if cls.sf is None: + raise NotImplementedError("Sink state is not defined") + if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: raise NotImplementedError( "Batch shape with more than one dimension is not supported" diff --git a/testing/test_environments.py b/testing/test_environments.py index 002e5ea4..cccc2bc1 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -2,6 +2,7 @@ import pytest import torch +from gfn.actions import GraphActionType from gfn.env import NonValidActionsError from gfn.gym import Box, DiscreteEBM, HyperGrid from gfn.gym.graph_building import GraphBuilding @@ -320,40 +321,103 @@ def test_get_grid(): def test_graph_env(): - NUM_NODES = 4 FEATURE_DIM = 8 BATCH_SIZE = 3 + NUM_NODES = 5 - env = GraphBuilding( - num_nodes=NUM_NODES, node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM - ) + env = GraphBuilding(node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM) states = env.reset(batch_shape=BATCH_SIZE) assert states.batch_shape == BATCH_SIZE - assert states.state_shape == (NUM_NODES, FEATURE_DIM) - - actions_traj = torch.tensor( - [ - [[0, 1], [1, 2], [2, 3]], - [[0, 2], [1, 3], [2, 4]], - [[0, 3], [1, 4], [2, 5]], - [[0, 4], [1, 5], [2, 6]], - [[0, 5], [1, 6], [2, 7]], - ], - dtype=torch.long, - ) + action_cls = env.make_actions_class() - for action_tensor in actions_traj: - actions = env.actions_from_tensor(action_tensor) + with pytest.raises(NonValidActionsError): + actions = action_cls( + GraphActionType.ADD_EDGE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long) + ) + states = env.step(states, actions) + + for _ in range(NUM_NODES): + actions = action_cls( + GraphActionType.ADD_NODE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + ) states = env.step(states, actions) + + assert states.data.x.shape == (BATCH_SIZE, NUM_NODES, FEATURE_DIM) - invalid_actions = torch.tensor([[0, 0], [1, 1], [2, 2]]) - actions = env.actions_from_tensor(invalid_actions) with pytest.raises(NonValidActionsError): + actions = action_cls( + GraphActionType.ADD_NODE, + states.data.x[:, 0], + ) states = env.step(states, actions) - invalid_actions = torch.tensor(actions_traj[0]) - actions = env.actions_from_tensor(invalid_actions) + with pytest.raises(NonValidActionsError): + edge_index = torch.randint(0, 3, (BATCH_SIZE,), dtype=torch.long) + actions = action_cls( + GraphActionType.ADD_EDGE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + torch.stack([edge_index, edge_index], dim=1) + ) states = env.step(states, actions) - expected_rewards = torch.zeros(BATCH_SIZE) - assert (env.reward(states) == expected_rewards).all() + for i in range(NUM_NODES - 1): + edge_index = torch.tensor([[i, i + 1]] * BATCH_SIZE) + actions = action_cls( + GraphActionType.ADD_EDGE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + edge_index + ) + states = env.step(states, actions) + + with pytest.raises(NonValidActionsError): + edge_index = torch.tensor([[0, 1]] * BATCH_SIZE) + actions = action_cls( + GraphActionType.ADD_EDGE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + edge_index + ) + states = env.step(states, actions) + + env.reward(states) + + # with pytest.raises(NonValidActionsError): + # actions = action_cls( + # GraphActionType.ADD_NODE, + # states.data.x[:, 0], + # ) + # states = env.backward_step(states, actions) + + for i in reversed(range(states.data.edge_attr.shape[1])): + actions = action_cls( + GraphActionType.ADD_EDGE, + states.data.edge_attr[:, i], + states.data.edge_index[:, :, i] + ) + states = env.backward_step(states, actions) + + with pytest.raises(NonValidActionsError): + actions = action_cls( + GraphActionType.ADD_EDGE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long) + ) + states = env.backward_step(states, actions) + + for i in reversed(range(NUM_NODES)): + actions = action_cls( + GraphActionType.ADD_NODE, + states.data.x[:, i], + ) + states = env.backward_step(states, actions) + + assert states.data.x.shape == (BATCH_SIZE, 0, FEATURE_DIM) + + with pytest.raises(NonValidActionsError): + actions = action_cls( + GraphActionType.ADD_NODE, + torch.rand((BATCH_SIZE, FEATURE_DIM)), + ) + states = env.backward_step(states, actions) \ No newline at end of file From d17967155258fd33ac1e669ecc5c80e5afdb7d1f Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 14 Nov 2024 18:31:27 +0100 Subject: [PATCH 07/27] fix batching mechanism --- src/gfn/actions.py | 12 +++--- src/gfn/gym/graph_building.py | 79 +++++++++++++++++++---------------- testing/test_environments.py | 31 ++++++++------ 3 files changed, 67 insertions(+), 55 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index e6a1e67f..4628590a 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -192,18 +192,18 @@ def __init__(self, action_type: GraphActionType, features: torch.Tensor, edge_in This must defined if and only if the action type is GraphActionType.AddEdge. """ self.action_type = action_type + batch_dim, features_dim = features.shape if self.action_type == GraphActionType.ADD_NODE: - assert features.shape[-1] == self.nodes_features_dim + assert features_dim == self.nodes_features_dim assert edge_index is None elif self.action_type == GraphActionType.ADD_EDGE: - assert features.shape[-1] == self.edge_features_dim + assert features_dim == self.edge_features_dim assert edge_index is not None - assert edge_index.shape[-1] == 2 - + assert edge_index.shape == (2, batch_dim) self.features = features self.edge_index = edge_index - self.batch_shape = tuple(self.features.shape[:-1]) + self.batch_shape = (batch_dim,) def __repr__(self): return f"""GraphAction object of type {self.action_type} and features of shape {self.features.shape}.""" @@ -248,7 +248,7 @@ def compare(self, other: GraphActions) -> torch.Tensor: @property def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" - return torch.full(self.batch_shape, self.action_type == GraphActionType.Exit, dtype=torch.bool, device=self.device) + return torch.full(self.batch_shape, self.action_type == GraphActionType.EXIT, dtype=torch.bool, device=self.device) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 6d017d38..b433340b 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -47,19 +47,19 @@ def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: if actions.action_type == GraphActionType.ADD_NODE: if graphs.x is None: - graphs.x = actions.features[:, None, :] + graphs.x = actions.features else: - graphs.x = torch.cat([graphs.x, actions.features[:, None, :]], dim=1) + graphs.x = torch.cat([graphs.x, actions.features]) if actions.action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None if graphs.edge_attr is None: - graphs.edge_attr = actions.features[:, None, :] + graphs.edge_attr = actions.features assert graphs.edge_index is None - graphs.edge_index = actions.edge_index[:, :, None] + graphs.edge_index = actions.edge_index else: - graphs.edge_attr = torch.cat([graphs.edge_attr, actions.features[:, None, :]], dim=1) - graphs.edge_index = torch.cat([graphs.edge_index, actions.edge_index[:, :, None]], dim=2) + graphs.edge_attr = torch.cat([graphs.edge_attr, actions.features]) + graphs.edge_index = torch.cat([graphs.edge_index, actions.edge_index], dim=1) return self.States(graphs) @@ -80,17 +80,17 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStat if actions.action_type == GraphActionType.ADD_NODE: assert graphs.x is not None - is_equal = torch.all(graphs.x == actions.features[:, None], dim=-1) - assert torch.all(torch.sum(is_equal, dim=-1) == 1) - graphs.x = graphs.x[~is_equal].reshape(states.data.batch_size, -1, self.node_feature_dim) - + is_equal = torch.any( + torch.all(graphs.x[:, None] == actions.features, dim=-1), + dim=-1 + ) + graphs.x = graphs.x[~is_equal] elif actions.action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None - is_equal = torch.all(graphs.edge_index == actions.edge_index[:, :, None], dim=1) - assert torch.all(torch.sum(is_equal, dim=-1) == 1) - graphs.edge_attr = graphs.edge_attr[~is_equal].reshape(states.data.batch_size, -1, self.edge_feature_dim) - edge_index = graphs.edge_index.permute(0, 2, 1)[~is_equal] - graphs.edge_index = edge_index.reshape(states.data.batch_size, -1, 2).permute(0, 2, 1) + is_equal = torch.all(graphs.edge_index[:, None] == actions.edge_index[:, :, None], dim=0) + is_equal = torch.any(is_equal, dim=0) + graphs.edge_attr = graphs.edge_attr[~is_equal] + graphs.edge_index = graphs.edge_index[:, ~is_equal] return self.States(graphs) @@ -103,10 +103,10 @@ def is_action_valid( if states.data.x is None: return not backward - equal_nodes_per_batch = torch.sum( - torch.all(states.data.x == actions.features[:, None], dim=-1), - dim=-1 - ) + equal_nodes_per_batch = torch.all( + states.data.x == actions.features[:, None], dim=-1 + ).reshape(states.data.batch_size, -1) + equal_nodes_per_batch = torch.sum(equal_nodes_per_batch, dim=-1) if backward: # TODO: check if no edge are connected? return torch.all(equal_nodes_per_batch == 1) @@ -114,19 +114,28 @@ def is_action_valid( if actions.action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None - if torch.any(actions.edge_index[:, 0] == actions.edge_index[:, 1]): + if torch.any(actions.edge_index[0] == actions.edge_index[1]): return False - if states.data.edge_index is None: - return not backward - - equal_edges_per_batch_attr = torch.sum( - torch.all(states.data.edge_attr == actions.features[:, None], dim=-1), - dim=-1 - ) - equal_edges_per_batch_index = torch.sum( - torch.all(states.data.edge_index == actions.edge_index[:, :, None], dim=1), - dim=-1 - ) + if states.data.num_nodes is None or states.data.num_nodes == 0: + return False + if torch.any(actions.edge_index > states.data.num_nodes): + return False + + batch_idx = actions.edge_index % actions.batch_shape[0] + if torch.any(batch_idx != torch.arange(actions.batch_shape[0])): + return False + if states.data.edge_attr is None: + return True + + equal_edges_per_batch_attr = torch.all( + states.data.edge_attr == actions.features[:, None], dim=-1 + ).reshape(states.data.batch_size, -1) + equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) + + equal_edges_per_batch_index = torch.all( + states.data.edge_index[:, None] == actions.edge_index[:, :, None], dim=0 + ).reshape(states.data.batch_size, -1) + equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) if backward: return torch.all(equal_edges_per_batch_attr == 1) and torch.all(equal_edges_per_batch_index == 1) return torch.all(equal_edges_per_batch_attr == 0) and torch.all(equal_edges_per_batch_index == 0) @@ -164,8 +173,6 @@ def __init__(self, num_features): self.net = GCNConv(num_features, 1) def __call__(self, batch: Batch) -> torch.Tensor: - out = torch.empty(len(batch), device=batch.x.device) - for i in range(len(batch)): # looks like net doesn't work with batch - out[i] = self.net(batch.x[i], batch.edge_index[i]).mean() - - return out + out = self.net(batch.x, batch.edge_index) + out = out.reshape(batch.batch_size, -1) + return out.mean(-1) \ No newline at end of file diff --git a/testing/test_environments.py b/testing/test_environments.py index cccc2bc1..55dda0d5 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -334,7 +334,7 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long) + torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long) ) states = env.step(states, actions) @@ -345,12 +345,13 @@ def test_graph_env(): ) states = env.step(states, actions) - assert states.data.x.shape == (BATCH_SIZE, NUM_NODES, FEATURE_DIM) + assert states.data.x.shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) with pytest.raises(NonValidActionsError): + first_node_mask = torch.arange(len(states.data.x)) // BATCH_SIZE == 0 actions = action_cls( GraphActionType.ADD_NODE, - states.data.x[:, 0], + states.data.x[first_node_mask], ) states = env.step(states, actions) @@ -359,16 +360,17 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([edge_index, edge_index], dim=1) + torch.stack([edge_index, edge_index]) ) states = env.step(states, actions) for i in range(NUM_NODES - 1): - edge_index = torch.tensor([[i, i + 1]] * BATCH_SIZE) + node_is = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) + node_js = torch.arange((i + 1) * BATCH_SIZE, (i + 2) * BATCH_SIZE) actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - edge_index + torch.stack([node_is, node_js]) ) states = env.step(states, actions) @@ -377,7 +379,7 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - edge_index + edge_index.T ) states = env.step(states, actions) @@ -390,11 +392,13 @@ def test_graph_env(): # ) # states = env.backward_step(states, actions) - for i in reversed(range(states.data.edge_attr.shape[1])): + num_edges_per_batch = states.data.edge_attr.shape[0] // BATCH_SIZE + for i in reversed(range(num_edges_per_batch)): + edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) actions = action_cls( GraphActionType.ADD_EDGE, - states.data.edge_attr[:, i], - states.data.edge_index[:, :, i] + states.data.edge_attr[edge_idx], + states.data.edge_index[:, edge_idx] ) states = env.backward_step(states, actions) @@ -402,18 +406,19 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long) + torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long) ) states = env.backward_step(states, actions) for i in reversed(range(NUM_NODES)): + edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) actions = action_cls( GraphActionType.ADD_NODE, - states.data.x[:, i], + states.data.x[edge_idx], ) states = env.backward_step(states, actions) - assert states.data.x.shape == (BATCH_SIZE, 0, FEATURE_DIM) + assert states.data.x.shape == (0, FEATURE_DIM) with pytest.raises(NonValidActionsError): actions = action_cls( From 7ff96d5d8e0484e28813bd242cf6c07987d92355 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 16 Nov 2024 11:26:06 +0100 Subject: [PATCH 08/27] add support for EXIT action --- src/gfn/actions.py | 59 ++++++++++++++++++++++------------- src/gfn/gym/graph_building.py | 11 +++++-- testing/test_environments.py | 5 ++- 3 files changed, 49 insertions(+), 26 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 4628590a..7207c1c5 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -179,31 +179,38 @@ class GraphActionType(enum.Enum): class GraphActions: - nodes_features_dim: ClassVar[int] # Dim size of the features tensor. - edge_features_dim: ClassVar[int] # Dim size of the edge features tensor. + nodes_features_dim: ClassVar[int] + edge_features_dim: ClassVar[int] - def __init__(self, action_type: GraphActionType, features: torch.Tensor, edge_index: Optional[torch.Tensor] = None): + def __init__(self, action_type: GraphActionType, features: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None): """Initializes a GraphAction object. Args: action: a GraphActionType indicating the type of action. - features: a tensor of shape (*batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type - edge_index: an tensor of shape (*batch_shape, 2) representing the edge to add. + features: a tensor of shape (batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type. + In case of EXIT action, this can be None. + edge_index: an tensor of shape (batch_shape, 2) representing the edge to add. This must defined if and only if the action type is GraphActionType.AddEdge. """ self.action_type = action_type - batch_dim, features_dim = features.shape - if self.action_type == GraphActionType.ADD_NODE: - assert features_dim == self.nodes_features_dim + if self.action_type == GraphActionType.EXIT: + assert features is None assert edge_index is None - elif self.action_type == GraphActionType.ADD_EDGE: - assert features_dim == self.edge_features_dim - assert edge_index is not None - assert edge_index.shape == (2, batch_dim) - - self.features = features - self.edge_index = edge_index - self.batch_shape = (batch_dim,) + self.features = None + self.edge_index = None + else: + assert features is not None + batch_dim, features_dim = features.shape + if self.action_type == GraphActionType.ADD_NODE: + assert features_dim == self.nodes_features_dim + assert edge_index is None + elif self.action_type == GraphActionType.ADD_EDGE: + assert features_dim == self.edge_features_dim + assert edge_index is not None + assert edge_index.shape == (2, batch_dim) + + self.features = features + self.edge_index = edge_index def __repr__(self): return f"""GraphAction object of type {self.action_type} and features of shape {self.features.shape}.""" @@ -215,19 +222,26 @@ def device(self) -> torch.device: def __len__(self) -> int: """Returns the number of actions in the batch.""" - return prod(self.batch_shape) + if self.action_type == GraphActionType.EXIT: + raise ValueError("Cannot get the length of exit actions.") + else: + assert self.features is not None + return self.features.shape[0] def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: """Get particular actions of the batch.""" - features = self.features[index] + features = self.features[index] if self.features is not None else None edge_index = self.edge_index[index] if self.edge_index is not None else None return GraphActions(self.action_type, features, edge_index) def __setitem__(self, index: int | Sequence[int] | Sequence[bool], action: GraphActions) -> None: """Set particular actions of the batch.""" assert self.action_type == action.action_type - self.features[index] = action.features - if self.edge_index is not None: + if self.action_type != GraphActionType.EXIT: + assert self.features is not None + self.features[index] = action.features + if self.action_type == GraphActionType.ADD_EDGE: + assert self.edge_index is not None self.edge_index[index] = action.edge_index def compare(self, other: GraphActions) -> torch.Tensor: @@ -239,7 +253,8 @@ def compare(self, other: GraphActions) -> torch.Tensor: Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. """ if self.action_type != other.action_type: - return torch.zeros(self.batch_shape, dtype=torch.bool, device=self.device) + len_ = self.features.shape[0] if self.features is not None else 1 + return torch.zeros(len_, dtype=torch.bool, device=self.device) out = torch.all(self.features == other.features, dim=-1) if self.edge_index is not None: out &= torch.all(self.edge_index == other.edge_index, dim=-1) @@ -248,7 +263,7 @@ def compare(self, other: GraphActions) -> torch.Tensor: @property def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" - return torch.full(self.batch_shape, self.action_type == GraphActionType.EXIT, dtype=torch.bool, device=self.device) + return torch.full((1,), self.action_type == GraphActionType.EXIT, dtype=torch.bool, device=self.device) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index b433340b..2d060759 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -43,15 +43,16 @@ def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: if not self.is_action_valid(states, actions): raise NonValidActionsError("Invalid action.") graphs: Batch = deepcopy(states.data) - assert len(graphs) == len(actions) if actions.action_type == GraphActionType.ADD_NODE: + assert len(graphs) == len(actions) if graphs.x is None: graphs.x = actions.features else: graphs.x = torch.cat([graphs.x, actions.features]) if actions.action_type == GraphActionType.ADD_EDGE: + assert len(graphs) == len(actions) assert actions.edge_index is not None if graphs.edge_attr is None: graphs.edge_attr = actions.features @@ -97,6 +98,9 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStat def is_action_valid( self, states: GraphStates, actions: GraphActions, backward: bool = False ) -> bool: + if actions.action_type == GraphActionType.EXIT: + return True # TODO: what are the conditions for exit action? + if actions.action_type == GraphActionType.ADD_NODE: if actions.edge_index is not None: return False @@ -121,8 +125,9 @@ def is_action_valid( if torch.any(actions.edge_index > states.data.num_nodes): return False - batch_idx = actions.edge_index % actions.batch_shape[0] - if torch.any(batch_idx != torch.arange(actions.batch_shape[0])): + batch_dim = actions.features.shape[0] + batch_idx = actions.edge_index % batch_dim + if torch.any(batch_idx != torch.arange(batch_dim)): return False if states.data.edge_attr is None: return True diff --git a/testing/test_environments.py b/testing/test_environments.py index 55dda0d5..e157f1bb 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -383,12 +383,15 @@ def test_graph_env(): ) states = env.step(states, actions) + actions = action_cls(GraphActionType.EXIT) + states = env.step(states, actions) env.reward(states) # with pytest.raises(NonValidActionsError): + # node_idx = torch.arange(0, BATCH_SIZE) # actions = action_cls( # GraphActionType.ADD_NODE, - # states.data.x[:, 0], + # states.data.x[node_idxs], # ) # states = env.backward_step(states, actions) From dacbbf746b0f31bce7930adf56a15e63cfeaed38 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 19 Nov 2024 12:40:13 +0100 Subject: [PATCH 09/27] add GraphActionPolicyEstimator --- src/gfn/actions.py | 2 +- src/gfn/modules.py | 165 +++++++++++++++------- src/gfn/states.py | 12 ++ src/gfn/utils/distributions.py | 19 ++- testing/test_samplers_and_trajectories.py | 12 +- 5 files changed, 155 insertions(+), 55 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 7207c1c5..0bdb5529 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -172,9 +172,9 @@ def is_exit(self) -> torch.Tensor: class GraphActionType(enum.Enum): - EXIT = enum.auto() ADD_NODE = enum.auto() ADD_EDGE = enum.auto() + EXIT = enum.auto() class GraphActions: diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 2eabf53d..4083d169 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -1,13 +1,14 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict import torch import torch.nn as nn -from torch.distributions import Categorical, Distribution +from torch.distributions import Categorical, Distribution, Normal +from gfn.actions import GraphActionType from gfn.preprocessors import IdentityPreprocessor, Preprocessor -from gfn.states import DiscreteStates, States -from gfn.utils.distributions import UnsqueezedCategorical +from gfn.states import DiscreteStates, GraphStates, States +from gfn.utils.distributions import ComposedDistribution, UnsqueezedCategorical REDUCTION_FXNS = { "mean": torch.mean, @@ -90,32 +91,11 @@ def forward(self, input: States | torch.Tensor) -> torch.Tensor: """ if isinstance(input, States): input = self.preprocessor(input) - - out = self.module(input) - - if not self._output_dim_is_checked: - self.check_output_dim(out) - self._output_dim_is_checked = True - - return out + return self.module(input) def __repr__(self): return f"{self.__class__.__name__} module" - @property - @abstractmethod - def expected_output_dim(self) -> int: - """Expected output dimension of the module.""" - - def check_output_dim(self, module_output: torch.Tensor) -> None: - """Check that the output of the module has the correct shape. Raises an error if not.""" - assert module_output.dtype == torch.float - if module_output.shape[-1] != self.expected_output_dim(): - raise ValueError( - f"{self.__class__.__name__} output dimension should be {self.expected_output_dim()}" - + f" but is {module_output.shape[-1]}." - ) - def to_probability_distribution( self, states: States, @@ -192,9 +172,6 @@ def __init__( ) self.reduction_fxn = REDUCTION_FXNS[reduction] - def expected_output_dim(self) -> int: - return 1 - def forward(self, input: States | torch.Tensor) -> torch.Tensor: """Forward pass of the module. @@ -212,10 +189,7 @@ def forward(self, input: States | torch.Tensor) -> torch.Tensor: if out.shape[-1] != 1: out = self.reduction_fxn(out, -1) - if not self._output_dim_is_checked: - self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == 1 return out @@ -250,12 +224,19 @@ def __init__( """ super().__init__(module, preprocessor, is_backward=is_backward) self.n_actions = n_actions + self.expected_output_dim = self.n_actions - int(self.is_backward) - def expected_output_dim(self) -> int: - if self.is_backward: - return self.n_actions - 1 - else: - return self.n_actions + def forward(self, states: DiscreteStates) -> torch.Tensor: + """Forward pass of the module. + + Args: + states: The input states. + + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). + """ + out = super().forward(states) + assert out.shape[-1] == self.expected_output_dim + return out def to_probability_distribution( self, @@ -279,7 +260,7 @@ def to_probability_distribution( on policy. epsilon: with probability epsilon, a random action is chosen. Does nothing if set to 0.0 (default), in which case it's on policy.""" - self.check_output_dim(module_output) + assert module_output.shape[-1] == self.expected_output_dim masks = states.backward_masks if self.is_backward else states.forward_masks logits = module_output @@ -296,7 +277,7 @@ def to_probability_distribution( # LogEdgeFlows are greedy, as are most P_B. else: - return UnsqueezedCategorical(logits=logits) + return UnsqueezedCategorical(logits=logits) class ConditionalDiscretePolicyEstimator(DiscretePolicyEstimator): @@ -362,11 +343,7 @@ def forward(self, states: States, conditioning: torch.Tensor) -> torch.Tensor: Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). """ out = self._forward_trunk(states, conditioning) - - if not self._output_dim_is_checked: - self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == self.expected_output_dim return out @@ -450,15 +427,9 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: if out.shape[-1] != 1: out = self.reduction_fxn(out, -1) - if not self._output_dim_is_checked: - self.check_output_dim(out) - self._output_dim_is_checked = True - + assert out.shape[-1] == 1 return out - def expected_output_dim(self) -> int: - return 1 - def to_probability_distribution( self, states: States, @@ -466,3 +437,93 @@ def to_probability_distribution( **policy_kwargs: Any, ) -> Distribution: raise NotImplementedError + + +class GraphActionPolicyEstimator(GFNModule): + r"""Container for forward and backward policy estimators for graph environments. + + $s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}$. + + or + + $s \mapsto (P_B(s' \mid s))_{s' \in Parents(s)}$. + + Attributes: + temperature: scalar to divide the logits by before softmax. + sf_bias: scalar to subtract from the exit action logit before dividing by + temperature. + epsilon: with probability epsilon, a random action is chosen. + """ + + def __init__( + self, + module: nn.ModuleDict, + preprocessor: Preprocessor | None = None, + is_backward: bool = False, + ): + """Initializes a estimator for P_F for graph environments. + + Args: + is_backward: if False, then this is a forward policy, else backward policy. + """ + super().__init__(module, preprocessor, is_backward) + assert isinstance(self.module, nn.ModuleDict) + assert self.module.keys() == {"action_type", "edge_index", "features"} + + def forward(self, states: GraphStates) -> Dict[str, torch.Tensor]: + """Forward pass of the module. + + Args: + states: The input graph states. + + Returns the . + """ + action_type_logits = self.module["action_type"](states) + edge_index_logits = self.module["edge_index"](states) + features = self.module["features"](states) + + assert action_type_logits == len(GraphActionType) + assert edge_index_logits.shape[-1] == 2 + return { + "action_type": action_type_logits, + "edge_index": edge_index_logits, + "features": features + } + + def to_probability_distribution( + self, + states: GraphStates, + module_output: Dict[str, torch.Tensor], + temperature: float = 1.0, + epsilon: float = 0.0, + ) -> ComposedDistribution: + """Returns a probability distribution given a batch of states and module output. + + We handle off-policyness using these kwargs. + + Args: + states: The states to use. + module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). + temperature: scalar to divide the logits by before softmax. Does nothing + if set to 1.0 (default), in which case it's on policy. + epsilon: with probability epsilon, a random action is chosen. Does nothing + if set to 0.0 (default), in which case it's on policy.""" + + dists = {} + + action_type_logits = module_output["action_type"] + action_type_masks = states.backward_masks if self.is_backward else states.forward_masks + action_type_logits[~action_type_masks] = -float("inf") + action_type_probs = torch.softmax(action_type_logits / temperature, dim=-1) + uniform_dist_probs = action_type_masks.float() / action_type_masks.sum(dim=-1, keepdim=True) + action_type_probs = (1 - epsilon) * action_type_probs + epsilon * uniform_dist_probs + + edge_index_logits = module_output["edge_index"] + edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) + uniform_dist_probs = torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] + edge_index_probs = (1 - epsilon) * edge_index_probs + epsilon * uniform_dist_probs + + dists["action_type"] = Categorical(probs=action_type_probs) + dists["features"] = Normal(module_output["features"], temperature) + dists["edge_index"] = Categorical(probs=edge_index_probs) + return ComposedDistribution(dists=dists) diff --git a/src/gfn/states.py b/src/gfn/states.py index 503514c4..5d63a41b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -8,6 +8,8 @@ import torch from torch_geometric.data import Batch, Data +from gfn.actions import GraphActionType + class States(ABC): """Base class for states, seen as nodes of the DAG. @@ -521,6 +523,16 @@ def __init__(self, graphs: Batch): self.batch_shape: int = len(self.data) self._log_rewards: float = None + # TODO logic repeated from env.is_valid_action + self.forward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) + self.forward_masks[:, GraphActionType.ADD_EDGE.value] = self.data.x.shape[0] > 0 + self.forward_masks[:, GraphActionType.EXIT.value] = self.data.x.shape[0] > 0 + + self.backward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) + self.backward_masks[:, GraphActionType.ADD_NODE.value] = self.data.x.shape[0] > 0 + self.backward_masks[:, GraphActionType.ADD_EDGE.value] = self.data.edge_attr.shape[0] > 0 + self.backward_masks[:, GraphActionType.EXIT.value] = self.data.x.shape[0] > 0 + @classmethod def from_batch_shape( cls, batch_shape: int, random: bool = False, sink: bool = False diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index f4948d0d..5cfb9cc8 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -1,5 +1,6 @@ +from typing import Dict import torch -from torch.distributions import Categorical +from torch.distributions import Distribution, Categorical class UnsqueezedCategorical(Categorical): @@ -39,3 +40,19 @@ def log_prob(self, sample: torch.Tensor) -> torch.Tensor: """ assert sample.shape[-1] == 1 return super().log_prob(sample.squeeze(-1)) + + +class ComposedDistribution(Distribution): + """A mixture distribution.""" + + def __init__(self, dists: Dict[str, Distribution]): + """Initializes the mixture distribution. + + Args: + dists: A dictionary of distributions. + """ + super().__init__() + self.dists = dists + + def sample(self, sample_shape: torch.Size) -> Dict[str, torch.Tensor]: + return {k: v.sample(sample_shape) for k, v in self.dists.items()} \ No newline at end of file diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 65199552..02014cb0 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -5,10 +5,12 @@ from gfn.containers import Trajectories from gfn.containers.replay_buffer import ReplayBuffer from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym.graph_building import GraphBuilding from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP -from gfn.modules import DiscretePolicyEstimator, GFNModule +from gfn.modules import DiscretePolicyEstimator, GFNModule, GraphActionPolicyEstimator from gfn.samplers import Sampler from gfn.utils.modules import MLP +from torch_geometric.nn import GCNConv def trajectory_sampling_with_return( @@ -214,3 +216,11 @@ def test_replay_buffer( replay_buffer.add(training_objects) except Exception as e: raise ValueError(f"Error while testing {env_name}") from e + + +def test_graph_building(): + node_feature_dim = 8 + env = GraphBuilding(node_feature_dim=node_feature_dim, edge_feature_dim=4) + graph_net = GCNConv(node_feature_dim, 1) + + GraphActionPolicyEstimator(module=graph_net) From e74e5000e4636f8fc3960a5f60323420d34784fa Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 22 Nov 2024 15:05:54 +0100 Subject: [PATCH 10/27] Sampler integration work --- src/gfn/actions.py | 24 +++++--- src/gfn/env.py | 25 +++++--- src/gfn/gym/graph_building.py | 12 +++- src/gfn/modules.py | 25 ++++---- src/gfn/samplers.py | 18 +++--- src/gfn/states.py | 69 +++++++++------------- src/gfn/utils/distributions.py | 11 +++- testing/test_samplers_and_trajectories.py | 72 +++++++++++++++++++++-- 8 files changed, 171 insertions(+), 85 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 0bdb5529..03a1bfdc 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -171,10 +171,10 @@ def is_exit(self) -> torch.Tensor: return self.compare(exit_actions_tensor) -class GraphActionType(enum.Enum): - ADD_NODE = enum.auto() - ADD_EDGE = enum.auto() - EXIT = enum.auto() +class GraphActionType(enum.IntEnum): + ADD_NODE = 0 + ADD_EDGE = 1 + EXIT = 2 class GraphActions: @@ -182,7 +182,7 @@ class GraphActions: nodes_features_dim: ClassVar[int] edge_features_dim: ClassVar[int] - def __init__(self, action_type: GraphActionType, features: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None): + def __init__(self, action_type: torch.Tensor, features: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None): """Initializes a GraphAction object. Args: @@ -192,7 +192,9 @@ def __init__(self, action_type: GraphActionType, features: Optional[torch.Tensor edge_index: an tensor of shape (batch_shape, 2) representing the edge to add. This must defined if and only if the action type is GraphActionType.AddEdge. """ - self.action_type = action_type + self.batch_shape = action_type.shape + assert torch.all(action_type == action_type[0]) + self.action_type = action_type[0] if self.action_type == GraphActionType.EXIT: assert features is None assert edge_index is None @@ -201,9 +203,9 @@ def __init__(self, action_type: GraphActionType, features: Optional[torch.Tensor else: assert features is not None batch_dim, features_dim = features.shape + assert (batch_dim,) == self.batch_shape if self.action_type == GraphActionType.ADD_NODE: assert features_dim == self.nodes_features_dim - assert edge_index is None elif self.action_type == GraphActionType.ADD_EDGE: assert features_dim == self.edge_features_dim assert edge_index is not None @@ -265,5 +267,13 @@ def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" return torch.full((1,), self.action_type == GraphActionType.EXIT, dtype=torch.bool, device=self.device) + @classmethod + def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: # TODO: remove make_dummy_actions + """Creates an Actions object of dummy actions with the given batch shape.""" + return GraphActions( + action_type=torch.full(batch_shape, fill_value=GraphActionType.EXIT), + features=None, + edge_index=None + ) diff --git a/src/gfn/env.py b/src/gfn/env.py index 8780c069..8db703b3 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch from torch_geometric.data import Batch, Data @@ -568,8 +568,8 @@ class GraphEnv(Env): def __init__( self, s0: Data, - node_feature_dim: int, - edge_feature_dim: int, + # node_feature_dim: int, + # edge_feature_dim: int, sf: Optional[Data] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, @@ -592,8 +592,8 @@ def __init__( """ self.s0 = s0.to(device_str) - self.node_feature_dim = node_feature_dim - self.edge_feature_dim = edge_feature_dim + self.node_feature_dim = s0.x.shape[1] + self.edge_feature_dim = s0.edge_attr.shape[1] self.sf = sf @@ -609,8 +609,8 @@ def make_states_class(self) -> type[GraphStates]: class GraphEnvStates(GraphStates): s0 = env.s0 sf = env.sf - node_feature_dim = env.node_feature_dim - edge_feature_dim = env.edge_feature_dim + # node_feature_dim = env.node_feature_dim + # edge_feature_dim = env.edge_feature_dim make_random_states_graph = env.make_random_states_tensor return GraphEnvStates @@ -630,6 +630,17 @@ class DefaultGraphAction(GraphActions): return DefaultGraphAction + def actions_from_tensor(self, tensor: Dict[str, torch.Tensor]): + """Wraps the supplied Tensor in an Actions instance. + + Args: + tensor: The tensor of shape "action_shape" representing the actions. + + Returns: + Actions: An instance of Actions. + """ + return self.Actions(**tensor) + @abstractmethod def step(self, states: GraphStates, actions: Actions) -> GraphStates: """Function that takes a batch of graph states and actions and returns a batch of next diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 2d060759..e33fafea 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -18,7 +18,14 @@ def __init__( state_evaluator: Callable[[Batch], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): - s0 = Data().to(device_str) + s0 = Data( + x=torch.zeros((0, node_feature_dim), dtype=torch.float32), + edge_attr=torch.zeros((0, edge_feature_dim), dtype=torch.float32), + edge_index=torch.zeros((2, 0), dtype=torch.long), + ).to(device_str) + sf = Data( + x=torch.ones((1, node_feature_dim), dtype=torch.float32) * float('inf'), + ).to(device_str) if state_evaluator is None: state_evaluator = GCNConvEvaluator(node_feature_dim) @@ -26,8 +33,7 @@ def __init__( super().__init__( s0=s0, - node_feature_dim=node_feature_dim, - edge_feature_dim=edge_feature_dim, + sf=sf, device_str=device_str, ) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 4083d169..adbcc1c3 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -458,7 +458,7 @@ class GraphActionPolicyEstimator(GFNModule): def __init__( self, module: nn.ModuleDict, - preprocessor: Preprocessor | None = None, + # preprocessor: Preprocessor | None = None, is_backward: bool = False, ): """Initializes a estimator for P_F for graph environments. @@ -466,9 +466,12 @@ def __init__( Args: is_backward: if False, then this is a forward policy, else backward policy. """ - super().__init__(module, preprocessor, is_backward) - assert isinstance(self.module, nn.ModuleDict) - assert self.module.keys() == {"action_type", "edge_index", "features"} + #super().__init__(module, preprocessor, is_backward) + nn.Module.__init__(self) + assert isinstance(module, nn.ModuleDict) + assert module.keys() == {"action_type", "edge_index", "features"} + self.module = module + self.is_backward = is_backward def forward(self, states: GraphStates) -> Dict[str, torch.Tensor]: """Forward pass of the module. @@ -482,8 +485,7 @@ def forward(self, states: GraphStates) -> Dict[str, torch.Tensor]: edge_index_logits = self.module["edge_index"](states) features = self.module["features"](states) - assert action_type_logits == len(GraphActionType) - assert edge_index_logits.shape[-1] == 2 + assert action_type_logits.shape[-1] == len(GraphActionType) return { "action_type": action_type_logits, "edge_index": edge_index_logits, @@ -517,13 +519,14 @@ def to_probability_distribution( action_type_probs = torch.softmax(action_type_logits / temperature, dim=-1) uniform_dist_probs = action_type_masks.float() / action_type_masks.sum(dim=-1, keepdim=True) action_type_probs = (1 - epsilon) * action_type_probs + epsilon * uniform_dist_probs + dists["action_type"] = Categorical(probs=action_type_probs) edge_index_logits = module_output["edge_index"] - edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) - uniform_dist_probs = torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] - edge_index_probs = (1 - epsilon) * edge_index_probs + epsilon * uniform_dist_probs + if edge_index_logits.shape[-1] != 0: + edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) + uniform_dist_probs = torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] + edge_index_probs = (1 - epsilon) * edge_index_probs + epsilon * uniform_dist_probs + dists["edge_index"] = UnsqueezedCategorical(probs=edge_index_probs) - dists["action_type"] = Categorical(probs=action_type_probs) dists["features"] = Normal(module_output["features"], temperature) - dists["edge_index"] = Categorical(probs=edge_index_probs) return ComposedDistribution(dists=dists) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 819620f0..8a8a112c 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -7,7 +7,7 @@ from gfn.containers import Trajectories from gfn.env import Env from gfn.modules import GFNModule -from gfn.states import States, stack_states +from gfn.states import States from gfn.utils.handlers import ( has_conditioning_exception_handler, no_conditioning_exception_handler, @@ -147,7 +147,7 @@ def sample_trajectories( if conditioning is not None: assert states.batch_shape == conditioning.shape[: len(states.batch_shape)] - device = states.tensor.device + device = states.device dones = ( states.is_initial_state @@ -155,8 +155,8 @@ def sample_trajectories( else states.is_sink_state ) - trajectories_states: List[States] = [deepcopy(states)] - trajectories_actions: List[torch.Tensor] = [] + trajectories_states: States = deepcopy(states) + trajectories_actions: Optional[Actions] = None trajectories_logprobs: List[torch.Tensor] = [] trajectories_dones = torch.zeros( n_trajectories, dtype=torch.long, device=device @@ -206,7 +206,11 @@ def sample_trajectories( if save_logprobs: # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs - trajectories_actions.append(actions) + + if trajectories_actions is None: + trajectories_actions = actions + else: + trajectories_actions.extend(actions) trajectories_logprobs.append(log_probs) if self.estimator.is_backward: @@ -239,10 +243,8 @@ def sample_trajectories( states = new_states dones = dones | new_dones - trajectories_states.append(deepcopy(states)) + trajectories_states.extend(deepcopy(states)) - trajectories_states = stack_states(trajectories_states) - trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = ( torch.stack(trajectories_logprobs, dim=0) if save_logprobs else None ) diff --git a/src/gfn/states.py b/src/gfn/states.py index 5d63a41b..d1cbed9b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -5,6 +5,7 @@ from math import prod from typing import Callable, ClassVar, List, Optional, Sequence, Tuple +import numpy as np import torch from torch_geometric.data import Batch, Data @@ -478,34 +479,6 @@ def init_forward_masks(self, set_ones: bool = True): self.forward_masks = torch.zeros(shape).bool() -def stack_states(states: List[States]): - """Given a list of states, stacks them along a new dimension (0).""" - state_example = states[0] # We assume all elems of `states` are the same. - - stacked_states = state_example.from_batch_shape((0, 0)) # Empty. - stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) - if state_example._log_rewards: - stacked_states._log_rewards = torch.stack( - [s._log_rewards for s in states], dim=0 - ) - - # We are dealing with a list of DiscretrStates instances. - if hasattr(state_example, "forward_masks"): - stacked_states.forward_masks = torch.stack( - [s.forward_masks for s in states], dim=0 - ) - stacked_states.backward_masks = torch.stack( - [s.backward_masks for s in states], dim=0 - ) - - # Adds the trajectory dimension. - stacked_states.batch_shape = ( - stacked_states.tensor.shape[0], - ) + state_example.batch_shape - - return stacked_states - - class GraphStates(ABC): """ Base class for Graph as a state representation. The `GraphStates` object is a batched collection of @@ -515,8 +488,6 @@ class GraphStates(ABC): s0: ClassVar[Data] sf: ClassVar[Optional[Data]] - node_feature_dim: ClassVar[int] - edge_feature_dim: ClassVar[int] def __init__(self, graphs: Batch): self.data: Batch = graphs @@ -524,14 +495,15 @@ def __init__(self, graphs: Batch): self._log_rewards: float = None # TODO logic repeated from env.is_valid_action + not_empty = self.data.x is not None and self.data.x.shape[0] > 0 self.forward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) - self.forward_masks[:, GraphActionType.ADD_EDGE.value] = self.data.x.shape[0] > 0 - self.forward_masks[:, GraphActionType.EXIT.value] = self.data.x.shape[0] > 0 - + self.forward_masks[:, GraphActionType.ADD_EDGE] = not_empty + self.forward_masks[:, GraphActionType.EXIT] = not_empty + self.backward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) - self.backward_masks[:, GraphActionType.ADD_NODE.value] = self.data.x.shape[0] > 0 - self.backward_masks[:, GraphActionType.ADD_EDGE.value] = self.data.edge_attr.shape[0] > 0 - self.backward_masks[:, GraphActionType.EXIT.value] = self.data.x.shape[0] > 0 + self.backward_masks[:, GraphActionType.ADD_NODE] = not_empty + self.backward_masks[:, GraphActionType.ADD_EDGE] = not_empty and self.data.edge_attr.shape[0] > 0 + self.backward_masks[:, GraphActionType.EXIT] = not_empty @classmethod def from_batch_shape( @@ -586,8 +558,8 @@ def make_random_states_graph(cls, batch_shape: int) -> Batch: data_list = [] for _ in range(batch_shape): data = Data( - x=torch.rand(cls.s0.num_nodes, cls.node_feature_dim), - edge_attr=torch.rand(cls.s0.num_edges, cls.edge_feature_dim), + x=torch.rand(cls.s0.num_nodes, cls.s0.x.shape[1]), + edge_attr=torch.rand(cls.s0.num_edges, cls.s0.edge_attr.shape[1]), edge_index=cls.s0.edge_index, # TODO: make it random ) data_list.append(data) @@ -599,13 +571,18 @@ def __len__(self): def __repr__(self): return ( f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " - f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}" + f"node feature dim {self.s0.x.shape[1]} and edge feature dim {self.s0.edge_attr.shape[1]}" ) def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: - out = self.__class__(Batch(self.data[index])) + idxs = np.arange(len(self.data))[index] + data = [] + for i in idxs: + data.append(self.data.get_example(i)) + + out = GraphStates(Batch.from_data_list(data)) if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] @@ -643,7 +620,10 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): @property def device(self) -> torch.device: - return self.data.get_example(0).x.device + sample = self.data.get_example(0).x + if sample is not None: + return sample.device + return torch.device("cuda" if torch.cuda.is_available() else "cpu") def to(self, device: torch.device) -> GraphStates: """ @@ -675,3 +655,10 @@ def log_rewards(self) -> torch.Tensor: @log_rewards.setter def log_rewards(self, log_rewards: torch.Tensor) -> None: self._log_rewards = log_rewards + + @property + def is_sink_state(self) -> torch.Tensor: + batch_dim = len(self.data.ptr) - 1 + if len(self.data.x) == 0: + return torch.zeros(batch_dim, dtype=torch.bool) + return torch.all(self.data.x == self.sf.x, dim=-1).reshape(batch_dim,) diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index 5cfb9cc8..a44727b3 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -54,5 +54,12 @@ def __init__(self, dists: Dict[str, Distribution]): super().__init__() self.dists = dists - def sample(self, sample_shape: torch.Size) -> Dict[str, torch.Tensor]: - return {k: v.sample(sample_shape) for k, v in self.dists.items()} \ No newline at end of file + def sample(self, sample_shape=torch.Size()) -> Dict[str, torch.Tensor]: + return {k: v.sample(sample_shape) for k, v in self.dists.items()} + + def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: + log_probs = [ + v.log_prob(sample[k]).reshape(sample[k].shape[0], -1).sum(dim=-1) + for k, v in self.dists.items() + ] + return sum(log_probs) \ No newline at end of file diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 02014cb0..598a7437 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -1,7 +1,12 @@ from typing import Literal, Tuple import pytest +import torch +from torch import nn +from torch_geometric.nn import GCNConv +from torch_geometric.data import Batch +from gfn.actions import GraphActionType from gfn.containers import Trajectories from gfn.containers.replay_buffer import ReplayBuffer from gfn.gym import Box, DiscreteEBM, HyperGrid @@ -9,9 +14,8 @@ from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP from gfn.modules import DiscretePolicyEstimator, GFNModule, GraphActionPolicyEstimator from gfn.samplers import Sampler +from gfn.states import GraphStates from gfn.utils.modules import MLP -from torch_geometric.nn import GCNConv - def trajectory_sampling_with_return( env_name: str, @@ -218,9 +222,65 @@ def test_replay_buffer( raise ValueError(f"Error while testing {env_name}") from e +# ------ GRAPH TESTS ------ + + +class ActionTypeNet(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.conv = GCNConv(feature_dim, len(GraphActionType)) + + def forward(self, states: GraphStates) -> torch.Tensor: + if len(states.data.x) == 0: + out = torch.zeros((len(states), len(GraphActionType))) + out[:, GraphActionType.ADD_NODE] = 1 + return out + + x = self.conv(states.data.x, states.data.edge_index) + return torch.mean(x, dim=0) + +class FeaturesNet(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.feature_dim = feature_dim + self.conv = GCNConv(feature_dim, feature_dim) + + def forward(self, states: GraphStates) -> torch.Tensor: + if len(states.data.x) == 0: + return torch.zeros((len(states), self.feature_dim)) + x = self.conv(states.data.x, states.data.edge_index) + x = x.reshape(len(states), -1, x.shape[-1]).mean(dim=0) + return x + +class EdgeIndexNet(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.conv = GCNConv(feature_dim, 8) + + def forward(self, states: GraphStates) -> torch.Tensor: + x = self.conv(states.data.x, states.data.edge_index) + return torch.einsum("nf,mf->nm", x, x) + def test_graph_building(): - node_feature_dim = 8 - env = GraphBuilding(node_feature_dim=node_feature_dim, edge_feature_dim=4) - graph_net = GCNConv(node_feature_dim, 1) + feature_dim = 8 + env = GraphBuilding(node_feature_dim=feature_dim, edge_feature_dim=feature_dim) + + action_type_net = ActionTypeNet(feature_dim) + features_net = FeaturesNet(feature_dim) + edge_index = EdgeIndexNet(feature_dim) + module = nn.ModuleDict({ + "action_type": action_type_net, + "features": features_net, + "edge_index": edge_index + }) + + pf_estimator = GraphActionPolicyEstimator(module=module) + + sampler = Sampler(estimator=pf_estimator) + trajectories = sampler.sample_trajectories( + env, + n=5, + save_logprobs=True, + save_estimator_outputs=True, + ) - GraphActionPolicyEstimator(module=graph_net) From 5e64c84b8d3779d291c0a42100c8b14d055e4ce3 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 26 Nov 2024 11:48:39 +0100 Subject: [PATCH 11/27] use TensorDict --- src/gfn/actions.py | 45 ++++++++------- src/gfn/gym/graph_building.py | 37 ++++++------ src/gfn/modules.py | 48 ++++++++-------- src/gfn/samplers.py | 12 ++-- src/gfn/states.py | 14 +++-- src/gfn/utils/distributions.py | 7 ++- testing/test_environments.py | 22 ++++---- testing/test_samplers_and_trajectories.py | 69 ++++++++++------------- 8 files changed, 130 insertions(+), 124 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 03a1bfdc..49ac974c 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -1,7 +1,7 @@ from __future__ import annotations # This allows to use the class name in type hints -from abc import ABC import enum +from abc import ABC from math import prod from typing import ClassVar, Optional, Sequence @@ -178,13 +178,17 @@ class GraphActionType(enum.IntEnum): class GraphActions: - nodes_features_dim: ClassVar[int] edge_features_dim: ClassVar[int] - def __init__(self, action_type: torch.Tensor, features: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None): + def __init__( + self, + action_type: torch.Tensor, + features: Optional[torch.Tensor] = None, + edge_index: Optional[torch.Tensor] = None, + ): """Initializes a GraphAction object. - + Args: action: a GraphActionType indicating the type of action. features: a tensor of shape (batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type. @@ -210,7 +214,7 @@ def __init__(self, action_type: torch.Tensor, features: Optional[torch.Tensor] = assert features_dim == self.edge_features_dim assert edge_index is not None assert edge_index.shape == (2, batch_dim) - + self.features = features self.edge_index = edge_index @@ -236,16 +240,14 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActio edge_index = self.edge_index[index] if self.edge_index is not None else None return GraphActions(self.action_type, features, edge_index) - def __setitem__(self, index: int | Sequence[int] | Sequence[bool], action: GraphActions) -> None: + def __setitem__( + self, index: int | Sequence[int] | Sequence[bool], action: GraphActions + ) -> None: """Set particular actions of the batch.""" - assert self.action_type == action.action_type - if self.action_type != GraphActionType.EXIT: - assert self.features is not None - self.features[index] = action.features - if self.action_type == GraphActionType.ADD_EDGE: - assert self.edge_index is not None - self.edge_index[index] = action.edge_index - + self.action_type[index] = action.action_type + self.features[index] = action.features + self.edge_index[index] = action.edge_index + def compare(self, other: GraphActions) -> torch.Tensor: """Compares the actions to another GraphAction object. @@ -265,15 +267,20 @@ def compare(self, other: GraphActions) -> torch.Tensor: @property def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" - return torch.full((1,), self.action_type == GraphActionType.EXIT, dtype=torch.bool, device=self.device) + return torch.full( + (1,), + self.action_type == GraphActionType.EXIT, + dtype=torch.bool, + device=self.device, + ) @classmethod - def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: # TODO: remove make_dummy_actions + def make_dummy_actions( + cls, batch_shape: tuple[int] + ) -> GraphActions: # TODO: remove make_dummy_actions """Creates an Actions object of dummy actions with the given batch shape.""" return GraphActions( action_type=torch.full(batch_shape, fill_value=GraphActionType.EXIT), features=None, - edge_index=None + edge_index=None, ) - - diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index e33fafea..f8a4f624 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -24,7 +24,7 @@ def __init__( edge_index=torch.zeros((2, 0), dtype=torch.long), ).to(device_str) sf = Data( - x=torch.ones((1, node_feature_dim), dtype=torch.float32) * float('inf'), + x=torch.ones((1, node_feature_dim), dtype=torch.float32) * float("inf"), ).to(device_str) if state_evaluator is None: @@ -66,11 +66,12 @@ def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: graphs.edge_index = actions.edge_index else: graphs.edge_attr = torch.cat([graphs.edge_attr, actions.features]) - graphs.edge_index = torch.cat([graphs.edge_index, actions.edge_index], dim=1) + graphs.edge_index = torch.cat( + [graphs.edge_index, actions.edge_index], dim=1 + ) return self.States(graphs) - def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: """Backward step function for the GraphBuilding environment. @@ -88,13 +89,14 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStat if actions.action_type == GraphActionType.ADD_NODE: assert graphs.x is not None is_equal = torch.any( - torch.all(graphs.x[:, None] == actions.features, dim=-1), - dim=-1 + torch.all(graphs.x[:, None] == actions.features, dim=-1), dim=-1 ) graphs.x = graphs.x[~is_equal] elif actions.action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None - is_equal = torch.all(graphs.edge_index[:, None] == actions.edge_index[:, :, None], dim=0) + is_equal = torch.all( + graphs.edge_index[:, None] == actions.edge_index[:, :, None], dim=0 + ) is_equal = torch.any(is_equal, dim=0) graphs.edge_attr = graphs.edge_attr[~is_equal] graphs.edge_index = graphs.edge_index[:, ~is_equal] @@ -106,13 +108,13 @@ def is_action_valid( ) -> bool: if actions.action_type == GraphActionType.EXIT: return True # TODO: what are the conditions for exit action? - + if actions.action_type == GraphActionType.ADD_NODE: if actions.edge_index is not None: return False if states.data.x is None: return not backward - + equal_nodes_per_batch = torch.all( states.data.x == actions.features[:, None], dim=-1 ).reshape(states.data.batch_size, -1) @@ -121,7 +123,7 @@ def is_action_valid( if backward: # TODO: check if no edge are connected? return torch.all(equal_nodes_per_batch == 1) return torch.all(equal_nodes_per_batch == 0) - + if actions.action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None if torch.any(actions.edge_index[0] == actions.edge_index[1]): @@ -130,9 +132,9 @@ def is_action_valid( return False if torch.any(actions.edge_index > states.data.num_nodes): return False - + batch_dim = actions.features.shape[0] - batch_idx = actions.edge_index % batch_dim + batch_idx = actions.edge_index % batch_dim if torch.any(batch_idx != torch.arange(batch_dim)): return False if states.data.edge_attr is None: @@ -142,15 +144,18 @@ def is_action_valid( states.data.edge_attr == actions.features[:, None], dim=-1 ).reshape(states.data.batch_size, -1) equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) - + equal_edges_per_batch_index = torch.all( states.data.edge_index[:, None] == actions.edge_index[:, :, None], dim=0 ).reshape(states.data.batch_size, -1) equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) if backward: - return torch.all(equal_edges_per_batch_attr == 1) and torch.all(equal_edges_per_batch_index == 1) - return torch.all(equal_edges_per_batch_attr == 0) and torch.all(equal_edges_per_batch_index == 0) - + return torch.all(equal_edges_per_batch_attr == 1) and torch.all( + equal_edges_per_batch_index == 1 + ) + return torch.all(equal_edges_per_batch_attr == 0) and torch.all( + equal_edges_per_batch_index == 0 + ) def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. @@ -186,4 +191,4 @@ def __init__(self, num_features): def __call__(self, batch: Batch) -> torch.Tensor: out = self.net(batch.x, batch.edge_index) out = out.reshape(batch.batch_size, -1) - return out.mean(-1) \ No newline at end of file + return out.mean(-1) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index adbcc1c3..90087e5a 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -1,8 +1,9 @@ -from abc import ABC, abstractmethod +from abc import ABC from typing import Any, Dict import torch import torch.nn as nn +from tensordict import TensorDict from torch.distributions import Categorical, Distribution, Normal from gfn.actions import GraphActionType @@ -277,7 +278,7 @@ def to_probability_distribution( # LogEdgeFlows are greedy, as are most P_B. else: - return UnsqueezedCategorical(logits=logits) + return UnsqueezedCategorical(logits=logits) class ConditionalDiscretePolicyEstimator(DiscretePolicyEstimator): @@ -457,7 +458,7 @@ class GraphActionPolicyEstimator(GFNModule): def __init__( self, - module: nn.ModuleDict, + module: nn.Module, # preprocessor: Preprocessor | None = None, is_backward: bool = False, ): @@ -466,14 +467,12 @@ def __init__( Args: is_backward: if False, then this is a forward policy, else backward policy. """ - #super().__init__(module, preprocessor, is_backward) + # super().__init__(module, preprocessor, is_backward) nn.Module.__init__(self) - assert isinstance(module, nn.ModuleDict) - assert module.keys() == {"action_type", "edge_index", "features"} self.module = module self.is_backward = is_backward - - def forward(self, states: GraphStates) -> Dict[str, torch.Tensor]: + + def forward(self, states: GraphStates) -> TensorDict: """Forward pass of the module. Args: @@ -481,16 +480,7 @@ def forward(self, states: GraphStates) -> Dict[str, torch.Tensor]: Returns the . """ - action_type_logits = self.module["action_type"](states) - edge_index_logits = self.module["edge_index"](states) - features = self.module["features"](states) - - assert action_type_logits.shape[-1] == len(GraphActionType) - return { - "action_type": action_type_logits, - "edge_index": edge_index_logits, - "features": features - } + return self.module(states) def to_probability_distribution( self, @@ -510,22 +500,32 @@ def to_probability_distribution( if set to 1.0 (default), in which case it's on policy. epsilon: with probability epsilon, a random action is chosen. Does nothing if set to 0.0 (default), in which case it's on policy.""" - + dists = {} action_type_logits = module_output["action_type"] - action_type_masks = states.backward_masks if self.is_backward else states.forward_masks + action_type_masks = ( + states.backward_masks if self.is_backward else states.forward_masks + ) action_type_logits[~action_type_masks] = -float("inf") action_type_probs = torch.softmax(action_type_logits / temperature, dim=-1) - uniform_dist_probs = action_type_masks.float() / action_type_masks.sum(dim=-1, keepdim=True) - action_type_probs = (1 - epsilon) * action_type_probs + epsilon * uniform_dist_probs + uniform_dist_probs = action_type_masks.float() / action_type_masks.sum( + dim=-1, keepdim=True + ) + action_type_probs = ( + 1 - epsilon + ) * action_type_probs + epsilon * uniform_dist_probs dists["action_type"] = Categorical(probs=action_type_probs) edge_index_logits = module_output["edge_index"] if edge_index_logits.shape[-1] != 0: edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) - uniform_dist_probs = torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] - edge_index_probs = (1 - epsilon) * edge_index_probs + epsilon * uniform_dist_probs + uniform_dist_probs = ( + torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] + ) + edge_index_probs = ( + 1 - epsilon + ) * edge_index_probs + epsilon * uniform_dist_probs dists["edge_index"] = UnsqueezedCategorical(probs=edge_index_probs) dists["features"] = Normal(module_output["features"], temperature) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 8a8a112c..4b706a39 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -193,12 +193,12 @@ def sample_trajectories( if estimator_outputs is not None: # Place estimator outputs into a stackable tensor. Note that this # will be replaced with torch.nested.nested_tensor in the future. - estimator_outputs_padded = torch.full( - (n_trajectories,) + estimator_outputs.shape[1:], + estimator_outputs_padded = torch.full_like( + estimator_outputs.expand( + (n_trajectories,) + estimator_outputs.shape[1:] + ), fill_value=-float("inf"), - dtype=torch.float, - device=device, - ) + ).clone() # TODO: inefficient estimator_outputs_padded[~dones] = estimator_outputs all_estimator_outputs.append(estimator_outputs_padded) @@ -206,7 +206,7 @@ def sample_trajectories( if save_logprobs: # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs - + if trajectories_actions is None: trajectories_actions = actions else: diff --git a/src/gfn/states.py b/src/gfn/states.py index d1cbed9b..b6c9bfb7 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, List, Optional, Sequence, Tuple +from typing import Callable, ClassVar, Optional, Sequence, Tuple import numpy as np import torch @@ -499,10 +499,12 @@ def __init__(self, graphs: Batch): self.forward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) self.forward_masks[:, GraphActionType.ADD_EDGE] = not_empty self.forward_masks[:, GraphActionType.EXIT] = not_empty - + self.backward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) self.backward_masks[:, GraphActionType.ADD_NODE] = not_empty - self.backward_masks[:, GraphActionType.ADD_EDGE] = not_empty and self.data.edge_attr.shape[0] > 0 + self.backward_masks[:, GraphActionType.ADD_EDGE] = ( + not_empty and self.data.edge_attr.shape[0] > 0 + ) self.backward_masks[:, GraphActionType.EXIT] = not_empty @classmethod @@ -655,10 +657,12 @@ def log_rewards(self) -> torch.Tensor: @log_rewards.setter def log_rewards(self, log_rewards: torch.Tensor) -> None: self._log_rewards = log_rewards - + @property def is_sink_state(self) -> torch.Tensor: batch_dim = len(self.data.ptr) - 1 if len(self.data.x) == 0: return torch.zeros(batch_dim, dtype=torch.bool) - return torch.all(self.data.x == self.sf.x, dim=-1).reshape(batch_dim,) + return torch.all(self.data.x == self.sf.x, dim=-1).reshape( + batch_dim, + ) diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index a44727b3..421a04e9 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -1,6 +1,7 @@ from typing import Dict + import torch -from torch.distributions import Distribution, Categorical +from torch.distributions import Categorical, Distribution class UnsqueezedCategorical(Categorical): @@ -56,10 +57,10 @@ def __init__(self, dists: Dict[str, Distribution]): def sample(self, sample_shape=torch.Size()) -> Dict[str, torch.Tensor]: return {k: v.sample(sample_shape) for k, v in self.dists.items()} - + def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: log_probs = [ v.log_prob(sample[k]).reshape(sample[k].shape[0], -1).sum(dim=-1) for k, v in self.dists.items() ] - return sum(log_probs) \ No newline at end of file + return sum(log_probs) diff --git a/testing/test_environments.py b/testing/test_environments.py index e157f1bb..80198fd2 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -334,7 +334,7 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long) + torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long), ) states = env.step(states, actions) @@ -344,7 +344,7 @@ def test_graph_env(): torch.rand((BATCH_SIZE, FEATURE_DIM)), ) states = env.step(states, actions) - + assert states.data.x.shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) with pytest.raises(NonValidActionsError): @@ -360,17 +360,17 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([edge_index, edge_index]) + torch.stack([edge_index, edge_index]), ) states = env.step(states, actions) for i in range(NUM_NODES - 1): - node_is = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) + node_is = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) node_js = torch.arange((i + 1) * BATCH_SIZE, (i + 2) * BATCH_SIZE) actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([node_is, node_js]) + torch.stack([node_is, node_js]), ) states = env.step(states, actions) @@ -379,7 +379,7 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - edge_index.T + edge_index.T, ) states = env.step(states, actions) @@ -401,15 +401,15 @@ def test_graph_env(): actions = action_cls( GraphActionType.ADD_EDGE, states.data.edge_attr[edge_idx], - states.data.edge_index[:, edge_idx] + states.data.edge_index[:, edge_idx], ) states = env.backward_step(states, actions) - + with pytest.raises(NonValidActionsError): actions = action_cls( GraphActionType.ADD_EDGE, torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long) + torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long), ) states = env.backward_step(states, actions) @@ -420,7 +420,7 @@ def test_graph_env(): states.data.x[edge_idx], ) states = env.backward_step(states, actions) - + assert states.data.x.shape == (0, FEATURE_DIM) with pytest.raises(NonValidActionsError): @@ -428,4 +428,4 @@ def test_graph_env(): GraphActionType.ADD_NODE, torch.rand((BATCH_SIZE, FEATURE_DIM)), ) - states = env.backward_step(states, actions) \ No newline at end of file + states = env.backward_step(states, actions) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 598a7437..302f31b3 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -2,9 +2,9 @@ import pytest import torch +from tensordict import TensorDict from torch import nn from torch_geometric.nn import GCNConv -from torch_geometric.data import Batch from gfn.actions import GraphActionType from gfn.containers import Trajectories @@ -17,6 +17,7 @@ from gfn.states import GraphStates from gfn.utils.modules import MLP + def trajectory_sampling_with_return( env_name: str, preprocessor_name: Literal["KHot", "OneHot", "Identity", "Enum"], @@ -225,55 +226,44 @@ def test_replay_buffer( # ------ GRAPH TESTS ------ -class ActionTypeNet(nn.Module): - def __init__(self, feature_dim: int): - super().__init__() - self.conv = GCNConv(feature_dim, len(GraphActionType)) - - def forward(self, states: GraphStates) -> torch.Tensor: - if len(states.data.x) == 0: - out = torch.zeros((len(states), len(GraphActionType))) - out[:, GraphActionType.ADD_NODE] = 1 - return out - - x = self.conv(states.data.x, states.data.edge_index) - return torch.mean(x, dim=0) - -class FeaturesNet(nn.Module): +class GraphActionNet(nn.Module): def __init__(self, feature_dim: int): super().__init__() self.feature_dim = feature_dim - self.conv = GCNConv(feature_dim, feature_dim) + self.action_type_conv = GCNConv(feature_dim, len(GraphActionType)) + self.features_conv = GCNConv(feature_dim, feature_dim) + self.edge_index_conv = GCNConv(feature_dim, 8) - def forward(self, states: GraphStates) -> torch.Tensor: + def forward(self, states: GraphStates) -> TensorDict: if len(states.data.x) == 0: - return torch.zeros((len(states), self.feature_dim)) - x = self.conv(states.data.x, states.data.edge_index) - x = x.reshape(len(states), -1, x.shape[-1]).mean(dim=0) - return x - -class EdgeIndexNet(nn.Module): - def __init__(self, feature_dim: int): - super().__init__() - self.conv = GCNConv(feature_dim, 8) + action_type = torch.zeros((len(states), len(GraphActionType))) + action_type[:, GraphActionType.ADD_NODE] = 1 + features = torch.zeros((len(states), self.feature_dim)) + else: + action_type = self.action_type_conv(states.data.x, states.data.edge_index) + action_type = torch.mean(action_type, dim=0) + features = self.features_conv(states.data.x, states.data.edge_index) + features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=0) + + edge_index = self.edge_index_conv(states.data.x, states.data.edge_index) + edge_index = edge_index.reshape(states.batch_shape, -1, 8) + edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) + + return TensorDict( + { + "action_type": action_type, + "features": features, + "edge_index": edge_index, + }, + batch_size=states.batch_shape, + ) - def forward(self, states: GraphStates) -> torch.Tensor: - x = self.conv(states.data.x, states.data.edge_index) - return torch.einsum("nf,mf->nm", x, x) def test_graph_building(): feature_dim = 8 env = GraphBuilding(node_feature_dim=feature_dim, edge_feature_dim=feature_dim) - action_type_net = ActionTypeNet(feature_dim) - features_net = FeaturesNet(feature_dim) - edge_index = EdgeIndexNet(feature_dim) - module = nn.ModuleDict({ - "action_type": action_type_net, - "features": features_net, - "edge_index": edge_index - }) - + module = GraphActionNet(feature_dim) pf_estimator = GraphActionPolicyEstimator(module=module) sampler = Sampler(estimator=pf_estimator) @@ -283,4 +273,3 @@ def test_graph_building(): save_logprobs=True, save_estimator_outputs=True, ) - From 81f8b7142ebfa2495aaa7e30e034ca39217314fd Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 28 Nov 2024 19:40:13 +0100 Subject: [PATCH 12/27] solve some errors --- src/gfn/actions.py | 69 +++++++----------- src/gfn/env.py | 24 ++---- src/gfn/gym/graph_building.py | 89 ++++++++++++----------- src/gfn/states.py | 21 ++++-- src/gfn/utils/distributions.py | 2 +- testing/test_environments.py | 2 +- testing/test_samplers_and_trajectories.py | 5 +- 7 files changed, 99 insertions(+), 113 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 49ac974c..89119fa1 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -177,9 +177,8 @@ class GraphActionType(enum.IntEnum): EXIT = 2 -class GraphActions: - nodes_features_dim: ClassVar[int] - edge_features_dim: ClassVar[int] +class GraphActions(Actions): + features_dim: ClassVar[int] def __init__( self, @@ -197,26 +196,20 @@ def __init__( This must defined if and only if the action type is GraphActionType.AddEdge. """ self.batch_shape = action_type.shape - assert torch.all(action_type == action_type[0]) - self.action_type = action_type[0] - if self.action_type == GraphActionType.EXIT: - assert features is None - assert edge_index is None - self.features = None - self.edge_index = None - else: - assert features is not None - batch_dim, features_dim = features.shape - assert (batch_dim,) == self.batch_shape - if self.action_type == GraphActionType.ADD_NODE: - assert features_dim == self.nodes_features_dim - elif self.action_type == GraphActionType.ADD_EDGE: - assert features_dim == self.edge_features_dim - assert edge_index is not None - assert edge_index.shape == (2, batch_dim) - - self.features = features - self.edge_index = edge_index + self.action_type = action_type + + if features is None: + assert torch.all(action_type == GraphActionType.EXIT) + features = torch.zeros((*self.batch_shape, self.features_dim)) + if edge_index is None: + assert torch.all(action_type != GraphActionType.ADD_EDGE) + edge_index = torch.zeros((2, *self.batch_shape)) + + batch_dim, _ = features.shape + assert (batch_dim,) == self.batch_shape + assert edge_index.shape == (2, batch_dim) + self.features = features + self.edge_index = edge_index def __repr__(self): return f"""GraphAction object of type {self.action_type} and features of shape {self.features.shape}.""" @@ -228,17 +221,14 @@ def device(self) -> torch.device: def __len__(self) -> int: """Returns the number of actions in the batch.""" - if self.action_type == GraphActionType.EXIT: - raise ValueError("Cannot get the length of exit actions.") - else: - assert self.features is not None - return self.features.shape[0] + return prod(self.batch_shape) def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: """Get particular actions of the batch.""" - features = self.features[index] if self.features is not None else None - edge_index = self.edge_index[index] if self.edge_index is not None else None - return GraphActions(self.action_type, features, edge_index) + action_type = self.action_type[index] + features = self.features[index] + edge_index = self.edge_index[:, index] + return GraphActions(action_type, features, edge_index) def __setitem__( self, index: int | Sequence[int] | Sequence[bool], action: GraphActions @@ -246,7 +236,7 @@ def __setitem__( """Set particular actions of the batch.""" self.action_type[index] = action.action_type self.features[index] = action.features - self.edge_index[index] = action.edge_index + self.edge_index[:, index] = action.edge_index def compare(self, other: GraphActions) -> torch.Tensor: """Compares the actions to another GraphAction object. @@ -267,20 +257,15 @@ def compare(self, other: GraphActions) -> torch.Tensor: @property def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" - return torch.full( - (1,), - self.action_type == GraphActionType.EXIT, - dtype=torch.bool, - device=self.device, - ) + return self.action_type == GraphActionType.EXIT @classmethod def make_dummy_actions( cls, batch_shape: tuple[int] - ) -> GraphActions: # TODO: remove make_dummy_actions + ) -> GraphActions: """Creates an Actions object of dummy actions with the given batch shape.""" - return GraphActions( + return cls( action_type=torch.full(batch_shape, fill_value=GraphActionType.EXIT), - features=None, - edge_index=None, + #features=torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), + #edge_index=torch.zeros((2, *batch_shape, 0)), ) diff --git a/src/gfn/env.py b/src/gfn/env.py index 8db703b3..d47e4cfa 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -264,10 +264,12 @@ def _step( not_done_actions = actions[~new_sink_states_idx] new_not_done_states_tensor = self.step(not_done_states, not_done_actions) - if not isinstance(new_not_done_states_tensor, torch.Tensor): - raise Exception( - "User implemented env.step function *must* return a torch.Tensor!" - ) + + # TODO: uncomment (change Data to TensorDict) + # if not isinstance(new_not_done_states_tensor, torch.Tensor): + # raise Exception( + # "User implemented env.step function *must* return a torch.Tensor!" + # ) new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor @@ -568,8 +570,6 @@ class GraphEnv(Env): def __init__( self, s0: Data, - # node_feature_dim: int, - # edge_feature_dim: int, sf: Optional[Data] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, @@ -578,8 +578,6 @@ def __init__( Args: s0: The initial graph state. - node_feature_dim: The dimension of the node features. - edge_feature_dim: The dimension of the edge features. action_shape: Tuple representing the shape of the actions. dummy_action: Tensor of shape "action_shape" representing a dummy action. exit_action: Tensor of shape "action_shape" representing the exit action. @@ -591,10 +589,7 @@ def __init__( the IdentityPreprocessor is used. """ self.s0 = s0.to(device_str) - - self.node_feature_dim = s0.x.shape[1] - self.edge_feature_dim = s0.edge_attr.shape[1] - + self.features_dim = s0.x.shape[1] self.sf = sf self.States = self.make_states_class() @@ -609,8 +604,6 @@ def make_states_class(self) -> type[GraphStates]: class GraphEnvStates(GraphStates): s0 = env.s0 sf = env.sf - # node_feature_dim = env.node_feature_dim - # edge_feature_dim = env.edge_feature_dim make_random_states_graph = env.make_random_states_tensor return GraphEnvStates @@ -625,8 +618,7 @@ def make_actions_class(self) -> type[GraphActions]: env = self class DefaultGraphAction(GraphActions): - nodes_features_dim = env.node_feature_dim - edge_features_dim = env.edge_feature_dim + features_dim = env.features_dim return DefaultGraphAction diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index f8a4f624..67c6cbae 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -13,22 +13,21 @@ class GraphBuilding(GraphEnv): def __init__( self, - node_feature_dim: int, - edge_feature_dim: int, + feature_dim: int, state_evaluator: Callable[[Batch], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): s0 = Data( - x=torch.zeros((0, node_feature_dim), dtype=torch.float32), - edge_attr=torch.zeros((0, edge_feature_dim), dtype=torch.float32), + x=torch.zeros((0, feature_dim), dtype=torch.float32), + edge_attr=torch.zeros((0, feature_dim), dtype=torch.float32), edge_index=torch.zeros((2, 0), dtype=torch.long), ).to(device_str) sf = Data( - x=torch.ones((1, node_feature_dim), dtype=torch.float32) * float("inf"), + x=torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), ).to(device_str) if state_evaluator is None: - state_evaluator = GCNConvEvaluator(node_feature_dim) + state_evaluator = GCNConvEvaluator(feature_dim) self.state_evaluator = state_evaluator super().__init__( @@ -37,7 +36,7 @@ def __init__( device_str=device_str, ) - def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: + def step(self, states: GraphStates, actions: GraphActions) -> Data: """Step function for the GraphBuilding environment. Args: @@ -50,14 +49,17 @@ def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: raise NonValidActionsError("Invalid action.") graphs: Batch = deepcopy(states.data) - if actions.action_type == GraphActionType.ADD_NODE: + action_type = actions.action_type[0] + assert torch.all(actions.action_type == action_type) + + if action_type == GraphActionType.ADD_NODE: assert len(graphs) == len(actions) if graphs.x is None: graphs.x = actions.features else: graphs.x = torch.cat([graphs.x, actions.features]) - if actions.action_type == GraphActionType.ADD_EDGE: + if action_type == GraphActionType.ADD_EDGE: assert len(graphs) == len(actions) assert actions.edge_index is not None if graphs.edge_attr is None: @@ -70,7 +72,7 @@ def step(self, states: GraphStates, actions: GraphActions) -> GraphStates: [graphs.edge_index, actions.edge_index], dim=1 ) - return self.States(graphs) + return graphs def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: """Backward step function for the GraphBuilding environment. @@ -106,56 +108,57 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStat def is_action_valid( self, states: GraphStates, actions: GraphActions, backward: bool = False ) -> bool: - if actions.action_type == GraphActionType.EXIT: - return True # TODO: what are the conditions for exit action? - - if actions.action_type == GraphActionType.ADD_NODE: - if actions.edge_index is not None: - return False - if states.data.x is None: - return not backward - + add_node_mask = actions.action_type == GraphActionType.ADD_NODE + if not torch.any(add_node_mask): + add_node_out = True + else: equal_nodes_per_batch = torch.all( - states.data.x == actions.features[:, None], dim=-1 + states[add_node_mask].data.x == actions[add_node_mask].features[:, None], dim=-1 ).reshape(states.data.batch_size, -1) equal_nodes_per_batch = torch.sum(equal_nodes_per_batch, dim=-1) - if backward: # TODO: check if no edge are connected? - return torch.all(equal_nodes_per_batch == 1) - return torch.all(equal_nodes_per_batch == 0) - - if actions.action_type == GraphActionType.ADD_EDGE: - assert actions.edge_index is not None - if torch.any(actions.edge_index[0] == actions.edge_index[1]): + add_node_out = torch.all(equal_nodes_per_batch == 1) + else: + add_node_out = torch.all(equal_nodes_per_batch == 0) + + add_edge_mask = actions.action_type == GraphActionType.ADD_EDGE + if not torch.any(add_edge_mask): + add_edge_out = True + else: + add_edge_states = states[add_edge_mask] + add_edge_actions = actions[add_edge_mask] + + if torch.any(add_edge_actions.edge_index[0] == add_edge_actions.edge_index[1]): return False - if states.data.num_nodes is None or states.data.num_nodes == 0: + if add_edge_states.data.num_nodes == 0: return False - if torch.any(actions.edge_index > states.data.num_nodes): + if torch.any(add_edge_actions.edge_index > add_edge_states.data.num_nodes): return False - batch_dim = actions.features.shape[0] - batch_idx = actions.edge_index % batch_dim + batch_dim = add_edge_actions.features.shape[0] + batch_idx = add_edge_actions.edge_index % batch_dim if torch.any(batch_idx != torch.arange(batch_dim)): return False - if states.data.edge_attr is None: - return True equal_edges_per_batch_attr = torch.all( - states.data.edge_attr == actions.features[:, None], dim=-1 - ).reshape(states.data.batch_size, -1) + add_edge_states.data.edge_attr == add_edge_actions.features[:, None], dim=-1 + ).reshape(add_edge_states.data.batch_size, -1) equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) - equal_edges_per_batch_index = torch.all( - states.data.edge_index[:, None] == actions.edge_index[:, :, None], dim=0 - ).reshape(states.data.batch_size, -1) + add_edge_states.data.edge_index[:, None] == add_edge_actions.edge_index[:, :, None], dim=0 + ).reshape(add_edge_states.data.batch_size, -1) equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) + if backward: - return torch.all(equal_edges_per_batch_attr == 1) and torch.all( + add_edge_out = torch.all(equal_edges_per_batch_attr == 1) and torch.all( equal_edges_per_batch_index == 1 ) - return torch.all(equal_edges_per_batch_attr == 0) and torch.all( - equal_edges_per_batch_index == 0 - ) + else: + add_edge_out = torch.all(equal_edges_per_batch_attr == 0) and torch.all( + equal_edges_per_batch_index == 0 + ) + + return bool(add_node_out) and bool(add_edge_out) def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. @@ -180,7 +183,7 @@ def true_dist_pmf(self) -> torch.Tensor: raise NotImplementedError def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: - """Generates random states tensor of shape (*batch_shape, num_nodes, node_feature_dim).""" + """Generates random states tensor of shape (*batch_shape, feature_dim).""" return self.States.from_batch_shape(batch_shape) diff --git a/src/gfn/states.py b/src/gfn/states.py index b6c9bfb7..8e5ec735 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -491,16 +491,16 @@ class GraphStates(ABC): def __init__(self, graphs: Batch): self.data: Batch = graphs - self.batch_shape: int = len(self.data) + self.batch_shape: tuple = (len(self.data),) self._log_rewards: float = None # TODO logic repeated from env.is_valid_action not_empty = self.data.x is not None and self.data.x.shape[0] > 0 - self.forward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) + self.forward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) self.forward_masks[:, GraphActionType.ADD_EDGE] = not_empty self.forward_masks[:, GraphActionType.EXIT] = not_empty - self.backward_masks = torch.ones((self.batch_shape, 3), dtype=torch.bool) + self.backward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) self.backward_masks[:, GraphActionType.ADD_NODE] = not_empty self.backward_masks[:, GraphActionType.ADD_EDGE] = ( not_empty and self.data.edge_attr.shape[0] > 0 @@ -580,10 +580,13 @@ def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: idxs = np.arange(len(self.data))[index] - data = [] - for i in idxs: - data.append(self.data.get_example(i)) - + data = [self.data.get_example(i) for i in idxs] + if len(data) == 0: + data.append(Data( + x=torch.zeros((0, self.data.x.shape[1]), dtype=torch.float32), + edge_attr=torch.zeros((0, self.data.edge_attr.shape[1]), dtype=torch.float32), + edge_index=torch.zeros((2, 0), dtype=torch.long), + )) out = GraphStates(Batch.from_data_list(data)) if self._log_rewards is not None: @@ -666,3 +669,7 @@ def is_sink_state(self) -> torch.Tensor: return torch.all(self.data.x == self.sf.x, dim=-1).reshape( batch_dim, ) + + @property + def tensor(self) -> Batch: + return self.data diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index 421a04e9..1fa6aa89 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -43,7 +43,7 @@ def log_prob(self, sample: torch.Tensor) -> torch.Tensor: return super().log_prob(sample.squeeze(-1)) -class ComposedDistribution(Distribution): +class ComposedDistribution(Distribution): # TODO: CompositeDistribution in TensorDict """A mixture distribution.""" def __init__(self, dists: Dict[str, Distribution]): diff --git a/testing/test_environments.py b/testing/test_environments.py index 80198fd2..948742bf 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -325,7 +325,7 @@ def test_graph_env(): BATCH_SIZE = 3 NUM_NODES = 5 - env = GraphBuilding(node_feature_dim=FEATURE_DIM, edge_feature_dim=FEATURE_DIM) + env = GraphBuilding(feature_dim=FEATURE_DIM) states = env.reset(batch_shape=BATCH_SIZE) assert states.batch_shape == BATCH_SIZE action_cls = env.make_actions_class() diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 302f31b3..d8a1c260 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -225,7 +225,6 @@ def test_replay_buffer( # ------ GRAPH TESTS ------ - class GraphActionNet(nn.Module): def __init__(self, feature_dim: int): super().__init__() @@ -246,7 +245,7 @@ def forward(self, states: GraphStates) -> TensorDict: features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=0) edge_index = self.edge_index_conv(states.data.x, states.data.edge_index) - edge_index = edge_index.reshape(states.batch_shape, -1, 8) + edge_index = edge_index.reshape(*states.batch_shape, -1, 8) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) return TensorDict( @@ -261,7 +260,7 @@ def forward(self, states: GraphStates) -> TensorDict: def test_graph_building(): feature_dim = 8 - env = GraphBuilding(node_feature_dim=feature_dim, edge_feature_dim=feature_dim) + env = GraphBuilding(feature_dim=feature_dim) module = GraphActionNet(feature_dim) pf_estimator = GraphActionPolicyEstimator(module=module) From 34781efbd1cd420676682e62bf8d765996218fe3 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 28 Nov 2024 22:25:15 +0100 Subject: [PATCH 13/27] use tensordict in actions --- src/gfn/actions.py | 57 +++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 89119fa1..400a9480 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -6,6 +6,7 @@ from typing import ClassVar, Optional, Sequence import torch +from tensordict import TensorDict class Actions(ABC): @@ -196,8 +197,6 @@ def __init__( This must defined if and only if the action type is GraphActionType.AddEdge. """ self.batch_shape = action_type.shape - self.action_type = action_type - if features is None: assert torch.all(action_type == GraphActionType.EXIT) features = torch.zeros((*self.batch_shape, self.features_dim)) @@ -205,19 +204,19 @@ def __init__( assert torch.all(action_type != GraphActionType.ADD_EDGE) edge_index = torch.zeros((2, *self.batch_shape)) - batch_dim, _ = features.shape - assert (batch_dim,) == self.batch_shape - assert edge_index.shape == (2, batch_dim) - self.features = features - self.edge_index = edge_index + self.tensor = TensorDict({ + "action_type": action_type, + "features": features, + "edge_index": edge_index.T, + }, batch_size=self.batch_shape) def __repr__(self): - return f"""GraphAction object of type {self.action_type} and features of shape {self.features.shape}.""" + return f"""GraphAction object with {self.batch_shape} actions.""" @property def device(self) -> torch.device: """Returns the device of the features tensor.""" - return self.features.device + return self.tensor.device def __len__(self) -> int: """Returns the number of actions in the batch.""" @@ -225,18 +224,18 @@ def __len__(self) -> int: def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: """Get particular actions of the batch.""" - action_type = self.action_type[index] - features = self.features[index] - edge_index = self.edge_index[:, index] - return GraphActions(action_type, features, edge_index) + tensor = self.tensor[index] + return GraphActions( + tensor["action_type"], + tensor["features"], + tensor["edge_index"].T + ) def __setitem__( self, index: int | Sequence[int] | Sequence[bool], action: GraphActions ) -> None: """Set particular actions of the batch.""" - self.action_type[index] = action.action_type - self.features[index] = action.features - self.edge_index[:, index] = action.edge_index + self.tensor[index] = action.tensor def compare(self, other: GraphActions) -> torch.Tensor: """Compares the actions to another GraphAction object. @@ -246,19 +245,31 @@ def compare(self, other: GraphActions) -> torch.Tensor: Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. """ - if self.action_type != other.action_type: - len_ = self.features.shape[0] if self.features is not None else 1 - return torch.zeros(len_, dtype=torch.bool, device=self.device) - out = torch.all(self.features == other.features, dim=-1) - if self.edge_index is not None: - out &= torch.all(self.edge_index == other.edge_index, dim=-1) - return out + compare = torch.all(self.tensor == other.tensor, dim=-1) + return compare["action_type"] & \ + (compare["action_type"] == GraphActionType.EXIT | compare["features"]) & \ + (compare["action_type"] != GraphActionType.ADD_EDGE | compare["edge_index"]) @property def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" return self.action_type == GraphActionType.EXIT + @property + def action_type(self) -> torch.Tensor: + """Returns the action type tensor.""" + return self.tensor["action_type"] + + @property + def features(self) -> torch.Tensor: + """Returns the features tensor.""" + return self.tensor["features"] + + @property + def edge_index(self) -> torch.Tensor: + """Returns the edge index tensor.""" + return self.tensor["edge_index"].T + @classmethod def make_dummy_actions( cls, batch_shape: tuple[int] From 3e584f2168eec0a6688b0ccbba7f1b00bd68549f Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 2 Dec 2024 19:50:59 +0100 Subject: [PATCH 14/27] handle sf --- src/gfn/gym/graph_building.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 67c6cbae..f3448068 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -52,6 +52,9 @@ def step(self, states: GraphStates, actions: GraphActions) -> Data: action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) + if action_type == GraphActionType.EXIT: + return self.sf # TODO: not possible to backtrack then... maybe a boolen in state? + if action_type == GraphActionType.ADD_NODE: assert len(graphs) == len(actions) if graphs.x is None: @@ -72,6 +75,7 @@ def step(self, states: GraphStates, actions: GraphActions) -> Data: [graphs.edge_index, actions.edge_index], dim=1 ) + import pdb; pdb.set_trace() return graphs def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: From d5e438f80444f267f647b8af2a55dd8a2071b951 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 3 Dec 2024 15:27:40 +0100 Subject: [PATCH 15/27] remove Data --- src/gfn/actions.py | 8 +- src/gfn/env.py | 26 ++-- src/gfn/gym/graph_building.py | 90 +++++++------- src/gfn/modules.py | 6 +- src/gfn/samplers.py | 2 +- src/gfn/states.py | 142 +++++++++------------- src/gfn/utils/distributions.py | 24 ++++ testing/test_samplers_and_trajectories.py | 18 ++- 8 files changed, 154 insertions(+), 162 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 400a9480..36c52ae5 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -202,12 +202,12 @@ def __init__( features = torch.zeros((*self.batch_shape, self.features_dim)) if edge_index is None: assert torch.all(action_type != GraphActionType.ADD_EDGE) - edge_index = torch.zeros((2, *self.batch_shape)) + edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long) self.tensor = TensorDict({ "action_type": action_type, "features": features, - "edge_index": edge_index.T, + "edge_index": edge_index, }, batch_size=self.batch_shape) def __repr__(self): @@ -228,7 +228,7 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActio return GraphActions( tensor["action_type"], tensor["features"], - tensor["edge_index"].T + tensor["edge_index"] ) def __setitem__( @@ -268,7 +268,7 @@ def features(self) -> torch.Tensor: @property def edge_index(self) -> torch.Tensor: """Returns the edge index tensor.""" - return self.tensor["edge_index"].T + return self.tensor["edge_index"] @classmethod def make_dummy_actions( diff --git a/src/gfn/env.py b/src/gfn/env.py index d47e4cfa..e09c5b79 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -2,7 +2,7 @@ from typing import Dict, Optional, Tuple, Union import torch -from torch_geometric.data import Batch, Data +from tensordict import TensorDict from gfn.actions import Actions, GraphActions from gfn.preprocessors import IdentityPreprocessor, Preprocessor @@ -256,7 +256,8 @@ def _step( ) new_sink_states_idx = actions.is_exit - new_states.tensor[new_sink_states_idx] = self.sf + sf_tensor = self.States.make_sink_states_tensor(new_sink_states_idx.sum()) + new_states[new_sink_states_idx] = self.States(sf_tensor) new_sink_states_idx = ~valid_states_idx | new_sink_states_idx assert new_sink_states_idx.shape == states.batch_shape @@ -265,14 +266,12 @@ def _step( new_not_done_states_tensor = self.step(not_done_states, not_done_actions) - # TODO: uncomment (change Data to TensorDict) - # if not isinstance(new_not_done_states_tensor, torch.Tensor): - # raise Exception( - # "User implemented env.step function *must* return a torch.Tensor!" - # ) - - new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor + if not isinstance(new_not_done_states_tensor, (torch.Tensor, TensorDict)): + raise Exception( + "User implemented env.step function *must* return a torch.Tensor!" + ) + new_states[~new_sink_states_idx] = self.States(new_not_done_states_tensor) return new_states def _backward_step( @@ -569,8 +568,8 @@ class GraphEnv(Env): def __init__( self, - s0: Data, - sf: Optional[Data] = None, + s0: TensorDict, + sf: Optional[TensorDict] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, ): @@ -578,9 +577,6 @@ def __init__( Args: s0: The initial graph state. - action_shape: Tuple representing the shape of the actions. - dummy_action: Tensor of shape "action_shape" representing a dummy action. - exit_action: Tensor of shape "action_shape" representing the exit action. sf: The final graph state. device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is inferred from s0. @@ -589,7 +585,7 @@ def __init__( the IdentityPreprocessor is used. """ self.s0 = s0.to(device_str) - self.features_dim = s0.x.shape[1] + self.features_dim = s0["node_feature"].shape[-1] self.sf = sf self.States = self.make_states_class() diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index f3448068..12fa8943 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -2,8 +2,8 @@ from typing import Callable, Literal, Tuple import torch -from torch_geometric.data import Batch, Data from torch_geometric.nn import GCNConv +from tensordict import TensorDict from gfn.actions import GraphActions, GraphActionType from gfn.env import GraphEnv, NonValidActionsError @@ -14,17 +14,19 @@ class GraphBuilding(GraphEnv): def __init__( self, feature_dim: int, - state_evaluator: Callable[[Batch], torch.Tensor] | None = None, + state_evaluator: Callable[[GraphStates], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): - s0 = Data( - x=torch.zeros((0, feature_dim), dtype=torch.float32), - edge_attr=torch.zeros((0, feature_dim), dtype=torch.float32), - edge_index=torch.zeros((2, 0), dtype=torch.long), - ).to(device_str) - sf = Data( - x=torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), - ).to(device_str) + s0 = TensorDict({ + "node_feature": torch.zeros((0, feature_dim), dtype=torch.float32), + "edge_feature": torch.zeros((0, feature_dim), dtype=torch.float32), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, device=device_str) + sf = TensorDict({ + "node_feature": torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), + "edge_feature": torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, device=device_str) if state_evaluator is None: state_evaluator = GCNConvEvaluator(feature_dim) @@ -36,7 +38,7 @@ def __init__( device_str=device_str, ) - def step(self, states: GraphStates, actions: GraphActions) -> Data: + def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: """Step function for the GraphBuilding environment. Args: @@ -47,36 +49,24 @@ def step(self, states: GraphStates, actions: GraphActions) -> Data: """ if not self.is_action_valid(states, actions): raise NonValidActionsError("Invalid action.") - graphs: Batch = deepcopy(states.data) + state_tensor = deepcopy(states.tensor) action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) - if action_type == GraphActionType.EXIT: - return self.sf # TODO: not possible to backtrack then... maybe a boolen in state? + return self.States.make_sink_states_tensor(states.batch_shape) if action_type == GraphActionType.ADD_NODE: - assert len(graphs) == len(actions) - if graphs.x is None: - graphs.x = actions.features - else: - graphs.x = torch.cat([graphs.x, actions.features]) + assert len(state_tensor) == len(actions) + state_tensor["node_feature"] = torch.cat([state_tensor["node_feature"], actions.features[:, None]], dim=1) if action_type == GraphActionType.ADD_EDGE: - assert len(graphs) == len(actions) - assert actions.edge_index is not None - if graphs.edge_attr is None: - graphs.edge_attr = actions.features - assert graphs.edge_index is None - graphs.edge_index = actions.edge_index - else: - graphs.edge_attr = torch.cat([graphs.edge_attr, actions.features]) - graphs.edge_index = torch.cat( - [graphs.edge_index, actions.edge_index], dim=1 - ) - - import pdb; pdb.set_trace() - return graphs + assert len(state_tensor) == len(actions) + state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features[:, None]], dim=1) + state_tensor["edge_index"] = torch.cat( + [state_tensor["edge_index"], actions.edge_index[:, None]], dim=1 + ) + return state_tensor def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: """Backward step function for the GraphBuilding environment. @@ -116,9 +106,10 @@ def is_action_valid( if not torch.any(add_node_mask): add_node_out = True else: + node_feature = states.tensor["node_feature"][add_node_mask] equal_nodes_per_batch = torch.all( - states[add_node_mask].data.x == actions[add_node_mask].features[:, None], dim=-1 - ).reshape(states.data.batch_size, -1) + node_feature == actions[add_node_mask].features[:, None], dim=-1 + ).reshape(len(node_feature), -1) equal_nodes_per_batch = torch.sum(equal_nodes_per_batch, dim=-1) if backward: # TODO: check if no edge are connected? add_node_out = torch.all(equal_nodes_per_batch == 1) @@ -129,14 +120,14 @@ def is_action_valid( if not torch.any(add_edge_mask): add_edge_out = True else: - add_edge_states = states[add_edge_mask] + add_edge_states = states[add_edge_mask].tensor add_edge_actions = actions[add_edge_mask] - if torch.any(add_edge_actions.edge_index[0] == add_edge_actions.edge_index[1]): + if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): return False - if add_edge_states.data.num_nodes == 0: + if add_edge_states["node_feature"].shape[1] == 0: return False - if torch.any(add_edge_actions.edge_index > add_edge_states.data.num_nodes): + if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[1]): return False batch_dim = add_edge_actions.features.shape[0] @@ -145,12 +136,12 @@ def is_action_valid( return False equal_edges_per_batch_attr = torch.all( - add_edge_states.data.edge_attr == add_edge_actions.features[:, None], dim=-1 - ).reshape(add_edge_states.data.batch_size, -1) + add_edge_states["edge_feature"] == add_edge_actions.features[:, None], dim=-1 + ).reshape(len(add_edge_states), -1) equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) equal_edges_per_batch_index = torch.all( - add_edge_states.data.edge_index[:, None] == add_edge_actions.edge_index[:, :, None], dim=0 - ).reshape(add_edge_states.data.batch_size, -1) + add_edge_states["edge_index"] == add_edge_actions.edge_index, dim=0 + ).reshape(len(add_edge_states), -1) equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) if backward: @@ -161,7 +152,7 @@ def is_action_valid( add_edge_out = torch.all(equal_edges_per_batch_attr == 0) and torch.all( equal_edges_per_batch_index == 0 ) - + return bool(add_node_out) and bool(add_edge_out) def reward(self, final_states: GraphStates) -> torch.Tensor: @@ -174,7 +165,7 @@ def reward(self, final_states: GraphStates) -> torch.Tensor: Returns: torch.Tensor: Tensor of shape "batch_shape" containing the rewards. """ - return self.state_evaluator(final_states.data) + return self.state_evaluator(final_states) @property def log_partition(self) -> float: @@ -193,9 +184,12 @@ def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: class GCNConvEvaluator: def __init__(self, num_features): + self.num_features = num_features self.net = GCNConv(num_features, 1) - def __call__(self, batch: Batch) -> torch.Tensor: - out = self.net(batch.x, batch.edge_index) - out = out.reshape(batch.batch_size, -1) + def __call__(self, state: GraphStates) -> torch.Tensor: + node_feature = state.tensor["node_feature"].reshape(-1, self.num_features) + edge_index = state.tensor["edge_index"].reshape(-1, 2).T + out = self.net(node_feature, edge_index) + out = out.reshape(len(state), state.tensor["node_feature"].shape[1]) return out.mean(-1) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 90087e5a..9d0f52f7 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -9,7 +9,7 @@ from gfn.actions import GraphActionType from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, GraphStates, States -from gfn.utils.distributions import ComposedDistribution, UnsqueezedCategorical +from gfn.utils.distributions import CategoricalIndexes, ComposedDistribution, UnsqueezedCategorical REDUCTION_FXNS = { "mean": torch.mean, @@ -518,7 +518,7 @@ def to_probability_distribution( dists["action_type"] = Categorical(probs=action_type_probs) edge_index_logits = module_output["edge_index"] - if edge_index_logits.shape[-1] != 0: + if states.tensor["node_feature"].shape[1] > 1: edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) uniform_dist_probs = ( torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] @@ -526,7 +526,7 @@ def to_probability_distribution( edge_index_probs = ( 1 - epsilon ) * edge_index_probs + epsilon * uniform_dist_probs - dists["edge_index"] = UnsqueezedCategorical(probs=edge_index_probs) + dists["edge_index"] = CategoricalIndexes(probs=edge_index_probs) dists["features"] = Normal(module_output["features"], temperature) return ComposedDistribution(dists=dists) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 4b706a39..083d8792 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -167,7 +167,7 @@ def sample_trajectories( step = 0 all_estimator_outputs = [] - + while not all(dones): actions = env.actions_from_batch_shape((n_trajectories,)) # Dummy actions. log_probs = torch.full( diff --git a/src/gfn/states.py b/src/gfn/states.py index 8e5ec735..81762909 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -7,6 +7,7 @@ import numpy as np import torch +from tensordict import TensorDict from torch_geometric.data import Batch, Data from gfn.actions import GraphActionType @@ -486,16 +487,19 @@ class GraphStates(ABC): graph objects as states. """ - s0: ClassVar[Data] - sf: ClassVar[Optional[Data]] + s0: ClassVar[TensorDict] + sf: ClassVar[TensorDict] - def __init__(self, graphs: Batch): - self.data: Batch = graphs - self.batch_shape: tuple = (len(self.data),) + def __init__(self, tensor: TensorDict): + self.tensor = tensor + self.node_features_dim = tensor["node_feature"].shape[-1] + self.edge_features_dim = tensor["edge_feature"].shape[-1] + + self.batch_shape: tuple = tensor.batch_size self._log_rewards: float = None # TODO logic repeated from env.is_valid_action - not_empty = self.data.x is not None and self.data.x.shape[0] > 0 + not_empty = self.tensor["node_feature"].shape[1] > 1 self.forward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) self.forward_masks[:, GraphActionType.ADD_EDGE] = not_empty self.forward_masks[:, GraphActionType.EXIT] = not_empty @@ -503,7 +507,7 @@ def __init__(self, graphs: Batch): self.backward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) self.backward_masks[:, GraphActionType.ADD_NODE] = not_empty self.backward_masks[:, GraphActionType.ADD_EDGE] = ( - not_empty and self.data.edge_attr.shape[0] > 0 + not_empty and self.tensor["edge_feature"].shape[1] > 0 > 0 ) self.backward_masks[:, GraphActionType.EXIT] = not_empty @@ -514,15 +518,15 @@ def from_batch_shape( if random and sink: raise ValueError("Only one of `random` and `sink` should be True.") if random: - data = cls.make_random_states_graph(batch_shape) + tensor = cls.make_random_states_tensor(batch_shape) elif sink: - data = cls.make_sink_states_graph(batch_shape) + tensor = cls.make_sink_states_tensor(batch_shape) else: - data = cls.make_initial_states_graph(batch_shape) - return cls(data) + tensor = cls.make_initial_states_tensor(batch_shape) + return cls(tensor) @classmethod - def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: + def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: raise NotImplementedError( "Batch shape with more than one dimension is not supported" @@ -530,11 +534,14 @@ def make_initial_states_graph(cls, batch_shape: int | Tuple) -> Batch: if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] - data = Batch.from_data_list([cls.s0 for _ in range(batch_shape)]) - return data + return TensorDict({ + "node_feature": cls.s0["node_feature"].repeat(batch_shape, 1, 1), + "edge_feature": cls.s0["edge_feature"].repeat(batch_shape, 1, 1), + "edge_index": cls.s0["edge_index"].repeat(batch_shape, 1, 1) + }, batch_size=batch_shape) @classmethod - def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: + def make_sink_states_tensor(cls, batch_shape: Tuple) -> TensorDict: if cls.sf is None: raise NotImplementedError("Sink state is not defined") @@ -545,11 +552,14 @@ def make_sink_states_graph(cls, batch_shape: Tuple) -> Batch: if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] - data = Batch.from_data_list([cls.sf for _ in range(batch_shape)]) - return data + return TensorDict({ + "node_feature": cls.sf["node_feature"].repeat(batch_shape, 1, 1), + "edge_feature": cls.sf["edge_feature"].repeat(batch_shape, 1, 1), + "edge_index": cls.sf["edge_index"].repeat(batch_shape, 1, 1) + }, batch_size=int(batch_shape)) @classmethod - def make_random_states_graph(cls, batch_shape: int) -> Batch: + def make_random_states_tensor(cls, batch_shape: int) -> TensorDict: if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: raise NotImplementedError( "Batch shape with more than one dimension is not supported" @@ -557,37 +567,30 @@ def make_random_states_graph(cls, batch_shape: int) -> Batch: if isinstance(batch_shape, Tuple): batch_shape = batch_shape[0] - data_list = [] - for _ in range(batch_shape): - data = Data( - x=torch.rand(cls.s0.num_nodes, cls.s0.x.shape[1]), - edge_attr=torch.rand(cls.s0.num_edges, cls.s0.edge_attr.shape[1]), - edge_index=cls.s0.edge_index, # TODO: make it random - ) - data_list.append(data) - return Batch.from_data_list(data_list) + num_nodes = np.random.randint(10) + num_edges = np.random.randint(num_nodes * (num_nodes - 1) // 2) + node_features_dim = cls.s0["node_feature"].shape[-1] + edge_features_dim = cls.s0["edge_feature"].shape[-1] + tensor = TensorDict({ + "node_feature": torch.rand(batch_shape, num_nodes, node_features_dim), + "edge_feature": torch.rand(batch_shape, num_edges, edge_features_dim), + "edge_index": torch.randint(num_nodes, size=(batch_shape, num_edges, 2)), + }) + return tensor def __len__(self): - return self.data.batch_size + return np.prod(self.tensor.batch_size) def __repr__(self): return ( f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " - f"node feature dim {self.s0.x.shape[1]} and edge feature dim {self.s0.edge_attr.shape[1]}" + f"node feature dim {self.node_features_dim} and edge feature dim {self.edge_features_dim}" ) def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: - idxs = np.arange(len(self.data))[index] - data = [self.data.get_example(i) for i in idxs] - if len(data) == 0: - data.append(Data( - x=torch.zeros((0, self.data.x.shape[1]), dtype=torch.float32), - edge_attr=torch.zeros((0, self.data.edge_attr.shape[1]), dtype=torch.float32), - edge_index=torch.zeros((2, 0), dtype=torch.long), - )) - out = GraphStates(Batch.from_data_list(data)) + out = GraphStates(self.tensor[index]) if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] @@ -598,44 +601,21 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): """ Set particular states of the Batch """ - data_list = self.data.to_data_list() - if isinstance(index, int): - assert ( - len(graph) == 1 - ), "GraphStates must have a batch size of 1 for single index assignment" - data_list[index] = graph.data[0] - self.data = Batch.from_data_list(data_list) - elif isinstance(index, Sequence): - assert len(index) == len( - graph - ), "Index and GraphState must have the same length" - for i, idx in enumerate(index): - data_list[idx] = graph.data[i] - self.data = Batch.from_data_list(data_list) - elif isinstance(index, slice): - assert index.stop - index.start == len( - graph - ), "Index slice and GraphStates must have the same length" - data_list[index] = graph.data.to_data_list() - self.data = Batch.from_data_list(data_list) - else: - raise NotImplementedError( - "Setters with type {} is not implemented".format(type(index)) - ) + len_index = len(self.tensor[index]) + if len_index != 0 and len_index != len(self.tensor): + raise ValueError("Can only set states with the same batch size as the original batch") + + self.tensor[index] = graph.tensor @property def device(self) -> torch.device: - sample = self.data.get_example(0).x - if sample is not None: - return sample.device - return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return self.tensor.device def to(self, device: torch.device) -> GraphStates: """ Moves and/or casts the graph states to the specified device """ - if self.device != device: - self.data = self.data.to(device) + self.tensor = self.tensor.to(device) return self def clone(self) -> GraphStates: @@ -644,14 +624,9 @@ def clone(self) -> GraphStates: def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" - self.data = Batch.from_data_list( - self.data.to_data_list() + other.data.to_data_list() - ) - if self._log_rewards is not None: - assert other._log_rewards is not None - self._log_rewards = torch.cat( - (self._log_rewards, other._log_rewards), dim=0 - ) + self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=1) + self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=1) + self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"]], dim=1) @property def log_rewards(self) -> torch.Tensor: @@ -663,13 +638,10 @@ def log_rewards(self, log_rewards: torch.Tensor) -> None: @property def is_sink_state(self) -> torch.Tensor: - batch_dim = len(self.data.ptr) - 1 - if len(self.data.x) == 0: - return torch.zeros(batch_dim, dtype=torch.bool) - return torch.all(self.data.x == self.sf.x, dim=-1).reshape( - batch_dim, + if self.tensor["node_feature"].shape[1] == 0: + return torch.zeros(self.batch_shape, dtype=torch.bool) + return ( + torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=(1, 2)) & + torch.all(self.tensor["edge_feature"] == self.sf["edge_feature"], dim=(1, 2)) & + torch.all(self.tensor["edge_index"] == self.sf["edge_index"], dim=(1, 2)) ) - - @property - def tensor(self) -> Batch: - return self.data diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index 1fa6aa89..e1e387c6 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -64,3 +64,27 @@ def log_prob(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: for k, v in self.dists.items() ] return sum(log_probs) + + +class CategoricalIndexes(Categorical): + """Samples indexes from a categorical distribution.""" + + def __init__(self, probs: torch.Tensor): + """Initializes the distribution. + + Args: + probs: The probabilities of the categorical distribution. + """ + self.n = probs.shape[-1] + batch_size = probs.shape[0] + assert probs.shape == (batch_size, self.n, self.n) + super().__init__(probs.reshape(batch_size, self.n * self.n)) + + def sample(self, sample_shape=torch.Size()) -> torch.Tensor: + samples = super().sample(sample_shape) + out = torch.stack([samples // self.n, samples % self.n], dim=-1) + return out + + def log_prob(self, value): + value = value[..., 0] * self.n + value[..., 1] + return super().log_prob(value) \ No newline at end of file diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index d8a1c260..96911165 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -234,19 +234,24 @@ def __init__(self, feature_dim: int): self.edge_index_conv = GCNConv(feature_dim, 8) def forward(self, states: GraphStates) -> TensorDict: - if len(states.data.x) == 0: + node_feature = states.tensor["node_feature"].reshape(-1, self.feature_dim) + edge_index = states.tensor["edge_index"].reshape(-1, 2).T + + if states.tensor["node_feature"].shape[1] == 0: action_type = torch.zeros((len(states), len(GraphActionType))) action_type[:, GraphActionType.ADD_NODE] = 1 features = torch.zeros((len(states), self.feature_dim)) else: - action_type = self.action_type_conv(states.data.x, states.data.edge_index) - action_type = torch.mean(action_type, dim=0) - features = self.features_conv(states.data.x, states.data.edge_index) - features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=0) + action_type = self.action_type_conv(node_feature, edge_index) + action_type = action_type.reshape(len(states), -1, action_type.shape[-1]).mean(dim=1) + action_type = action_type.mean(dim=0).expand(len(states), -1) + features = self.features_conv(node_feature, edge_index) + features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=1) - edge_index = self.edge_index_conv(states.data.x, states.data.edge_index) + edge_index = self.edge_index_conv(node_feature, edge_index) edge_index = edge_index.reshape(*states.batch_shape, -1, 8) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) + torch.diagonal(edge_index, dim1=-2, dim2=-1).fill_(float("-inf")) return TensorDict( { @@ -259,6 +264,7 @@ def forward(self, states: GraphStates) -> TensorDict: def test_graph_building(): + torch.manual_seed(7) feature_dim = 8 env = GraphBuilding(feature_dim=feature_dim) From fba5d509d64c72af5d3cb0f920f497944705b2f7 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 6 Dec 2024 15:46:45 +0100 Subject: [PATCH 16/27] categorical action type --- src/gfn/containers/trajectories.py | 2 +- src/gfn/gym/graph_building.py | 2 ++ src/gfn/modules.py | 4 ++-- src/gfn/samplers.py | 9 ++++----- src/gfn/states.py | 8 +++++--- src/gfn/utils/distributions.py | 16 +++++++++++++++- testing/test_samplers_and_trajectories.py | 9 ++++++--- 7 files changed, 35 insertions(+), 15 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 5feb665a..c56cd1ee 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -76,7 +76,7 @@ def __init__( self.states = ( states if states is not None else env.states_from_batch_shape((0, 0)) ) - assert len(self.states.batch_shape) == 2 + assert len(self.states.batch_shape) == 2, self.states.batch_shape self.actions = ( actions if actions is not None else env.actions_from_batch_shape((0, 0)) ) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 12fa8943..728f1fd5 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -50,6 +50,8 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: if not self.is_action_valid(states, actions): raise NonValidActionsError("Invalid action.") state_tensor = deepcopy(states.tensor) + if len(actions) == 0: + return state_tensor action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 9d0f52f7..0500a157 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -9,7 +9,7 @@ from gfn.actions import GraphActionType from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, GraphStates, States -from gfn.utils.distributions import CategoricalIndexes, ComposedDistribution, UnsqueezedCategorical +from gfn.utils.distributions import CategoricalActionType, CategoricalIndexes, ComposedDistribution, UnsqueezedCategorical REDUCTION_FXNS = { "mean": torch.mean, @@ -515,7 +515,7 @@ def to_probability_distribution( action_type_probs = ( 1 - epsilon ) * action_type_probs + epsilon * uniform_dist_probs - dists["action_type"] = Categorical(probs=action_type_probs) + dists["action_type"] = CategoricalActionType(probs=action_type_probs) edge_index_logits = module_output["edge_index"] if states.tensor["node_feature"].shape[1] > 1: diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 083d8792..c8289955 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -194,11 +194,9 @@ def sample_trajectories( # Place estimator outputs into a stackable tensor. Note that this # will be replaced with torch.nested.nested_tensor in the future. estimator_outputs_padded = torch.full_like( - estimator_outputs.expand( - (n_trajectories,) + estimator_outputs.shape[1:] - ), - fill_value=-float("inf"), - ).clone() # TODO: inefficient + estimator_outputs.expand((n_trajectories,) + estimator_outputs.shape[1:]).clone(), + fill_value=-float("inf") + ) estimator_outputs_padded[~dones] = estimator_outputs all_estimator_outputs.append(estimator_outputs_padded) @@ -243,6 +241,7 @@ def sample_trajectories( states = new_states dones = dones | new_dones + import pdb; pdb.set_trace() trajectories_states.extend(deepcopy(states)) trajectories_logprobs = ( diff --git a/src/gfn/states.py b/src/gfn/states.py index 81762909..6df93eca 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -602,11 +602,13 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): Set particular states of the Batch """ len_index = len(self.tensor[index]) - if len_index != 0 and len_index != len(self.tensor): + if len_index == 0: + return + elif len_index == len(self.tensor): + self.tensor = graph.tensor + else: # TODO: fix this raise ValueError("Can only set states with the same batch size as the original batch") - self.tensor[index] = graph.tensor - @property def device(self) -> torch.device: return self.tensor.device diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index e1e387c6..e6f6beaa 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -87,4 +87,18 @@ def sample(self, sample_shape=torch.Size()) -> torch.Tensor: def log_prob(self, value): value = value[..., 0] * self.n + value[..., 1] - return super().log_prob(value) \ No newline at end of file + return super().log_prob(value) + + +class CategoricalActionType(Categorical): # TODO: remove, just to sample 1 action_type + + def __init__(self, probs: torch.Tensor): + self.batch_len = len(probs) + super().__init__(probs[0]) + + def sample(self, sample_shape=torch.Size()) -> torch.Tensor: + samples = super().sample(sample_shape) + return samples.repeat(self.batch_len) + + def log_prob(self, value): + return super().log_prob(value[0]).repeat(self.batch_len) \ No newline at end of file diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 96911165..e3ffef3f 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -244,7 +244,6 @@ def forward(self, states: GraphStates) -> TensorDict: else: action_type = self.action_type_conv(node_feature, edge_index) action_type = action_type.reshape(len(states), -1, action_type.shape[-1]).mean(dim=1) - action_type = action_type.mean(dim=0).expand(len(states), -1) features = self.features_conv(node_feature, edge_index) features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=1) @@ -274,7 +273,11 @@ def test_graph_building(): sampler = Sampler(estimator=pf_estimator) trajectories = sampler.sample_trajectories( env, - n=5, + n=7, save_logprobs=True, - save_estimator_outputs=True, + save_estimator_outputs=False, ) + + +if __name__ == "__main__": + test_graph_building() \ No newline at end of file From 478bd148cb68d8ec49e87e8eaaf8008645e5f126 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 10 Dec 2024 16:46:06 +0100 Subject: [PATCH 17/27] change batching --- src/gfn/gym/graph_building.py | 68 ++++-- src/gfn/modules.py | 2 +- src/gfn/samplers.py | 1 - src/gfn/states.py | 239 ++++++++++++++++------ testing/test_samplers_and_trajectories.py | 4 +- 5 files changed, 230 insertions(+), 84 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 728f1fd5..67d3d316 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -59,14 +59,14 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: return self.States.make_sink_states_tensor(states.batch_shape) if action_type == GraphActionType.ADD_NODE: - assert len(state_tensor) == len(actions) - state_tensor["node_feature"] = torch.cat([state_tensor["node_feature"], actions.features[:, None]], dim=1) + batch_indices = torch.arange(len(states))[actions.action_type == GraphActionType.ADD_NODE] + state_tensor = self._add_node(state_tensor, batch_indices, actions.features) if action_type == GraphActionType.ADD_EDGE: assert len(state_tensor) == len(actions) - state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features[:, None]], dim=1) + state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features], dim=0) state_tensor["edge_index"] = torch.cat( - [state_tensor["edge_index"], actions.edge_index[:, None]], dim=1 + [state_tensor["edge_index"], actions.edge_index], dim=0 ) return state_tensor @@ -108,11 +108,10 @@ def is_action_valid( if not torch.any(add_node_mask): add_node_out = True else: - node_feature = states.tensor["node_feature"][add_node_mask] + node_feature = states[add_node_mask].tensor["node_feature"] equal_nodes_per_batch = torch.all( node_feature == actions[add_node_mask].features[:, None], dim=-1 - ).reshape(len(node_feature), -1) - equal_nodes_per_batch = torch.sum(equal_nodes_per_batch, dim=-1) + ).reshape(-1) if backward: # TODO: check if no edge are connected? add_node_out = torch.all(equal_nodes_per_batch == 1) else: @@ -127,9 +126,9 @@ def is_action_valid( if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): return False - if add_edge_states["node_feature"].shape[1] == 0: + if add_edge_states["node_feature"].shape[0] == 0: return False - if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[1]): + if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): return False batch_dim = add_edge_actions.features.shape[0] @@ -156,6 +155,47 @@ def is_action_valid( ) return bool(add_node_out) and bool(add_edge_out) + + def _add_node(self, tensor_dict: TensorDict, batch_indices: torch.Tensor, nodes_to_add: torch.Tensor) -> TensorDict: + if isinstance(batch_indices, list): + batch_indices = torch.tensor(batch_indices) + if len(batch_indices) != len(nodes_to_add): + raise ValueError("Number of batch indices must match number of node feature lists") + + modified_dict = tensor_dict.clone() + node_feature_dim = modified_dict['node_feature'].shape[1] + edge_feature_dim = modified_dict['edge_feature'].shape[1] + + for graph_idx, new_nodes in zip(batch_indices, nodes_to_add): + start_ptr = tensor_dict['batch_ptr'][graph_idx] + end_ptr = tensor_dict['batch_ptr'][graph_idx + 1] + num_original_nodes = end_ptr - start_ptr + + if new_nodes.ndim == 1: + new_nodes = new_nodes.unsqueeze(0) + if new_nodes.shape[1] != node_feature_dim: + raise ValueError(f"Node features must have dimension {node_feature_dim}") + + # Update batch pointers for subsequent graphs + shift = new_nodes.shape[0] + modified_dict['batch_ptr'][graph_idx + 1:] += shift + + # Expand node features + original_nodes = modified_dict['node_feature'][start_ptr:end_ptr] + modified_dict['node_feature'] = torch.cat([ + modified_dict['node_feature'][:end_ptr], + new_nodes, + modified_dict['node_feature'][end_ptr:] + ]) + + # Update edge indices + # Increment indices for edges after the current graph + edge_mask_0 = modified_dict['edge_index'][:, 0] >= end_ptr + edge_mask_1 = modified_dict['edge_index'][:, 1] >= end_ptr + modified_dict['edge_index'][edge_mask_0, 0] += shift + modified_dict['edge_index'][edge_mask_1, 1] += shift + + return modified_dict def reward(self, final_states: GraphStates) -> torch.Tensor: """The environment's reward given a state. @@ -186,12 +226,14 @@ def make_random_states_tensor(self, batch_shape: Tuple) -> GraphStates: class GCNConvEvaluator: def __init__(self, num_features): - self.num_features = num_features self.net = GCNConv(num_features, 1) def __call__(self, state: GraphStates) -> torch.Tensor: - node_feature = state.tensor["node_feature"].reshape(-1, self.num_features) - edge_index = state.tensor["edge_index"].reshape(-1, 2).T + node_feature = state.tensor["node_feature"] + edge_index = state.tensor["edge_index"].T + if len(node_feature) == 0: + return torch.zeros(len(state)) + out = self.net(node_feature, edge_index) - out = out.reshape(len(state), state.tensor["node_feature"].shape[1]) + out = out.reshape(*state.batch_shape, -1) return out.mean(-1) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 0500a157..86a345d4 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -518,7 +518,7 @@ def to_probability_distribution( dists["action_type"] = CategoricalActionType(probs=action_type_probs) edge_index_logits = module_output["edge_index"] - if states.tensor["node_feature"].shape[1] > 1: + if states.tensor["node_feature"].shape[0] > 1 and torch.any(edge_index_logits != -float("inf")): edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) uniform_dist_probs = ( torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index c8289955..3085b697 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -241,7 +241,6 @@ def sample_trajectories( states = new_states dones = dones | new_dones - import pdb; pdb.set_trace() trajectories_states.extend(deepcopy(states)) trajectories_logprobs = ( diff --git a/src/gfn/states.py b/src/gfn/states.py index 6df93eca..f91bec88 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,7 +3,7 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, Optional, Sequence, Tuple +from typing import Callable, ClassVar, List, Optional, Sequence, Tuple import numpy as np import torch @@ -495,25 +495,25 @@ def __init__(self, tensor: TensorDict): self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] - self.batch_shape: tuple = tensor.batch_size + self.batch_shape: tuple = tuple(tensor["batch_shape"].tolist()) self._log_rewards: float = None # TODO logic repeated from env.is_valid_action - not_empty = self.tensor["node_feature"].shape[1] > 1 - self.forward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) - self.forward_masks[:, GraphActionType.ADD_EDGE] = not_empty - self.forward_masks[:, GraphActionType.EXIT] = not_empty - - self.backward_masks = torch.ones((*self.batch_shape, 3), dtype=torch.bool) - self.backward_masks[:, GraphActionType.ADD_NODE] = not_empty - self.backward_masks[:, GraphActionType.ADD_EDGE] = ( - not_empty and self.tensor["edge_feature"].shape[1] > 0 > 0 - ) - self.backward_masks[:, GraphActionType.EXIT] = not_empty + not_empty = self.tensor["batch_ptr"][:-1] + 1 < self.tensor["batch_ptr"][1:] + self.forward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) + self.forward_masks[..., GraphActionType.ADD_EDGE] = not_empty + self.forward_masks[..., GraphActionType.EXIT] = not_empty + self.forward_masks = self.forward_masks.view(*self.batch_shape, 3) + + self.backward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) + self.backward_masks[..., GraphActionType.ADD_NODE] = not_empty + self.backward_masks[..., GraphActionType.ADD_EDGE] = not_empty # TODO: check at least one edge is present + self.backward_masks[..., GraphActionType.EXIT] = not_empty + self.backward_masks = self.backward_masks.view(*self.batch_shape, 3) @classmethod def from_batch_shape( - cls, batch_shape: int, random: bool = False, sink: bool = False + cls, batch_shape: int | Tuple, random: bool = False, sink: bool = False ) -> GraphStates: if random and sink: raise ValueError("Only one of `random` and `sink` should be True.") @@ -527,71 +527,100 @@ def from_batch_shape( @classmethod def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: - if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError( - "Batch shape with more than one dimension is not supported" - ) - if isinstance(batch_shape, Tuple): - batch_shape = batch_shape[0] + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) return TensorDict({ - "node_feature": cls.s0["node_feature"].repeat(batch_shape, 1, 1), - "edge_feature": cls.s0["edge_feature"].repeat(batch_shape, 1, 1), - "edge_index": cls.s0["edge_index"].repeat(batch_shape, 1, 1) - }, batch_size=batch_shape) + "node_feature": cls.s0["node_feature"].repeat(np.prod(batch_shape), 1), + "edge_feature": cls.s0["edge_feature"].repeat(np.prod(batch_shape), 1), + "edge_index": cls.s0["edge_index"].repeat(np.prod(batch_shape), 1), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * cls.s0["node_feature"].shape[0], + "batch_shape": batch_shape + }) @classmethod - def make_sink_states_tensor(cls, batch_shape: Tuple) -> TensorDict: + def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: if cls.sf is None: raise NotImplementedError("Sink state is not defined") - if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError( - "Batch shape with more than one dimension is not supported" - ) - if isinstance(batch_shape, Tuple): - batch_shape = batch_shape[0] - + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) return TensorDict({ - "node_feature": cls.sf["node_feature"].repeat(batch_shape, 1, 1), - "edge_feature": cls.sf["edge_feature"].repeat(batch_shape, 1, 1), - "edge_index": cls.sf["edge_index"].repeat(batch_shape, 1, 1) - }, batch_size=int(batch_shape)) + "node_feature": cls.sf["node_feature"].repeat(np.prod(batch_shape), 1), + "edge_feature": cls.sf["edge_feature"].repeat(np.prod(batch_shape), 1), + "edge_index": cls.sf["edge_index"].repeat(np.prod(batch_shape), 1), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * cls.sf["node_feature"].shape[0], + "batch_shape": batch_shape + }) @classmethod - def make_random_states_tensor(cls, batch_shape: int) -> TensorDict: - if isinstance(batch_shape, Tuple) and len(batch_shape) > 1: - raise NotImplementedError( - "Batch shape with more than one dimension is not supported" - ) - if isinstance(batch_shape, Tuple): - batch_shape = batch_shape[0] + def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: + batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) num_nodes = np.random.randint(10) num_edges = np.random.randint(num_nodes * (num_nodes - 1) // 2) node_features_dim = cls.s0["node_feature"].shape[-1] edge_features_dim = cls.s0["edge_feature"].shape[-1] - tensor = TensorDict({ - "node_feature": torch.rand(batch_shape, num_nodes, node_features_dim), - "edge_feature": torch.rand(batch_shape, num_edges, edge_features_dim), - "edge_index": torch.randint(num_nodes, size=(batch_shape, num_edges, 2)), + return TensorDict({ + "node_feature": torch.rand(np.prod(batch_shape) * num_nodes, node_features_dim), + "edge_feature": torch.rand(np.prod(batch_shape) * num_edges, edge_features_dim), + "edge_index": torch.randint(num_nodes, size=(np.prod(batch_shape) * num_edges, 2)), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * num_nodes, + "batch_shape": batch_shape }) - return tensor def __len__(self): - return np.prod(self.tensor.batch_size) + return np.prod(self.batch_shape) def __repr__(self): return ( - f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " + f"{self.__class__.__name__} object of batch shape {self.tensor['batch_shape']} and " f"node feature dim {self.node_features_dim} and edge feature dim {self.edge_features_dim}" ) def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: - out = GraphStates(self.tensor[index]) - + if isinstance(index, (int, list)): + index = torch.tensor(index) + if index.dtype == torch.bool: + index = torch.where(index)[0] + + if torch.any(index >= len(self.tensor['batch_ptr']) - 1): + raise ValueError("Graph index out of bounds") + + start_ptrs = self.tensor['batch_ptr'][:-1][index] + end_ptrs = self.tensor['batch_ptr'][1:][index] + + node_features = [torch.empty(0, self.node_features_dim)] + edge_features = [torch.empty(0, self.edge_features_dim)] + edge_indices = [torch.empty(0, 2, dtype=torch.long)] + batch_ptr = [0] + + for start, end in zip(start_ptrs, end_ptrs): + graph_nodes = self.tensor['node_feature'][start:end] + node_features.append(graph_nodes) + batch_ptr.append(batch_ptr[-1] + len(graph_nodes)) + + # Find edges for this graph + edge_mask = ((self.tensor['edge_index'][:, 0] >= start) & + (self.tensor['edge_index'][:, 0] < end)) + graph_edges = self.tensor['edge_feature'][edge_mask] + edge_features.append(graph_edges) + + # Adjust edge indices to be local to this graph + graph_edge_index = self.tensor['edge_index'][edge_mask] + graph_edge_index[:, 0] -= start + graph_edge_index[:, 1] -= start + edge_indices.append(graph_edge_index) + + out = self.__class__(TensorDict({ + 'node_feature': torch.cat(node_features), + 'edge_feature': torch.cat(edge_features), + 'edge_index': torch.cat(edge_indices), + 'batch_ptr': torch.tensor(batch_ptr), + 'batch_shape': (len(index),) + })) + + if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] @@ -601,13 +630,64 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): """ Set particular states of the Batch """ - len_index = len(self.tensor[index]) - if len_index == 0: - return - elif len_index == len(self.tensor): - self.tensor = graph.tensor - else: # TODO: fix this - raise ValueError("Can only set states with the same batch size as the original batch") + if isinstance(index, (int, list)): + index = torch.tensor(index) + if index.dtype == torch.bool: + index = torch.where(index)[0] + + # Validate indices + if torch.any(index >= len(self.tensor['batch_ptr']) - 1): + raise ValueError("Target graph index out of bounds") + + # Get batch pointers for target and source + target_start_ptrs = self.tensor['batch_ptr'][:-1][index] + target_end_ptrs = self.tensor['batch_ptr'][1:][index] + + # Source graph details + source_tensor_dict = graph.tensor + source_num_graphs = torch.prod(source_tensor_dict['batch_shape']) + + # Validate source and target indices match + if len(index) != source_num_graphs: + raise ValueError("Number of source graphs must match number of target indices") + + for i, graph_idx in enumerate(index): + # Get start and end pointers for the current graph + start_ptr = self.tensor['batch_ptr'][graph_idx] + end_ptr = self.tensor['batch_ptr'][graph_idx + 1] + + new_nodes = source_tensor_dict['node_feature'][ + source_tensor_dict['batch_ptr'][i]:source_tensor_dict['batch_ptr'][i + 1] + ] + + # Ensure new nodes have correct feature dimension + if new_nodes.ndim == 1: + new_nodes = new_nodes.unsqueeze(0) + + if new_nodes.shape[1] != self.node_features_dim: + raise ValueError(f"Node features must have dimension {node_feature_dim}") + + # Number of new nodes to add + shift = new_nodes.shape[0] - (end_ptr - start_ptr) + + # Concatenate node features + self.tensor['node_feature'] = torch.cat([ + self.tensor['node_feature'][:start_ptr], # Nodes before the current graph + new_nodes, # New nodes to add + self.tensor['node_feature'][end_ptr:] # Nodes after the current graph + ]) + + # Update edge indices for subsequent graphs + edge_mask_0 = self.tensor['edge_index'][:, 0] >= end_ptr + edge_mask_1 = self.tensor['edge_index'][:, 1] >= end_ptr + self.tensor['edge_index'][edge_mask_0, 0] += shift + self.tensor['edge_index'][edge_mask_1, 1] += shift + + # Update batch pointers + self.tensor['batch_ptr'][graph_idx + 1:] += shift + + # TODO: add new edges + @property def device(self) -> torch.device: @@ -626,9 +706,10 @@ def clone(self) -> GraphStates: def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" - self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=1) - self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=1) - self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"]], dim=1) + self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=0) + self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0) + self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"]], dim=0) + @property def log_rewards(self) -> torch.Tensor: @@ -640,10 +721,34 @@ def log_rewards(self, log_rewards: torch.Tensor) -> None: @property def is_sink_state(self) -> torch.Tensor: - if self.tensor["node_feature"].shape[1] == 0: + if len(self.tensor["node_feature"]) != np.prod(self.batch_shape): return torch.zeros(self.batch_shape, dtype=torch.bool) - return ( - torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=(1, 2)) & - torch.all(self.tensor["edge_feature"] == self.sf["edge_feature"], dim=(1, 2)) & - torch.all(self.tensor["edge_index"] == self.sf["edge_index"], dim=(1, 2)) + return torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=-1).view(self.batch_shape) + + +def stack_states(states: List[States]): + """Given a list of states, stacks them along a new dimension (0).""" + state_example = states[0] # We assume all elems of `states` are the same. + + stacked_states = state_example.from_batch_shape((0, 0)) # Empty. + stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) + if state_example._log_rewards: + stacked_states._log_rewards = torch.stack( + [s._log_rewards for s in states], dim=0 ) + + # We are dealing with a list of DiscretrStates instances. + if hasattr(state_example, "forward_masks"): + stacked_states.forward_masks = torch.stack( + [s.forward_masks for s in states], dim=0 + ) + stacked_states.backward_masks = torch.stack( + [s.backward_masks for s in states], dim=0 + ) + + # Adds the trajectory dimension. + stacked_states.batch_shape = ( + stacked_states.tensor.shape[0], + ) + state_example.batch_shape + + return stacked_states \ No newline at end of file diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index e3ffef3f..90bdfa49 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -235,9 +235,9 @@ def __init__(self, feature_dim: int): def forward(self, states: GraphStates) -> TensorDict: node_feature = states.tensor["node_feature"].reshape(-1, self.feature_dim) - edge_index = states.tensor["edge_index"].reshape(-1, 2).T + edge_index = states.tensor["edge_index"].T - if states.tensor["node_feature"].shape[1] == 0: + if states.tensor["node_feature"].shape[0] == 0: action_type = torch.zeros((len(states), len(GraphActionType))) action_type[:, GraphActionType.ADD_NODE] = 1 features = torch.zeros((len(states), self.feature_dim)) From dd80f2815237d7b64ec3885bc0139b063edcc122 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 11 Dec 2024 14:52:22 +0100 Subject: [PATCH 18/27] fix stacking --- src/gfn/actions.py | 13 ++++++++++ src/gfn/samplers.py | 13 ++++------ src/gfn/states.py | 58 ++++++++++++++++++++++----------------------- 3 files changed, 47 insertions(+), 37 deletions(-) diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 36c52ae5..f31ac032 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -280,3 +280,16 @@ def make_dummy_actions( #features=torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), #edge_index=torch.zeros((2, *batch_shape, 0)), ) + + @classmethod + def stack(cls, actions_list: list[GraphActions]) -> GraphActions: + """Stacks a list of GraphActions objects into a single GraphActions object.""" + actions_tensor = torch.stack( + [actions.tensor for actions in actions_list], dim=0 + ) + return cls( + actions_tensor["action_type"], + actions_tensor["features"], + actions_tensor["edge_index"] + ) + diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 3085b697..e056de12 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -155,8 +155,8 @@ def sample_trajectories( else states.is_sink_state ) - trajectories_states: States = deepcopy(states) - trajectories_actions: Optional[Actions] = None + trajectories_states: List[States] = [deepcopy(states)] + trajectories_actions: List[Actions] = [] trajectories_logprobs: List[torch.Tensor] = [] trajectories_dones = torch.zeros( n_trajectories, dtype=torch.long, device=device @@ -205,10 +205,7 @@ def sample_trajectories( # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs - if trajectories_actions is None: - trajectories_actions = actions - else: - trajectories_actions.extend(actions) + trajectories_actions.append(actions) trajectories_logprobs.append(log_probs) if self.estimator.is_backward: @@ -241,8 +238,8 @@ def sample_trajectories( states = new_states dones = dones | new_dones - trajectories_states.extend(deepcopy(states)) - + trajectories_states = env.States.stack(trajectories_states) + trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = ( torch.stack(trajectories_logprobs, dim=0) if save_logprobs else None ) diff --git a/src/gfn/states.py b/src/gfn/states.py index f91bec88..07969891 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -293,6 +293,34 @@ def log_rewards(self, log_rewards: torch.Tensor) -> None: def sample(self, n_samples: int) -> States: """Samples a subset of the States object.""" return self[torch.randperm(len(self))[:n_samples]] + + @classmethod + def stack(cls, states: List[States]): + """Given a list of states, stacks them along a new dimension (0).""" + state_example = states[0] # We assume all elems of `states` are the same. + + stacked_states = state_example.from_batch_shape((0, 0)) # Empty. + stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) + if state_example._log_rewards: + stacked_states._log_rewards = torch.stack( + [s._log_rewards for s in states], dim=0 + ) + + # We are dealing with a list of DiscretrStates instances. + if hasattr(state_example, "forward_masks"): + stacked_states.forward_masks = torch.stack( + [s.forward_masks for s in states], dim=0 + ) + stacked_states.backward_masks = torch.stack( + [s.backward_masks for s in states], dim=0 + ) + + # Adds the trajectory dimension. + stacked_states.batch_shape = ( + stacked_states.tensor.shape[0], + ) + state_example.batch_shape + + return stacked_states class DiscreteStates(States, ABC): @@ -480,7 +508,7 @@ def init_forward_masks(self, set_ones: bool = True): self.forward_masks = torch.zeros(shape).bool() -class GraphStates(ABC): +class GraphStates(States): """ Base class for Graph as a state representation. The `GraphStates` object is a batched collection of multiple graph objects. The `Batch` object from PyTorch Geometric is used to represent the batch of @@ -724,31 +752,3 @@ def is_sink_state(self) -> torch.Tensor: if len(self.tensor["node_feature"]) != np.prod(self.batch_shape): return torch.zeros(self.batch_shape, dtype=torch.bool) return torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=-1).view(self.batch_shape) - - -def stack_states(states: List[States]): - """Given a list of states, stacks them along a new dimension (0).""" - state_example = states[0] # We assume all elems of `states` are the same. - - stacked_states = state_example.from_batch_shape((0, 0)) # Empty. - stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) - if state_example._log_rewards: - stacked_states._log_rewards = torch.stack( - [s._log_rewards for s in states], dim=0 - ) - - # We are dealing with a list of DiscretrStates instances. - if hasattr(state_example, "forward_masks"): - stacked_states.forward_masks = torch.stack( - [s.forward_masks for s in states], dim=0 - ) - stacked_states.backward_masks = torch.stack( - [s.backward_masks for s in states], dim=0 - ) - - # Adds the trajectory dimension. - stacked_states.batch_shape = ( - stacked_states.tensor.shape[0], - ) + state_example.batch_shape - - return stacked_states \ No newline at end of file From 616551c46f56e2b53334ab3f169432d976d4f317 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 11 Dec 2024 17:17:00 +0100 Subject: [PATCH 19/27] fix graph stacking --- src/gfn/env.py | 2 +- src/gfn/samplers.py | 9 ++++++--- src/gfn/states.py | 34 ++++++++++++++++++++++++++-------- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index e09c5b79..25026dd7 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -256,7 +256,7 @@ def _step( ) new_sink_states_idx = actions.is_exit - sf_tensor = self.States.make_sink_states_tensor(new_sink_states_idx.sum()) + sf_tensor = self.States.make_sink_states_tensor((new_sink_states_idx.sum(),)) new_states[new_sink_states_idx] = self.States(sf_tensor) new_sink_states_idx = ~valid_states_idx | new_sink_states_idx assert new_sink_states_idx.shape == states.batch_shape diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index e056de12..d0f580fd 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -193,9 +193,11 @@ def sample_trajectories( if estimator_outputs is not None: # Place estimator outputs into a stackable tensor. Note that this # will be replaced with torch.nested.nested_tensor in the future. - estimator_outputs_padded = torch.full_like( - estimator_outputs.expand((n_trajectories,) + estimator_outputs.shape[1:]).clone(), - fill_value=-float("inf") + estimator_outputs_padded = torch.full( + (n_trajectories,) + estimator_outputs.shape[1:], + fill_value=-float("inf"), + dtype=torch.float, + device=device, ) estimator_outputs_padded[~dones] = estimator_outputs all_estimator_outputs.append(estimator_outputs_padded) @@ -237,6 +239,7 @@ def sample_trajectories( ) states = new_states dones = dones | new_dones + trajectories_states.append(deepcopy(states)) trajectories_states = env.States.stack(trajectories_states) trajectories_actions = env.Actions.stack(trajectories_actions) diff --git a/src/gfn/states.py b/src/gfn/states.py index 07969891..408a41d8 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -523,9 +523,7 @@ def __init__(self, tensor: TensorDict): self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] - self.batch_shape: tuple = tuple(tensor["batch_shape"].tolist()) self._log_rewards: float = None - # TODO logic repeated from env.is_valid_action not_empty = self.tensor["batch_ptr"][:-1] + 1 < self.tensor["batch_ptr"][1:] self.forward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) @@ -538,6 +536,10 @@ def __init__(self, tensor: TensorDict): self.backward_masks[..., GraphActionType.ADD_EDGE] = not_empty # TODO: check at least one edge is present self.backward_masks[..., GraphActionType.EXIT] = not_empty self.backward_masks = self.backward_masks.view(*self.batch_shape, 3) + + @property + def batch_shape(self) -> tuple: + return tuple(self.tensor["batch_shape"].tolist()) @classmethod def from_batch_shape( @@ -667,10 +669,6 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): if torch.any(index >= len(self.tensor['batch_ptr']) - 1): raise ValueError("Target graph index out of bounds") - # Get batch pointers for target and source - target_start_ptrs = self.tensor['batch_ptr'][:-1][index] - target_end_ptrs = self.tensor['batch_ptr'][1:][index] - # Source graph details source_tensor_dict = graph.tensor source_num_graphs = torch.prod(source_tensor_dict['batch_shape']) @@ -736,8 +734,10 @@ def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=0) self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0) - self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"]], dim=0) - + self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"] + self.tensor["batch_ptr"][-1]], dim=0) + self.tensor["batch_ptr"] = torch.cat([self.tensor["batch_ptr"], other.tensor["batch_ptr"][1:] + self.tensor["batch_ptr"][-1]], dim=0) + assert torch.all(self.tensor["batch_shape"][1:] == other.tensor["batch_shape"][1:]) + self.tensor["batch_shape"] = (self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0],) + self.batch_shape[1:] @property def log_rewards(self) -> torch.Tensor: @@ -752,3 +752,21 @@ def is_sink_state(self) -> torch.Tensor: if len(self.tensor["node_feature"]) != np.prod(self.batch_shape): return torch.zeros(self.batch_shape, dtype=torch.bool) return torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=-1).view(self.batch_shape) + + @classmethod + def stack(cls, states: List[GraphStates]): + """Given a list of states, stacks them along a new dimension (0).""" + stacked_states = cls.from_batch_shape(0) + state_batch_shape = states[0].batch_shape + for state in states: + assert state.batch_shape == state_batch_shape + stacked_states.extend(state) + + stacked_states.forward_masks = torch.stack( + [s.forward_masks for s in states], dim=0 + ) + stacked_states.backward_masks = torch.stack( + [s.backward_masks for s in states], dim=0 + ) + stacked_states.tensor["batch_shape"] = (len(states),) + state_batch_shape + return stacked_states From 77611d41dccf4317427e9a86dbcb80808c608ee7 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 12 Dec 2024 12:57:19 +0100 Subject: [PATCH 20/27] fix test graph env --- src/gfn/env.py | 4 +- src/gfn/gym/graph_building.py | 49 ++++++++++-------------- src/gfn/states.py | 8 ++-- testing/test_environments.py | 71 +++++++++++++++++------------------ 4 files changed, 60 insertions(+), 72 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 25026dd7..3c543d51 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -630,11 +630,11 @@ def actions_from_tensor(self, tensor: Dict[str, torch.Tensor]): return self.Actions(**tensor) @abstractmethod - def step(self, states: GraphStates, actions: Actions) -> GraphStates: + def step(self, states: GraphStates, actions: Actions) -> torch.Tensor: """Function that takes a batch of graph states and actions and returns a batch of next graph states.""" @abstractmethod - def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + def backward_step(self, states: GraphStates, actions: Actions) -> torch.Tensor: """Function that takes a batch of graph states and actions and returns a batch of previous graph states.""" diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 67d3d316..8a2936d3 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -63,14 +63,12 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: state_tensor = self._add_node(state_tensor, batch_indices, actions.features) if action_type == GraphActionType.ADD_EDGE: - assert len(state_tensor) == len(actions) state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features], dim=0) - state_tensor["edge_index"] = torch.cat( - [state_tensor["edge_index"], actions.edge_index], dim=0 - ) + state_tensor["edge_index"] = torch.cat([state_tensor["edge_index"], actions.edge_index], dim=0) + return state_tensor - def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStates: + def backward_step(self, states: GraphStates, actions: GraphActions) -> torch.Tensor: """Backward step function for the GraphBuilding environment. Args: @@ -81,25 +79,26 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> GraphStat """ if not self.is_action_valid(states, actions, backward=True): raise NonValidActionsError("Invalid action.") - graphs: Batch = deepcopy(states.data) - assert len(graphs) == len(actions) + state_tensor = deepcopy(states.tensor) - if actions.action_type == GraphActionType.ADD_NODE: - assert graphs.x is not None + action_type = actions.action_type[0] + assert torch.all(actions.action_type == action_type) + if action_type == GraphActionType.ADD_NODE: is_equal = torch.any( - torch.all(graphs.x[:, None] == actions.features, dim=-1), dim=-1 + torch.all(state_tensor["node_feature"][:, None] == actions.features, dim=-1), + dim=-1 ) - graphs.x = graphs.x[~is_equal] - elif actions.action_type == GraphActionType.ADD_EDGE: + state_tensor["node_feature"] = state_tensor["node_feature"][~is_equal] + elif action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None is_equal = torch.all( - graphs.edge_index[:, None] == actions.edge_index[:, :, None], dim=0 + state_tensor["edge_index"] == actions.edge_index[:, None], dim=-1 ) is_equal = torch.any(is_equal, dim=0) - graphs.edge_attr = graphs.edge_attr[~is_equal] - graphs.edge_index = graphs.edge_index[:, ~is_equal] + state_tensor["edge_feature"] = state_tensor["edge_feature"][~is_equal] + state_tensor["edge_index"] = state_tensor["edge_index"][~is_equal] - return self.States(graphs) + return state_tensor def is_action_valid( self, states: GraphStates, actions: GraphActions, backward: bool = False @@ -111,8 +110,8 @@ def is_action_valid( node_feature = states[add_node_mask].tensor["node_feature"] equal_nodes_per_batch = torch.all( node_feature == actions[add_node_mask].features[:, None], dim=-1 - ).reshape(-1) - if backward: # TODO: check if no edge are connected? + ).sum(dim=-1) + if backward: # TODO: check if no edge is connected? add_node_out = torch.all(equal_nodes_per_batch == 1) else: add_node_out = torch.all(equal_nodes_per_batch == 0) @@ -131,18 +130,13 @@ def is_action_valid( if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): return False - batch_dim = add_edge_actions.features.shape[0] - batch_idx = add_edge_actions.edge_index % batch_dim - if torch.any(batch_idx != torch.arange(batch_dim)): - return False - equal_edges_per_batch_attr = torch.all( add_edge_states["edge_feature"] == add_edge_actions.features[:, None], dim=-1 - ).reshape(len(add_edge_states), -1) + ) equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) equal_edges_per_batch_index = torch.all( - add_edge_states["edge_index"] == add_edge_actions.edge_index, dim=0 - ).reshape(len(add_edge_states), -1) + add_edge_states["edge_index"] == add_edge_actions.edge_index[:, None], dim=-1 + ) equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) if backward: @@ -164,12 +158,10 @@ def _add_node(self, tensor_dict: TensorDict, batch_indices: torch.Tensor, nodes_ modified_dict = tensor_dict.clone() node_feature_dim = modified_dict['node_feature'].shape[1] - edge_feature_dim = modified_dict['edge_feature'].shape[1] for graph_idx, new_nodes in zip(batch_indices, nodes_to_add): start_ptr = tensor_dict['batch_ptr'][graph_idx] end_ptr = tensor_dict['batch_ptr'][graph_idx + 1] - num_original_nodes = end_ptr - start_ptr if new_nodes.ndim == 1: new_nodes = new_nodes.unsqueeze(0) @@ -181,7 +173,6 @@ def _add_node(self, tensor_dict: TensorDict, batch_indices: torch.Tensor, nodes_ modified_dict['batch_ptr'][graph_idx + 1:] += shift # Expand node features - original_nodes = modified_dict['node_feature'][start_ptr:end_ptr] modified_dict['node_feature'] = torch.cat([ modified_dict['node_feature'][:end_ptr], new_nodes, diff --git a/src/gfn/states.py b/src/gfn/states.py index 408a41d8..cccab384 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -628,7 +628,6 @@ def __getitem__( for start, end in zip(start_ptrs, end_ptrs): graph_nodes = self.tensor['node_feature'][start:end] node_features.append(graph_nodes) - batch_ptr.append(batch_ptr[-1] + len(graph_nodes)) # Find edges for this graph edge_mask = ((self.tensor['edge_index'][:, 0] >= start) & @@ -638,10 +637,11 @@ def __getitem__( # Adjust edge indices to be local to this graph graph_edge_index = self.tensor['edge_index'][edge_mask] - graph_edge_index[:, 0] -= start - graph_edge_index[:, 1] -= start + graph_edge_index[:, 0] -= (batch_ptr[-1] - start) + graph_edge_index[:, 1] -= (batch_ptr[-1] - start) edge_indices.append(graph_edge_index) - + batch_ptr.append(batch_ptr[-1] + len(graph_nodes)) + out = self.__class__(TensorDict({ 'node_feature': torch.cat(node_features), 'edge_feature': torch.cat(edge_features), diff --git a/testing/test_environments.py b/testing/test_environments.py index 948742bf..a061014d 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -327,64 +327,59 @@ def test_graph_env(): env = GraphBuilding(feature_dim=FEATURE_DIM) states = env.reset(batch_shape=BATCH_SIZE) - assert states.batch_shape == BATCH_SIZE + assert states.batch_shape == (BATCH_SIZE,) action_cls = env.make_actions_class() with pytest.raises(NonValidActionsError): actions = action_cls( - GraphActionType.ADD_EDGE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long), + torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), ) states = env.step(states, actions) for _ in range(NUM_NODES): actions = action_cls( - GraphActionType.ADD_NODE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), torch.rand((BATCH_SIZE, FEATURE_DIM)), ) states = env.step(states, actions) + states = env.States(states) - assert states.data.x.shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) + assert states.tensor["node_feature"].shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) with pytest.raises(NonValidActionsError): - first_node_mask = torch.arange(len(states.data.x)) // BATCH_SIZE == 0 + first_node_mask = torch.arange(len(states.tensor["node_feature"])) // BATCH_SIZE == 0 actions = action_cls( - GraphActionType.ADD_NODE, - states.data.x[first_node_mask], + torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + states.tensor["node_feature"][first_node_mask], ) states = env.step(states, actions) with pytest.raises(NonValidActionsError): edge_index = torch.randint(0, 3, (BATCH_SIZE,), dtype=torch.long) actions = action_cls( - GraphActionType.ADD_EDGE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([edge_index, edge_index]), + torch.stack([edge_index, edge_index], dim=1), ) states = env.step(states, actions) for i in range(NUM_NODES - 1): - node_is = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) - node_js = torch.arange((i + 1) * BATCH_SIZE, (i + 2) * BATCH_SIZE) + node_is = states.tensor["batch_ptr"][:-1] + i + node_js = states.tensor["batch_ptr"][:-1] + i + 1 actions = action_cls( - GraphActionType.ADD_EDGE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([node_is, node_js]), + torch.stack([node_is, node_js], dim=1), ) states = env.step(states, actions) + states = env.States(states) - with pytest.raises(NonValidActionsError): - edge_index = torch.tensor([[0, 1]] * BATCH_SIZE) - actions = action_cls( - GraphActionType.ADD_EDGE, - torch.rand((BATCH_SIZE, FEATURE_DIM)), - edge_index.T, - ) - states = env.step(states, actions) - - actions = action_cls(GraphActionType.EXIT) - states = env.step(states, actions) + actions = action_cls(torch.full((BATCH_SIZE,), GraphActionType.EXIT)) + sf_states = env.step(states, actions) + sf_states = env.States(sf_states) + assert torch.all(sf_states.is_sink_state) env.reward(states) # with pytest.raises(NonValidActionsError): @@ -395,37 +390,39 @@ def test_graph_env(): # ) # states = env.backward_step(states, actions) - num_edges_per_batch = states.data.edge_attr.shape[0] // BATCH_SIZE + num_edges_per_batch = len(states.tensor["edge_feature"]) // BATCH_SIZE for i in reversed(range(num_edges_per_batch)): edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) actions = action_cls( - GraphActionType.ADD_EDGE, - states.data.edge_attr[edge_idx], - states.data.edge_index[:, edge_idx], + torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + states.tensor["edge_feature"][edge_idx], + states.tensor["edge_index"][edge_idx], ) states = env.backward_step(states, actions) + states = env.States(states) with pytest.raises(NonValidActionsError): actions = action_cls( - GraphActionType.ADD_EDGE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (2, BATCH_SIZE), dtype=torch.long), + torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), ) states = env.backward_step(states, actions) - for i in reversed(range(NUM_NODES)): - edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) + for i in reversed(range(1, NUM_NODES + 1)): + edge_idx = torch.arange(BATCH_SIZE) * i actions = action_cls( - GraphActionType.ADD_NODE, - states.data.x[edge_idx], + torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + states.tensor["node_feature"][edge_idx], ) states = env.backward_step(states, actions) + states = env.States(states) - assert states.data.x.shape == (0, FEATURE_DIM) + assert states.tensor["node_feature"].shape == (0, FEATURE_DIM) with pytest.raises(NonValidActionsError): actions = action_cls( - GraphActionType.ADD_NODE, + torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), torch.rand((BATCH_SIZE, FEATURE_DIM)), ) states = env.backward_step(states, actions) From 5874ff6443f0496b0ef13ece7355d5f8adfe06d9 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 00:56:55 +0100 Subject: [PATCH 21/27] add ring example --- src/gfn/actions.py | 39 ++---- src/gfn/env.py | 12 +- src/gfn/gflownet/flow_matching.py | 16 +-- src/gfn/gym/__init__.py | 1 + src/gfn/modules.py | 3 +- src/gfn/states.py | 28 ++-- testing/test_environments.py | 87 ++++++------ testing/test_samplers_and_trajectories.py | 4 - tutorials/examples/test_graph_ring.py | 159 ++++++++++++++++++++++ 9 files changed, 249 insertions(+), 100 deletions(-) create mode 100644 tutorials/examples/test_graph_ring.py diff --git a/src/gfn/actions.py b/src/gfn/actions.py index f31ac032..d2e9b3b0 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -181,12 +181,7 @@ class GraphActionType(enum.IntEnum): class GraphActions(Actions): features_dim: ClassVar[int] - def __init__( - self, - action_type: torch.Tensor, - features: Optional[torch.Tensor] = None, - edge_index: Optional[torch.Tensor] = None, - ): + def __init__(self, tensor: TensorDict): """Initializes a GraphAction object. Args: @@ -196,16 +191,18 @@ def __init__( edge_index: an tensor of shape (batch_shape, 2) representing the edge to add. This must defined if and only if the action type is GraphActionType.AddEdge. """ - self.batch_shape = action_type.shape + self.batch_shape = tensor["action_type"].shape + features = tensor.get("features", None) if features is None: - assert torch.all(action_type == GraphActionType.EXIT) + assert torch.all(tensor["action_type"] == GraphActionType.EXIT) features = torch.zeros((*self.batch_shape, self.features_dim)) + edge_index = tensor.get("edge_index", None) if edge_index is None: - assert torch.all(action_type != GraphActionType.ADD_EDGE) + assert torch.all(tensor["action_type"] != GraphActionType.ADD_EDGE) edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long) self.tensor = TensorDict({ - "action_type": action_type, + "action_type": tensor["action_type"], "features": features, "edge_index": edge_index, }, batch_size=self.batch_shape) @@ -224,12 +221,8 @@ def __len__(self) -> int: def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions: """Get particular actions of the batch.""" - tensor = self.tensor[index] - return GraphActions( - tensor["action_type"], - tensor["features"], - tensor["edge_index"] - ) + return GraphActions(self.tensor[index]) + def __setitem__( self, index: int | Sequence[int] | Sequence[bool], action: GraphActions @@ -276,9 +269,11 @@ def make_dummy_actions( ) -> GraphActions: """Creates an Actions object of dummy actions with the given batch shape.""" return cls( - action_type=torch.full(batch_shape, fill_value=GraphActionType.EXIT), - #features=torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), - #edge_index=torch.zeros((2, *batch_shape, 0)), + TensorDict({ + "action_type": torch.full(batch_shape, fill_value=GraphActionType.EXIT), + # "features": torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), + # "edge_index": torch.zeros((2, *batch_shape, 0)), + }, batch_size=batch_shape) ) @classmethod @@ -287,9 +282,5 @@ def stack(cls, actions_list: list[GraphActions]) -> GraphActions: actions_tensor = torch.stack( [actions.tensor for actions in actions_list], dim=0 ) - return cls( - actions_tensor["action_type"], - actions_tensor["features"], - actions_tensor["edge_index"] - ) + return cls(actions_tensor) diff --git a/src/gfn/env.py b/src/gfn/env.py index 3c543d51..86c592b5 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -219,6 +219,7 @@ def reset( batch_shape = (1,) if isinstance(batch_shape, int): batch_shape = (batch_shape,) + return self.States.from_batch_shape( batch_shape=batch_shape, random=random, sink=sink ) @@ -618,17 +619,6 @@ class DefaultGraphAction(GraphActions): return DefaultGraphAction - def actions_from_tensor(self, tensor: Dict[str, torch.Tensor]): - """Wraps the supplied Tensor in an Actions instance. - - Args: - tensor: The tensor of shape "action_shape" representing the actions. - - Returns: - Actions: An instance of Actions. - """ - return self.Actions(**tensor) - @abstractmethod def step(self, states: GraphStates, actions: Actions) -> torch.Tensor: """Function that takes a batch of graph states and actions and returns a batch of next diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 38072080..8347b835 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -33,10 +33,10 @@ class FMGFlowNet(GFlowNet[Tuple[DiscreteStates, DiscreteStates]]): def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): super().__init__() - assert isinstance( # TODO: need a more flexible type check. - logF, - DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator, - ), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator" + # assert isinstance( # TODO: need a more flexible type check. + # logF, + # DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator, + # ), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator" self.logF = logF self.alpha = alpha @@ -50,10 +50,10 @@ def sample_trajectories( **policy_kwargs: Any, ) -> Trajectories: """Sample trajectory with optional kwargs controling the policy.""" - if not env.is_discrete: - raise NotImplementedError( - "Flow Matching GFlowNet only supports discrete environments for now." - ) + # if not env.is_discrete: + # raise NotImplementedError( + # "Flow Matching GFlowNet only supports discrete environments for now." + # ) sampler = Sampler(estimator=self.logF) trajectories = sampler.sample_trajectories( env, diff --git a/src/gfn/gym/__init__.py b/src/gfn/gym/__init__.py index fbec4831..20490566 100644 --- a/src/gfn/gym/__init__.py +++ b/src/gfn/gym/__init__.py @@ -1,3 +1,4 @@ from gfn.gym.box import Box from gfn.gym.discrete_ebm import DiscreteEBM from gfn.gym.hypergrid import HyperGrid +from gfn.gym.graph_building import GraphBuilding \ No newline at end of file diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 86a345d4..169a1f57 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -79,7 +79,6 @@ def __init__( ) preprocessor = IdentityPreprocessor(module.input_dim) self.preprocessor = preprocessor - self._output_dim_is_checked = False self.is_backward = is_backward def forward(self, input: States | torch.Tensor) -> torch.Tensor: @@ -236,7 +235,7 @@ def forward(self, states: DiscreteStates) -> torch.Tensor: Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). """ out = super().forward(states) - assert out.shape[-1] == self.expected_output_dim + assert out.shape[-1] == self.expected_output_dim, f"Expected output dim: {self.expected_output_dim}, got: {out.shape[-1]}" return out def to_probability_distribution( diff --git a/src/gfn/states.py b/src/gfn/states.py index cccab384..215d49b0 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -597,8 +597,8 @@ def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: "batch_shape": batch_shape }) - def __len__(self): - return np.prod(self.batch_shape) + def __len__(self) -> int: + return int(np.prod(self.batch_shape)) def __repr__(self): return ( @@ -609,10 +609,8 @@ def __repr__(self): def __getitem__( self, index: int | Sequence[int] | slice | torch.Tensor ) -> GraphStates: - if isinstance(index, (int, list)): - index = torch.tensor(index) - if index.dtype == torch.bool: - index = torch.where(index)[0] + tensor_idx = torch.arange(len(self)).view(*self.batch_shape) + index = tensor_idx[index].flatten() if torch.any(index >= len(self.tensor['batch_ptr']) - 1): raise ValueError("Graph index out of bounds") @@ -747,11 +745,23 @@ def log_rewards(self) -> torch.Tensor: def log_rewards(self, log_rewards: torch.Tensor) -> None: self._log_rewards = log_rewards + def _compare(self, other: TensorDict) -> torch.Tensor: + out = torch.zeros(len(self.tensor["batch_ptr"]) - 1, dtype=torch.bool) + for i in range(len(self.tensor["batch_ptr"]) - 1): + start, end = self.tensor["batch_ptr"][i], self.tensor["batch_ptr"][i + 1] + if end - start != len(other["node_feature"]): + out[i] = False + else: + out[i] = torch.all(self.tensor["node_feature"][start:end] == other["node_feature"]) + return out.view(self.batch_shape) + @property def is_sink_state(self) -> torch.Tensor: - if len(self.tensor["node_feature"]) != np.prod(self.batch_shape): - return torch.zeros(self.batch_shape, dtype=torch.bool) - return torch.all(self.tensor["node_feature"] == self.sf["node_feature"], dim=-1).view(self.batch_shape) + return self._compare(self.sf) + + @property + def is_initial_state(self) -> torch.Tensor: + return self._compare(self.s0) @classmethod def stack(cls, states: List[GraphStates]): diff --git a/testing/test_environments.py b/testing/test_environments.py index a061014d..a2919d50 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +from tensordict import TensorDict from gfn.actions import GraphActionType from gfn.env import NonValidActionsError @@ -331,18 +332,18 @@ def test_graph_env(): action_cls = env.make_actions_class() with pytest.raises(NonValidActionsError): - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), + }, batch_size=BATCH_SIZE)) states = env.step(states, actions) for _ in range(NUM_NODES): - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + }, batch_size=BATCH_SIZE)) states = env.step(states, actions) states = env.States(states) @@ -350,33 +351,35 @@ def test_graph_env(): with pytest.raises(NonValidActionsError): first_node_mask = torch.arange(len(states.tensor["node_feature"])) // BATCH_SIZE == 0 - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - states.tensor["node_feature"][first_node_mask], - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor["node_feature"][first_node_mask], + }, batch_size=BATCH_SIZE)) states = env.step(states, actions) with pytest.raises(NonValidActionsError): edge_index = torch.randint(0, 3, (BATCH_SIZE,), dtype=torch.long) - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([edge_index, edge_index], dim=1), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.stack([edge_index, edge_index], dim=1), + }, batch_size=BATCH_SIZE)) states = env.step(states, actions) for i in range(NUM_NODES - 1): node_is = states.tensor["batch_ptr"][:-1] + i node_js = states.tensor["batch_ptr"][:-1] + i + 1 - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.stack([node_is, node_js], dim=1), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.stack([node_is, node_js], dim=1), + }, batch_size=BATCH_SIZE)) states = env.step(states, actions) states = env.States(states) - actions = action_cls(torch.full((BATCH_SIZE,), GraphActionType.EXIT)) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.EXIT), + }, batch_size=BATCH_SIZE)) sf_states = env.step(states, actions) sf_states = env.States(sf_states) assert torch.all(sf_states.is_sink_state) @@ -393,36 +396,36 @@ def test_graph_env(): num_edges_per_batch = len(states.tensor["edge_feature"]) // BATCH_SIZE for i in reversed(range(num_edges_per_batch)): edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - states.tensor["edge_feature"][edge_idx], - states.tensor["edge_index"][edge_idx], - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": states.tensor["edge_feature"][edge_idx], + "edge_index": states.tensor["edge_index"][edge_idx], + }, batch_size=BATCH_SIZE)) states = env.backward_step(states, actions) states = env.States(states) with pytest.raises(NonValidActionsError): - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), + }, batch_size=BATCH_SIZE)) states = env.backward_step(states, actions) for i in reversed(range(1, NUM_NODES + 1)): edge_idx = torch.arange(BATCH_SIZE) * i - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - states.tensor["node_feature"][edge_idx], - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor["node_feature"][edge_idx], + }, batch_size=BATCH_SIZE)) states = env.backward_step(states, actions) states = env.States(states) assert states.tensor["node_feature"].shape == (0, FEATURE_DIM) with pytest.raises(NonValidActionsError): - actions = action_cls( - torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - torch.rand((BATCH_SIZE, FEATURE_DIM)), - ) + actions = action_cls(TensorDict({ + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + }, batch_size=BATCH_SIZE)) states = env.backward_step(states, actions) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 90bdfa49..470c8b09 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -277,7 +277,3 @@ def test_graph_building(): save_logprobs=True, save_estimator_outputs=False, ) - - -if __name__ == "__main__": - test_graph_building() \ No newline at end of file diff --git a/tutorials/examples/test_graph_ring.py b/tutorials/examples/test_graph_ring.py new file mode 100644 index 00000000..8546ec41 --- /dev/null +++ b/tutorials/examples/test_graph_ring.py @@ -0,0 +1,159 @@ +"""Write ane xamples where we want to create graphs that are rings.""" + +import torch +from torch import nn +from gfn.actions import Actions, GraphActionType, GraphActions +from gfn.gflownet.flow_matching import FMGFlowNet +from gfn.gym import GraphBuilding +from gfn.modules import DiscretePolicyEstimator +from gfn.preprocessors import Preprocessor +from gfn.states import GraphStates +from tensordict import TensorDict +from torch_geometric.nn import GCNConv + + +def state_evaluator(states: GraphStates) -> torch.Tensor: + if states.tensor["edge_index"].shape[0] == 0: + return torch.zeros(states.batch_shape) + if states.tensor["edge_index"].shape[0] != states.tensor["node_feature"].shape[0]: + return torch.zeros(states.batch_shape) + + i0 = torch.unique(states.tensor["edge_index"][0], sorted=False) + i1 = torch.unique(states.tensor["edge_index"][1], sorted=False) + + if len(i0) == len(i1) == states.tensor["node_feature"].shape[0]: + return torch.ones(states.batch_shape) + return torch.zeros(states.batch_shape) + + +class RingPolicyEstimator(nn.Module): + def __init__(self, n_nodes: int): + super().__init__() + self.action_type_conv = GCNConv(1, 1) + self.edge_index_conv = GCNConv(1, 8) + self.n_nodes = n_nodes + + def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: + cumsum = torch.zeros((len(tensor) + 1, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + cumsum[1:] = torch.cumsum(tensor, dim=0) + return cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]] + + def forward(self, states_tensor: TensorDict) -> torch.Tensor: + node_feature = states_tensor["node_feature"].reshape(-1, 1) + edge_index = states_tensor["edge_index"].T + batch_ptr = states_tensor["batch_ptr"] + + action_type = self.action_type_conv(node_feature, edge_index) + action_type = self._group_sum(action_type, batch_ptr) + + edge_index = self.edge_index_conv(node_feature, edge_index) + #edge_index = self._group_sum(edge_index, batch_ptr) + edge_index = edge_index.reshape(*states_tensor["batch_shape"], -1, 8) + edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) + torch.diagonal(edge_index, dim1=-2, dim2=-1).fill_(float("-inf")) + edge_actions = edge_index.reshape(*states_tensor["batch_shape"], -1) + + return torch.cat([action_type, edge_actions], dim=-1) + +class RingGraphBuilding(GraphBuilding): + def __init__(self, nodes: int = 10): + self.nodes = nodes + self.n_actions = 1 + nodes * nodes + super().__init__(feature_dim=1, state_evaluator=state_evaluator) + + + def make_actions_class(self) -> type[Actions]: + env = self + class RingActions(Actions): + action_shape = (1,) + dummy_action = torch.tensor([env.n_actions]) + exit_action = torch.zeros(1,) + + return RingActions + + + def make_states_class(self) -> type[GraphStates]: + env = self + + class RingStates(GraphStates): + s0 = TensorDict({ + "node_feature": torch.zeros((env.nodes, 1)), + "edge_feature": torch.zeros((0, 1)), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, batch_size=()) + sf = TensorDict({ + "node_feature": torch.ones((env.nodes, 1)), + "edge_feature": torch.zeros((0, 1)), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, batch_size=()) + n_actions = env.n_actions + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.forward_masks = torch.ones(self.batch_shape + (self.n_actions,), dtype=torch.bool) + self.backward_masks = torch.ones(self.batch_shape + (self.n_actions,), dtype=torch.bool) + return RingStates + + def _step(self, states: GraphStates, actions: Actions) -> GraphStates: + actions = self.convert_actions(actions) + return super()._step(states, actions) + + def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: + actions = self.convert_actions(actions) + return super()._backward_step(states, actions) + + def convert_actions(self, actions: Actions) -> GraphActions: + action_tensor = actions.tensor.squeeze(-1) + action_type = torch.where(action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE) + edge_index_i0 = (action_tensor - 1) // (self.nodes) + edge_index_i1 = (action_tensor - 1) % (self.nodes) + # edge_index_i1 = edge_index_i1 + (edge_index_i1 >= edge_index_i0) + + edge_index = torch.stack([edge_index_i0, edge_index_i1], dim=-1) + return GraphActions(TensorDict({ + "action_type": action_type, + "features": torch.ones(action_tensor.shape + (1,)), + "edge_index": edge_index, + }, batch_size=action_tensor.shape)) + + +class GraphPreprocessor(Preprocessor): + + def __init__(self, feature_dim: int = 1): + super().__init__(output_dim=feature_dim) + + def preprocess(self, states: GraphStates) -> TensorDict: + return states.tensor + + def __call__(self, states: GraphStates) -> torch.Tensor: + return self.preprocess(states) + + +if __name__ == "__main__": + torch.random.manual_seed(42) + env = RingGraphBuilding(nodes=10) + module = RingPolicyEstimator(env.nodes) + + pf_estimator = DiscretePolicyEstimator(module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor()) + + gflownet = FMGFlowNet(pf_estimator) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-3) + + visited_terminating_states = env.States.from_batch_shape((0,)) + losses = [] + + for iteration in range(100): + print(f"Iteration {iteration}") + trajectories = gflownet.sample_trajectories(env, n=128) + samples = gflownet.to_training_samples(trajectories) + optimizer.zero_grad() + loss = gflownet.loss(env, samples) + loss.backward() + optimizer.step() + + visited_terminating_states.extend(trajectories.last_states) + losses.append(loss.item()) + + + From 9d42332b6946a01f0457a66a6255deb8455ee48a Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 13:09:59 +0100 Subject: [PATCH 22/27] remove check edge_features --- src/gfn/gym/graph_building.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 8a2936d3..d1eb9af1 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -123,6 +123,7 @@ def is_action_valid( add_edge_states = states[add_edge_mask].tensor add_edge_actions = actions[add_edge_mask] + import pdb; pdb.set_trace() if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): return False if add_edge_states["node_feature"].shape[0] == 0: @@ -130,23 +131,15 @@ def is_action_valid( if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): return False - equal_edges_per_batch_attr = torch.all( - add_edge_states["edge_feature"] == add_edge_actions.features[:, None], dim=-1 - ) - equal_edges_per_batch_attr = torch.sum(equal_edges_per_batch_attr, dim=-1) equal_edges_per_batch_index = torch.all( add_edge_states["edge_index"] == add_edge_actions.edge_index[:, None], dim=-1 ) equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) if backward: - add_edge_out = torch.all(equal_edges_per_batch_attr == 1) and torch.all( - equal_edges_per_batch_index == 1 - ) + add_edge_out = torch.all(equal_edges_per_batch_index == 1) else: - add_edge_out = torch.all(equal_edges_per_batch_attr == 0) and torch.all( - equal_edges_per_batch_index == 0 - ) + add_edge_out = torch.all(equal_edges_per_batch_index == 0) return bool(add_node_out) and bool(add_edge_out) From 2d442426521ca58709e2afb9d63c6335aeee3852 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 13:10:17 +0100 Subject: [PATCH 23/27] fix GraphStates set --- src/gfn/states.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 215d49b0..9be90bc4 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -658,10 +658,8 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): """ Set particular states of the Batch """ - if isinstance(index, (int, list)): - index = torch.tensor(index) - if index.dtype == torch.bool: - index = torch.where(index)[0] + tensor_idx = torch.arange(len(self)).view(*self.batch_shape) + index = tensor_idx[index].flatten() # Validate indices if torch.any(index >= len(self.tensor['batch_ptr']) - 1): @@ -679,10 +677,10 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): # Get start and end pointers for the current graph start_ptr = self.tensor['batch_ptr'][graph_idx] end_ptr = self.tensor['batch_ptr'][graph_idx + 1] + source_start_ptr = source_tensor_dict['batch_ptr'][i] + source_end_ptr = source_tensor_dict['batch_ptr'][i + 1] - new_nodes = source_tensor_dict['node_feature'][ - source_tensor_dict['batch_ptr'][i]:source_tensor_dict['batch_ptr'][i + 1] - ] + new_nodes = source_tensor_dict['node_feature'][source_start_ptr:source_end_ptr] # Ensure new nodes have correct feature dimension if new_nodes.ndim == 1: @@ -706,13 +704,18 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): edge_mask_1 = self.tensor['edge_index'][:, 1] >= end_ptr self.tensor['edge_index'][edge_mask_0, 0] += shift self.tensor['edge_index'][edge_mask_1, 1] += shift + self.tensor['edge_index'] = torch.cat([ + self.tensor['edge_index'], + source_tensor_dict['edge_index'] - source_start_ptr + start_ptr, + ], dim=0) + self.tensor['edge_feature'] = torch.cat([ + self.tensor['edge_feature'], + source_tensor_dict['edge_feature'], + ], dim=0) # Update batch pointers self.tensor['batch_ptr'][graph_idx + 1:] += shift - # TODO: add new edges - - @property def device(self) -> torch.device: return self.tensor.device From 173d4fb029eebda8ce8065a880c27ed3b3ffd885 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 13:12:50 +0100 Subject: [PATCH 24/27] remove debug --- src/gfn/gym/graph_building.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index d1eb9af1..1bbb4704 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -123,7 +123,6 @@ def is_action_valid( add_edge_states = states[add_edge_mask].tensor add_edge_actions = actions[add_edge_mask] - import pdb; pdb.set_trace() if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): return False if add_edge_states["node_feature"].shape[0] == 0: From 7265857e21198dada0a281d4c0397f04f2c26d27 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 15:19:09 +0100 Subject: [PATCH 25/27] fix add_edge action --- src/gfn/gym/graph_building.py | 19 ++++++++++++------- src/gfn/states.py | 7 +++++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 1bbb4704..cec3c6c6 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -64,7 +64,10 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: if action_type == GraphActionType.ADD_EDGE: state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features], dim=0) - state_tensor["edge_index"] = torch.cat([state_tensor["edge_index"], actions.edge_index], dim=0) + state_tensor["edge_index"] = torch.cat([ + state_tensor["edge_index"], + actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] + ], dim=0) return state_tensor @@ -91,8 +94,9 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> torch.Ten state_tensor["node_feature"] = state_tensor["node_feature"][~is_equal] elif action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None + global_edge_index = actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] is_equal = torch.all( - state_tensor["edge_index"] == actions.edge_index[:, None], dim=-1 + state_tensor["edge_index"] == global_edge_index[:, None], dim=-1 ) is_equal = torch.any(is_equal, dim=0) state_tensor["edge_feature"] = state_tensor["edge_feature"][~is_equal] @@ -130,15 +134,16 @@ def is_action_valid( if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): return False - equal_edges_per_batch_index = torch.all( - add_edge_states["edge_index"] == add_edge_actions.edge_index[:, None], dim=-1 + global_edge_index = add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None] + equal_edges_per_batch = torch.all( + add_edge_states["edge_index"] == global_edge_index[:, None], dim=-1 ) - equal_edges_per_batch_index = torch.sum(equal_edges_per_batch_index, dim=-1) + equal_edges_per_batch = equal_edges_per_batch.sum(dim=-1) if backward: - add_edge_out = torch.all(equal_edges_per_batch_index == 1) + add_edge_out = torch.all(equal_edges_per_batch == 1) else: - add_edge_out = torch.all(equal_edges_per_batch_index == 0) + add_edge_out = torch.all(equal_edges_per_batch == 0) return bool(add_node_out) and bool(add_edge_out) diff --git a/src/gfn/states.py b/src/gfn/states.py index 9be90bc4..471145bf 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -704,13 +704,15 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): edge_mask_1 = self.tensor['edge_index'][:, 1] >= end_ptr self.tensor['edge_index'][edge_mask_0, 0] += shift self.tensor['edge_index'][edge_mask_1, 1] += shift + edge_to_add_mask = torch.all(source_tensor_dict['edge_index'] >= source_start_ptr, dim=-1) + edge_to_add_mask &= torch.all(source_tensor_dict['edge_index'] < source_end_ptr, dim=-1) self.tensor['edge_index'] = torch.cat([ self.tensor['edge_index'], - source_tensor_dict['edge_index'] - source_start_ptr + start_ptr, + source_tensor_dict['edge_index'][edge_to_add_mask] - source_start_ptr + start_ptr, ], dim=0) self.tensor['edge_feature'] = torch.cat([ self.tensor['edge_feature'], - source_tensor_dict['edge_feature'], + source_tensor_dict['edge_feature'][edge_to_add_mask], ], dim=0) # Update batch pointers @@ -735,6 +737,7 @@ def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=0) self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0) + # TODO: fix indices self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"] + self.tensor["batch_ptr"][-1]], dim=0) self.tensor["batch_ptr"] = torch.cat([self.tensor["batch_ptr"], other.tensor["batch_ptr"][1:] + self.tensor["batch_ptr"][-1]], dim=0) assert torch.all(self.tensor["batch_shape"][1:] == other.tensor["batch_shape"][1:]) From 2b3208fd738aabaa873191635ddaa0628e96012e Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 20 Dec 2024 16:12:42 +0100 Subject: [PATCH 26/27] fix edge_index after get --- src/gfn/gym/graph_building.py | 3 +-- src/gfn/states.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index cec3c6c6..48ea17e3 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -137,8 +137,7 @@ def is_action_valid( global_edge_index = add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None] equal_edges_per_batch = torch.all( add_edge_states["edge_index"] == global_edge_index[:, None], dim=-1 - ) - equal_edges_per_batch = equal_edges_per_batch.sum(dim=-1) + ).sum(dim=-1) if backward: add_edge_out = torch.all(equal_edges_per_batch == 1) diff --git a/src/gfn/states.py b/src/gfn/states.py index 471145bf..2b419696 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -635,8 +635,8 @@ def __getitem__( # Adjust edge indices to be local to this graph graph_edge_index = self.tensor['edge_index'][edge_mask] - graph_edge_index[:, 0] -= (batch_ptr[-1] - start) - graph_edge_index[:, 1] -= (batch_ptr[-1] - start) + graph_edge_index[:, 0] -= (start - batch_ptr[-1]) + graph_edge_index[:, 1] -= (start - batch_ptr[-1]) edge_indices.append(graph_edge_index) batch_ptr.append(batch_ptr[-1] + len(graph_nodes)) From b84246f2aa497dd668552145ff0fd1641b2cd5ac Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 22 Dec 2024 23:44:05 +0100 Subject: [PATCH 27/27] push updated code --- src/gfn/env.py | 2 +- src/gfn/gym/graph_building.py | 6 +- src/gfn/states.py | 20 ++++-- tutorials/examples/test_graph_ring.py | 98 ++++++++++++++++++--------- 4 files changed, 83 insertions(+), 43 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index 86c592b5..38884c97 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -300,7 +300,7 @@ def _backward_step( # Calculate the backward step, and update only the states which are not Done. new_not_done_states_tensor = self.backward_step(valid_states, valid_actions) - new_states.tensor[valid_states_idx] = new_not_done_states_tensor + new_states[valid_states_idx] = self.States(new_not_done_states_tensor) if isinstance(new_states, DiscreteStates): self.update_masks(new_states) diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 48ea17e3..9a15d468 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -71,7 +71,7 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: return state_tensor - def backward_step(self, states: GraphStates, actions: GraphActions) -> torch.Tensor: + def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDict: """Backward step function for the GraphBuilding environment. Args: @@ -83,6 +83,8 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> torch.Ten if not self.is_action_valid(states, actions, backward=True): raise NonValidActionsError("Invalid action.") state_tensor = deepcopy(states.tensor) + if len(actions) == 0: + return state_tensor action_type = actions.action_type[0] assert torch.all(actions.action_type == action_type) @@ -126,14 +128,12 @@ def is_action_valid( else: add_edge_states = states[add_edge_mask].tensor add_edge_actions = actions[add_edge_mask] - if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): return False if add_edge_states["node_feature"].shape[0] == 0: return False if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): return False - global_edge_index = add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None] equal_edges_per_batch = torch.all( add_edge_states["edge_index"] == global_edge_index[:, None], dim=-1 diff --git a/src/gfn/states.py b/src/gfn/states.py index 2b419696..7875038c 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -523,7 +523,7 @@ def __init__(self, tensor: TensorDict): self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] - self._log_rewards: float = None + self._log_rewards: Optional[float] = None # TODO logic repeated from env.is_valid_action not_empty = self.tensor["batch_ptr"][:-1] + 1 < self.tensor["batch_ptr"][1:] self.forward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) @@ -700,18 +700,19 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): ]) # Update edge indices for subsequent graphs - edge_mask_0 = self.tensor['edge_index'][:, 0] >= end_ptr - edge_mask_1 = self.tensor['edge_index'][:, 1] >= end_ptr - self.tensor['edge_index'][edge_mask_0, 0] += shift - self.tensor['edge_index'][edge_mask_1, 1] += shift + edge_mask = self.tensor['edge_index'] >= end_ptr + assert torch.all(edge_mask[..., 0] == edge_mask[..., 1]) + edge_mask = torch.all(edge_mask, dim=-1) + self.tensor['edge_index'][edge_mask] += shift + edge_mask |= torch.all(self.tensor['edge_index'] < start_ptr, dim=-1) edge_to_add_mask = torch.all(source_tensor_dict['edge_index'] >= source_start_ptr, dim=-1) edge_to_add_mask &= torch.all(source_tensor_dict['edge_index'] < source_end_ptr, dim=-1) self.tensor['edge_index'] = torch.cat([ - self.tensor['edge_index'], + self.tensor['edge_index'][edge_mask], source_tensor_dict['edge_index'][edge_to_add_mask] - source_start_ptr + start_ptr, ], dim=0) self.tensor['edge_feature'] = torch.cat([ - self.tensor['edge_feature'], + self.tensor['edge_feature'][edge_mask], source_tensor_dict['edge_feature'][edge_to_add_mask], ], dim=0) @@ -759,6 +760,11 @@ def _compare(self, other: TensorDict) -> torch.Tensor: out[i] = False else: out[i] = torch.all(self.tensor["node_feature"][start:end] == other["node_feature"]) + edge_mask = torch.all((self.tensor["edge_index"] >= start) & (self.tensor["edge_index"] < end), dim=-1) + edge_index = self.tensor["edge_index"][edge_mask] - start + out[i] &= len(edge_index) == len(other["edge_index"]) and torch.all(edge_index == other["edge_index"]) + edge_feature = self.tensor["edge_feature"][edge_mask] + out[i] &= len(edge_feature) == len(other["edge_feature"]) and torch.all(edge_feature == other["edge_feature"]) return out.view(self.batch_shape) @property diff --git a/tutorials/examples/test_graph_ring.py b/tutorials/examples/test_graph_ring.py index 8546ec41..3b2679b8 100644 --- a/tutorials/examples/test_graph_ring.py +++ b/tutorials/examples/test_graph_ring.py @@ -1,5 +1,6 @@ """Write ane xamples where we want to create graphs that are rings.""" +from typing import Optional import torch from torch import nn from gfn.actions import Actions, GraphActionType, GraphActions @@ -13,17 +14,24 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: + eps = 1e-6 if states.tensor["edge_index"].shape[0] == 0: - return torch.zeros(states.batch_shape) + return torch.full(states.batch_shape, eps) if states.tensor["edge_index"].shape[0] != states.tensor["node_feature"].shape[0]: - return torch.zeros(states.batch_shape) - - i0 = torch.unique(states.tensor["edge_index"][0], sorted=False) - i1 = torch.unique(states.tensor["edge_index"][1], sorted=False) - - if len(i0) == len(i1) == states.tensor["node_feature"].shape[0]: - return torch.ones(states.batch_shape) - return torch.zeros(states.batch_shape) + return torch.full(states.batch_shape, eps) + + out = torch.zeros(len(states)) + for i in range(len(states)): + start, end = states.tensor["batch_ptr"][i], states.tensor["batch_ptr"][i + 1] + edge_index_mask = torch.all(states.tensor["edge_index"] >= start, dim=-1) & torch.all(states.tensor["edge_index"] < end, dim=-1) + edge_index = states.tensor["edge_index"][edge_index_mask] + arange = torch.arange(start, end) + # TODO: not correct, accepts multiple rings + if torch.all(torch.sort(edge_index[:, 0])[0] == arange) and torch.all(torch.sort(edge_index[:, 1])[0] == arange): + out[i] = 1 + else: + out[i] = eps + return out.view(*states.batch_shape) class RingPolicyEstimator(nn.Module): @@ -47,18 +55,16 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: action_type = self._group_sum(action_type, batch_ptr) edge_index = self.edge_index_conv(node_feature, edge_index) - #edge_index = self._group_sum(edge_index, batch_ptr) - edge_index = edge_index.reshape(*states_tensor["batch_shape"], -1, 8) + edge_index = edge_index.reshape(*states_tensor["batch_shape"], self.n_nodes, 8) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) - torch.diagonal(edge_index, dim1=-2, dim2=-1).fill_(float("-inf")) - edge_actions = edge_index.reshape(*states_tensor["batch_shape"], -1) + edge_actions = edge_index.reshape(*states_tensor["batch_shape"], self.n_nodes * self.n_nodes) return torch.cat([action_type, edge_actions], dim=-1) class RingGraphBuilding(GraphBuilding): - def __init__(self, nodes: int = 10): - self.nodes = nodes - self.n_actions = 1 + nodes * nodes + def __init__(self, n_nodes: int = 10): + self.n_nodes = n_nodes + self.n_actions = 1 + n_nodes * n_nodes super().__init__(feature_dim=1, state_evaluator=state_evaluator) @@ -77,22 +83,52 @@ def make_states_class(self) -> type[GraphStates]: class RingStates(GraphStates): s0 = TensorDict({ - "node_feature": torch.zeros((env.nodes, 1)), + "node_feature": torch.zeros((env.n_nodes, 1)), "edge_feature": torch.zeros((0, 1)), "edge_index": torch.zeros((0, 2), dtype=torch.long), }, batch_size=()) sf = TensorDict({ - "node_feature": torch.ones((env.nodes, 1)), + "node_feature": torch.ones((env.n_nodes, 1)), "edge_feature": torch.zeros((0, 1)), "edge_index": torch.zeros((0, 2), dtype=torch.long), }, batch_size=()) - n_actions = env.n_actions - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.forward_masks = torch.ones(self.batch_shape + (self.n_actions,), dtype=torch.bool) - self.backward_masks = torch.ones(self.batch_shape + (self.n_actions,), dtype=torch.bool) + def __init__(self, tensor: TensorDict): + self.tensor = tensor + self.node_features_dim = tensor["node_feature"].shape[-1] + self.edge_features_dim = tensor["edge_feature"].shape[-1] + self._log_rewards: Optional[float] = None + + self.n_nodes = env.n_nodes + self.n_actions = env.n_actions + + @property + def forward_masks(self): + forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) + forward_masks[:, 1::self.n_nodes + 1] = False + for i in range(len(self)): + existing_edges = self[i].tensor["edge_index"] + forward_masks[i, 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1]] = False + + return forward_masks.view(*self.batch_shape, self.n_actions) + + @forward_masks.setter + def forward_masks(self, value: torch.Tensor): + pass # fwd masks is computed on the fly + + @property + def backward_masks(self): + backward_masks = torch.zeros(len(self), self.n_actions, dtype=torch.bool) + for i in range(len(self)): + existing_edges = self[i].tensor["edge_index"] + backward_masks[i, 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1]] = True + + return backward_masks.view(*self.batch_shape, self.n_actions) + + @backward_masks.setter + def backward_masks(self, value: torch.Tensor): + pass # bwd masks is computed on the fly + return RingStates def _step(self, states: GraphStates, actions: Actions) -> GraphStates: @@ -106,9 +142,8 @@ def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: def convert_actions(self, actions: Actions) -> GraphActions: action_tensor = actions.tensor.squeeze(-1) action_type = torch.where(action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE) - edge_index_i0 = (action_tensor - 1) // (self.nodes) - edge_index_i1 = (action_tensor - 1) % (self.nodes) - # edge_index_i1 = edge_index_i1 + (edge_index_i1 >= edge_index_i0) + edge_index_i0 = (action_tensor - 1) // (self.n_nodes) + edge_index_i1 = (action_tensor - 1) % (self.n_nodes) edge_index = torch.stack([edge_index_i0, edge_index_i1], dim=-1) return GraphActions(TensorDict({ @@ -132,8 +167,8 @@ def __call__(self, states: GraphStates) -> torch.Tensor: if __name__ == "__main__": torch.random.manual_seed(42) - env = RingGraphBuilding(nodes=10) - module = RingPolicyEstimator(env.nodes) + env = RingGraphBuilding(n_nodes=3) + module = RingPolicyEstimator(env.n_nodes) pf_estimator = DiscretePolicyEstimator(module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor()) @@ -143,9 +178,8 @@ def __call__(self, states: GraphStates) -> torch.Tensor: visited_terminating_states = env.States.from_batch_shape((0,)) losses = [] - for iteration in range(100): - print(f"Iteration {iteration}") - trajectories = gflownet.sample_trajectories(env, n=128) + for iteration in range(128): + trajectories = gflownet.sample_trajectories(env, n=32) samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() loss = gflownet.loss(env, samples)