-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft: Add Graphs as States for torchgfn #210
Draft
alip67
wants to merge
31
commits into
GFNOrg:master
Choose a base branch
from
alip67:graph-states
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 3 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
9ae28b2
including Graphs as States for torchgfn
alip67 de6ab1c
add GraphEnv
younik 24e23e8
add deps and reformat
younik 1f7b220
add test, fix errors, add valid action check
younik 63e4f1c
fix formatting
younik 8034fb2
add GraphAction
younik d179671
fix batching mechanism
younik e018f4e
Merge branch 'GFNOrg:master' into graph-states
alip67 7ff96d5
add support for EXIT action
younik cf482da
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
younik dacbbf7
add GraphActionPolicyEstimator
younik 98ea448
Merge branch 'GFNOrg:master' into graph-states
alip67 e74e500
Sampler integration work
younik a862bb4
Merge branch 'graph-states' of https://github.com/alip67/torchgfn int…
younik 5e64c84
use TensorDict
younik 81f8b71
solve some errors
younik 34781ef
use tensordict in actions
younik 3e584f2
handle sf
younik d5e438f
remove Data
younik fba5d50
categorical action type
younik 478bd14
change batching
younik dd80f28
fix stacking
younik 616551c
fix graph stacking
younik 77611d4
fix test graph env
younik 5874ff6
add ring example
younik 9d42332
remove check edge_features
younik 2d44242
fix GraphStates set
younik 173d4fb
remove debug
younik 7265857
fix add_edge action
younik 2b3208f
fix edge_index after get
younik b84246f
push updated code
younik File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from copy import copy | ||
from typing import Callable, Literal | ||
|
||
import torch | ||
from torch_geometric.data import Batch, Data | ||
from torch_geometric.nn import GCNConv | ||
|
||
from gfn.actions import Actions | ||
from gfn.env import GraphEnv | ||
from gfn.states import GraphStates | ||
|
||
|
||
class GraphBuilding(GraphEnv): | ||
def __init__( | ||
self, | ||
num_nodes: int, | ||
node_feature_dim: int, | ||
edge_feature_dim: int, | ||
state_evaluator: Callable[[Batch], torch.Tensor] | None = None, | ||
device_str: Literal["cpu", "cuda"] = "cpu", | ||
): | ||
s0 = Data(x=torch.zeros((num_nodes, node_feature_dim)).to(device_str)) | ||
exit_action = torch.tensor( | ||
[-float("inf"), -float("inf")], device=torch.device(device_str) | ||
) | ||
dummy_action = torch.tensor( | ||
[float("inf"), float("inf")], device=torch.device(device_str) | ||
) | ||
if state_evaluator is None: | ||
state_evaluator = GCNConvEvaluator(node_feature_dim) | ||
self.state_evaluator = state_evaluator | ||
|
||
super().__init__( | ||
s0=s0, | ||
node_feature_dim=node_feature_dim, | ||
edge_feature_dim=edge_feature_dim, | ||
action_shape=(2,), | ||
dummy_action=dummy_action, | ||
exit_action=exit_action, | ||
device_str=device_str, | ||
) | ||
|
||
def step(self, states: GraphStates, actions: Actions) -> GraphStates: | ||
"""Step function for the GraphBuilding environment. | ||
|
||
Args: | ||
states: GraphStates object representing the current graph. | ||
actions: Actions indicating which edge to add. | ||
|
||
Returns the next graph the new GraphStates. | ||
""" | ||
graphs: Batch = copy.deepcopy(states.data) | ||
assert len(graphs) == len(actions) | ||
|
||
for i, act in enumerate(actions.tensor): | ||
edge_index = torch.cat([graphs[i].edge_index, act.unsqueeze(1)], dim=1) | ||
graphs[i].edge_index = edge_index | ||
|
||
return GraphStates(graphs) | ||
|
||
def backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: | ||
"""Backward step function for the GraphBuilding environment. | ||
|
||
Args: | ||
states: GraphStates object representing the current graph. | ||
actions: Actions indicating which edge to remove. | ||
|
||
Returns the previous graph as a new GraphStates. | ||
""" | ||
graphs: Batch = copy.deepcopy(states.data) | ||
assert len(graphs) == len(actions) | ||
|
||
for i, act in enumerate(actions.tensor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note to self - we should evaluate if this can be vectorized. |
||
edge_index = graphs[i].edge_index | ||
edge_index = edge_index[:, edge_index[1] != act] | ||
graphs[i].edge_index = edge_index | ||
|
||
return GraphStates(graphs) | ||
|
||
def is_action_valid( | ||
self, states: GraphStates, actions: Actions, backward: bool = False | ||
) -> bool: | ||
for i, act in enumerate(actions.tensor): | ||
if backward and len(states.data[i].edge_index[1]) == 0: | ||
return False | ||
if not backward and torch.any(states.data[i].edge_index[1] == act): | ||
return False | ||
return True | ||
|
||
def reward(self, final_states: GraphStates) -> torch.Tensor: | ||
"""The environment's reward given a state. | ||
This or log_reward must be implemented. | ||
|
||
Args: | ||
final_states: A batch of final states. | ||
|
||
Returns: | ||
torch.Tensor: Tensor of shape "batch_shape" containing the rewards. | ||
""" | ||
return self.state_evaluator(final_states.data).sum(dim=1) | ||
|
||
@property | ||
def log_partition(self) -> float: | ||
"Returns the logarithm of the partition function." | ||
raise NotImplementedError | ||
|
||
@property | ||
def true_dist_pmf(self) -> torch.Tensor: | ||
"Returns a one-dimensional tensor representing the true distribution." | ||
raise NotImplementedError | ||
|
||
|
||
class GCNConvEvaluator: | ||
def __init__(self, num_features): | ||
self.net = GCNConv(num_features, 1) | ||
|
||
def __call__(self, batch: Batch) -> torch.Tensor: | ||
return self.net(batch.x, batch.edge_index) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to self - we should evaluate if this can be vectorized.