Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No more class factories #149

Merged
merged 30 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0535e9f
removed one order of magnitude precision required
josephdviviano Nov 27, 2023
7753041
merge conflicts
josephdviviano Nov 27, 2023
f7a562e
replaced State method call with Env method call
josephdviviano Nov 29, 2023
742bfb2
replaced State method call with Env method call, and removed is_tenso…
josephdviviano Nov 29, 2023
f545167
replaced State method call with Env method call
josephdviviano Nov 29, 2023
f945b39
switch name of backward/forward step, and replaced State method call …
josephdviviano Nov 29, 2023
2a51704
removed States/Actions class definition, added the appropriate args t…
josephdviviano Nov 29, 2023
cdd425c
moved environment to Gym
josephdviviano Nov 29, 2023
6ae846b
renamed maskless_?_step functions, and made the generic step/backward…
josephdviviano Nov 29, 2023
a09c9a5
removed comment
josephdviviano Nov 29, 2023
93f6a65
States methods moved to Env methods, also, name change for step
josephdviviano Nov 29, 2023
2ab5885
changes to the handling of forward / backward masks. in addition, mak…
josephdviviano Nov 29, 2023
bfd6bbf
method renaming
josephdviviano Nov 29, 2023
d6d30fe
docs update (TOOD: this might need a full rework)
josephdviviano Nov 29, 2023
25b7527
changes to support new API
josephdviviano Nov 29, 2023
f12cbec
tweaks (TODO: fix in follow up PR)
josephdviviano Nov 29, 2023
cdffab1
black / isort
josephdviviano Nov 29, 2023
3b756a5
cleanup
josephdviviano Nov 29, 2023
c98f423
gradient clipping added back in
Nov 30, 2023
6af395f
renaming make_States_class to follow pep
josephdviviano Dec 8, 2023
6c0d8aa
updated documentation
josephdviviano Dec 8, 2023
364e52d
rename methods
josephdviviano Dec 8, 2023
8d8a4c1
rename method
josephdviviano Dec 8, 2023
ebf0db2
deps
josephdviviano Feb 13, 2024
71e6603
requirements
josephdviviano Feb 13, 2024
c393014
deps
josephdviviano Feb 13, 2024
3cb9914
merge
josephdviviano Feb 16, 2024
b85f1eb
merged
josephdviviano Feb 16, 2024
ad80e7e
update
josephdviviano Feb 16, 2024
ae3fa2e
update
josephdviviano Feb 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
removed States/Actions class definition, added the appropriate args t…
…o be passed to subclasses.
  • Loading branch information
josephdviviano committed Nov 29, 2023
commit 2a51704e92f7bde08641c394e2b6cdb8997f44f5
49 changes: 16 additions & 33 deletions src/gfn/gym/box.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from math import log
from typing import ClassVar, Literal, Tuple
from typing import Literal, Tuple

import torch
from torchtyping import TensorType as TT
Expand All @@ -25,49 +25,32 @@ def __init__(
self.delta = delta
self.epsilon = epsilon
s0 = torch.tensor([0.0, 0.0], device=torch.device(device_str))
exit_action = torch.tensor([-float("inf"), -float("inf")], device=torch.device(device_str))
dummy_action = torch.tensor([float("inf"), float("inf")], device=torch.device(device_str))

self.R0 = R0
self.R1 = R1
self.R2 = R2

super().__init__(s0=s0)

def make_States_class(self) -> type[States]:
env = self

class BoxStates(States):
state_shape: ClassVar[Tuple[int, ...]] = (2,)
s0 = env.s0
sf = env.sf # should be (-inf, -inf)

@classmethod
def make_random_states_tensor(
cls, batch_shape: Tuple[int, ...]
) -> TT["batch_shape", 2, torch.float]:
return torch.rand(batch_shape + (2,), device=env.device)

return BoxStates

def make_Actions_class(self) -> type[Actions]:
env = self

class BoxActions(Actions):
action_shape: ClassVar[Tuple[int, ...]] = (2,)
dummy_action: ClassVar[TT[2]] = torch.tensor(
[float("inf"), float("inf")], device=env.device
)
exit_action: ClassVar[TT[2]] = torch.tensor(
[-float("inf"), -float("inf")], device=env.device
)
super().__init__(
s0=s0,
state_shape=(2,), # ()
action_shape=(2,),
dummy_action=dummy_action,
exit_action=exit_action,
)

return BoxActions
def make_random_states_tensor(
self, batch_shape: Tuple[int, ...]
) -> TT["batch_shape", 2, torch.float]:
return torch.rand(batch_shape + (2,), device=self.device)

def maskless_step(
def step(
self, states: States, actions: Actions
) -> TT["batch_shape", 2, torch.float]:
return states.tensor + actions.tensor

def maskless_backward_step(
def backward_step(
self, states: States, actions: Actions
) -> TT["batch_shape", 2, torch.float]:
return states.tensor - actions.tensor
Expand Down
94 changes: 33 additions & 61 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import ClassVar, Literal, Tuple
from typing import Literal, Tuple

import torch
from torch import Tensor
import torch.nn as nn
from torchtyping import TensorType as TT

Expand Down Expand Up @@ -87,69 +88,37 @@ def __init__(
raise ValueError(f"Unknown preprocessor {preprocessor_name}")

super().__init__(
n_actions=n_actions,
s0=s0,
state_shape=(self.ndim, ),
# dummy_action=,
# exit_action=,
n_actions=n_actions,
sf=sf,
device_str=device_str,
preprocessor=preprocessor,
)

def make_States_class(self) -> type[DiscreteStates]:
env = self

class DiscreteEBMStates(DiscreteStates):
state_shape: ClassVar[tuple[int, ...]] = (env.ndim,)
s0 = env.s0
sf = env.sf
n_actions = env.n_actions
device = env.device

@classmethod
def make_random_states_tensor(
cls, batch_shape: Tuple[int, ...]
) -> TT["batch_shape", "state_shape", torch.float]:
return torch.randint(
-1,
2,
batch_shape + (env.ndim,),
dtype=torch.long,
device=env.device,
)

# TODO: Look into make masks - I don't think this is being called.
def make_masks(
self,
) -> Tuple[
TT["batch_shape", "n_actions", torch.bool],
TT["batch_shape", "n_actions - 1", torch.bool],
]:
forward_masks = torch.zeros(
self.batch_shape + (env.n_actions,),
device=env.device,
dtype=torch.bool,
)
backward_masks = torch.zeros(
self.batch_shape + (env.n_actions - 1,),
device=env.device,
dtype=torch.bool,
)

return forward_masks, backward_masks

def update_masks(self) -> None:
self.set_default_typing()
self.forward_masks[..., : env.ndim] = self.tensor == -1
self.forward_masks[..., env.ndim : 2 * env.ndim] = self.tensor == -1
self.forward_masks[..., -1] = torch.all(self.tensor != -1, dim=-1)
self.backward_masks[..., : env.ndim] = self.tensor == 0
self.backward_masks[..., env.ndim : 2 * env.ndim] = self.tensor == 1

return DiscreteEBMStates
def update_masks(self, states: type[States]) -> None:
states.set_default_typing()
states.forward_masks[..., : self.ndim] = states.tensor == -1
states.forward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == -1
states.forward_masks[..., -1] = torch.all(states.tensor != -1, dim=-1)
states.backward_masks[..., : self.ndim] = states.tensor == 0
states.backward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == 1

def make_random_states_tensor(self, batch_shape: Tuple) -> Tensor:
return torch.randint(
-1,
2,
batch_shape + (self.ndim,),
dtype=torch.long,
device=self.device,
)

def is_exit_actions(self, actions: TT["batch_shape"]) -> TT["batch_shape"]:
return actions == self.n_actions - 1

def maskless_step(
def step(
self, states: States, actions: Actions
) -> TT["batch_shape", "state_shape", torch.float]:
# First, we select that actions that replace a -1 with a 0.
Expand All @@ -169,15 +138,18 @@ def maskless_step(
)
return states.tensor

def maskless_backward_step(
def backward_step(
self, states: States, actions: Actions
) -> TT["batch_shape", "state_shape", torch.float]:
# In this env, states are n-dim vectors. s0 is empty (represented as -1),
# so s0=[-1, -1, ..., -1], each action is replacing a -1 with either a
# 0 or 1. Action i in [0, ndim-1] os replacing s[i] with 0, whereas
# action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1.
# A backward action asks "what index should be set back to -1", hence the fmod
# to enable wrapping of indices.
"""Performs a backward step.

In this env, states are n-dim vectors. s0 is empty (represented as -1),
so s0=[-1, -1, ..., -1], each action is replacing a -1 with either a
0 or 1. Action i in [0, ndim-1] os replacing s[i] with 0, whereas
action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1.
A backward action asks "what index should be set back to -1", hence the fmod
to enable wrapping of indices.
"""
return states.tensor.scatter(-1, actions.tensor.fmod(self.ndim), -1)

def reward(self, final_states: DiscreteStates) -> TT["batch_shape"]:
Expand Down
63 changes: 25 additions & 38 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Copied and Adapted from https://github.com/Tikquuss/GflowNets_Tutorial
"""
from typing import ClassVar, Literal, Tuple
from typing import Literal, Tuple

import torch
from einops import rearrange
Expand Down Expand Up @@ -53,7 +53,6 @@ def __init__(
sf = torch.full(
(ndim,), fill_value=-1, dtype=torch.long, device=torch.device(device_str)
)

n_actions = ndim + 1

if preprocessor_name == "Identity":
Expand All @@ -74,55 +73,43 @@ def __init__(
else:
raise ValueError(f"Unknown preprocessor {preprocessor_name}")

state_shape = (self.ndim,)

super().__init__(
n_actions=n_actions,
s0=s0,
state_shape=state_shape,
sf=sf,
device_str=device_str,
preprocessor=preprocessor,
)

def make_States_class(self) -> type[DiscreteStates]:
"Creates a States class for this environment"
env = self

class HyperGridStates(DiscreteStates):
state_shape: ClassVar[tuple[int, ...]] = (env.ndim,)
s0 = env.s0
sf = env.sf
n_actions = env.n_actions
device = env.device

@classmethod
def make_random_states_tensor(
cls, batch_shape: Tuple[int, ...]
) -> TT["batch_shape", "state_shape", torch.float]:
"Creates a batch of random states."
states_tensor = torch.randint(
0, env.height, batch_shape + env.s0.shape, device=env.device
)
return states_tensor

def update_masks(self) -> None:
"Update the masks based on the current states."
self.set_default_typing()
# Not allowed to take any action beyond the environment height, but
# allow early termination.
self.set_nonexit_action_masks(
self.tensor == env.height - 1,
allow_exit=True,
)
self.backward_masks = self.tensor != 0

return HyperGridStates

def maskless_step(
def update_masks(self, states: type[DiscreteStates]) -> None:
"""Update the masks based on the current states."""
states.set_default_typing()
# Not allowed to take any action beyond the environment height, but
# allow early termination.
states.set_nonexit_action_masks(
states.tensor == self.height - 1,
allow_exit=True,
)
states.backward_masks = states.tensor != 0

def make_random_states_tensor(
self, batch_shape: Tuple[int, ...]
) -> TT["batch_shape", "state_shape", torch.float]:
"""Creates a batch of random states."""
return torch.randint(
0, self.height, batch_shape + self.s0.shape, device=self.device
)

def step(
self, states: DiscreteStates, actions: Actions
) -> TT["batch_shape", "state_shape", torch.float]:
new_states_tensor = states.tensor.scatter(-1, actions.tensor, 1, reduce="add")
return new_states_tensor

def maskless_backward_step(
def backward_step(
self, states: DiscreteStates, actions: Actions
) -> TT["batch_shape", "state_shape", torch.float]:
new_states_tensor = states.tensor.scatter(-1, actions.tensor, -1, reduce="add")
Expand Down