Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No more class factories #149

Merged
merged 30 commits into from
Feb 16, 2024
Merged
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0535e9f
removed one order of magnitude precision required
josephdviviano Nov 27, 2023
7753041
merge conflicts
josephdviviano Nov 27, 2023
f7a562e
replaced State method call with Env method call
josephdviviano Nov 29, 2023
742bfb2
replaced State method call with Env method call, and removed is_tenso…
josephdviviano Nov 29, 2023
f545167
replaced State method call with Env method call
josephdviviano Nov 29, 2023
f945b39
switch name of backward/forward step, and replaced State method call …
josephdviviano Nov 29, 2023
2a51704
removed States/Actions class definition, added the appropriate args t…
josephdviviano Nov 29, 2023
cdd425c
moved environment to Gym
josephdviviano Nov 29, 2023
6ae846b
renamed maskless_?_step functions, and made the generic step/backward…
josephdviviano Nov 29, 2023
a09c9a5
removed comment
josephdviviano Nov 29, 2023
93f6a65
States methods moved to Env methods, also, name change for step
josephdviviano Nov 29, 2023
2ab5885
changes to the handling of forward / backward masks. in addition, mak…
josephdviviano Nov 29, 2023
bfd6bbf
method renaming
josephdviviano Nov 29, 2023
d6d30fe
docs update (TOOD: this might need a full rework)
josephdviviano Nov 29, 2023
25b7527
changes to support new API
josephdviviano Nov 29, 2023
f12cbec
tweaks (TODO: fix in follow up PR)
josephdviviano Nov 29, 2023
cdffab1
black / isort
josephdviviano Nov 29, 2023
3b756a5
cleanup
josephdviviano Nov 29, 2023
c98f423
gradient clipping added back in
Nov 30, 2023
6af395f
renaming make_States_class to follow pep
josephdviviano Dec 8, 2023
6c0d8aa
updated documentation
josephdviviano Dec 8, 2023
364e52d
rename methods
josephdviviano Dec 8, 2023
8d8a4c1
rename method
josephdviviano Dec 8, 2023
ebf0db2
deps
josephdviviano Feb 13, 2024
71e6603
requirements
josephdviviano Feb 13, 2024
c393014
deps
josephdviviano Feb 13, 2024
3cb9914
merge
josephdviviano Feb 16, 2024
b85f1eb
merged
josephdviviano Feb 16, 2024
ad80e7e
update
josephdviviano Feb 16, 2024
ae3fa2e
update
josephdviviano Feb 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
changes to the handling of forward / backward masks. in addition, mak…
…e_random_state_tensor is now a function passed to the States class as inheritance can no longer be relied on to overwrite the default method.
  • Loading branch information
josephdviviano committed Nov 29, 2023
commit 2ab5885dbe8104d542f06e229c35212a00f267b5
44 changes: 17 additions & 27 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC, abstractmethod
from math import prod
from typing import ClassVar, Optional, Sequence, cast
from typing import ClassVar, Optional, Sequence, cast, Callable

import torch
from torchtyping import TensorType as TT
Expand Down Expand Up @@ -49,6 +49,7 @@ class States(ABC):
sf: ClassVar[
TT["state_shape", torch.float]
] # Dummy state, used to pad a batch of states
make_random_states_tensor: Callable = lambda x: (_ for _ in ()).throw(NotImplementedError("The environment does not support initialization of random states."))

def __init__(self, tensor: TT["batch_shape", "state_shape"]):
"""Initalize the State container with a batch of states.
Expand Down Expand Up @@ -101,15 +102,6 @@ def make_initial_states_tensor(
assert cls.s0 is not None and state_ndim is not None
return cls.s0.repeat(*batch_shape, *((1,) * state_ndim))

@classmethod
def make_random_states_tensor(
cls, batch_shape: tuple[int]
) -> TT["batch_shape", "state_shape", torch.float]:
"""Makes a tensor with a `batch_shape` of random states, placeholder."""
raise NotImplementedError(
"The environment does not support initialization of random states."
)

@classmethod
def make_sink_states_tensor(
cls, batch_shape: tuple[int]
Expand All @@ -133,7 +125,7 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> States:
"""Access particular states of the batch."""
return self.__class__(
self.tensor[index]
) # TODO: Inefficient - this make a copy of the tensor!
) # TODO: Inefficient - this makes a copy of the tensor!

def __setitem__(
self, index: int | Sequence[int] | Sequence[bool], states: States
Expand Down Expand Up @@ -275,7 +267,6 @@ class DiscreteStates(States, ABC):
forward_masks: A boolean tensor of allowable forward policy actions.
backward_masks: A boolean tensor of allowable backward policy actions.
"""

n_actions: ClassVar[int]
device: ClassVar[torch.device]

Expand All @@ -285,32 +276,35 @@ def __init__(
forward_masks: Optional[TT["batch_shape", "n_actions", torch.bool]] = None,
backward_masks: Optional[TT["batch_shape", "n_actions - 1", torch.bool]] = None,
) -> None:

"""Initalize a DiscreteStates container with a batch of states and masks.
Args:
tensor: A batch of states.
forward_masks (optional): Initializes a boolean tensor of allowable forward
policy actions.
backward_masks (optional): Initializes a boolean tensor of allowable backward
policy actions.
forward_masks: Initializes a boolean tensor of allowable forward policy
actions.
backward_masks: Initializes a boolean tensor of allowable backward policy
actions.
"""
super().__init__(tensor)

if forward_masks is None and backward_masks is None:
self.forward_masks = torch.ones(
# In the usual case, no masks are provided and we produce these defaults.
# Note: this **must** be updated externally by the env.
if isinstance(forward_masks, type(None)):
forward_masks = torch.ones(
(*self.batch_shape, self.__class__.n_actions),
dtype=torch.bool,
device=self.__class__.device,
)
self.backward_masks = torch.ones(
if isinstance(backward_masks, type(None)):
backward_masks = torch.ones(
(*self.batch_shape, self.__class__.n_actions - 1),
dtype=torch.bool,
device=self.__class__.device,
)
self.update_masks()
else:
self.forward_masks = cast(torch.Tensor, forward_masks)
self.backward_masks = cast(torch.Tensor, backward_masks)

# Ensures typecasting is consistent no matter what is submitted to init.
self.forward_masks = cast(torch.Tensor, forward_masks) # TODO: Required?
self.backward_masks = cast(torch.Tensor, backward_masks) # TODO: Required?
self.set_default_typing()

def clone(self) -> States:
Expand All @@ -332,10 +326,6 @@ def set_default_typing(self) -> None:
self.backward_masks,
)

@abstractmethod
def update_masks(self) -> None:
"""Updates the masks, called after each action is taken."""

def _check_both_forward_backward_masks_exist(self):
assert self.forward_masks is not None and self.backward_masks is not None

Expand Down