Skip to content

Commit

Permalink
add Metropolis-Hastings acceptance rule
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Nov 29, 2024
1 parent 4e11c27 commit e7fa8b6
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 190 deletions.
4 changes: 2 additions & 2 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@ def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
0
)

trajectories_states = trajectories.env.States(new_states)
trajectories_actions = trajectories.env.Actions(new_actions)
trajectories_states = trajectories.env.states_from_tensor(new_states)
trajectories_actions = trajectories.env.actions_from_tensor(new_actions)

return Trajectories(
env=trajectories.env,
Expand Down
14 changes: 9 additions & 5 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(
self.dummy_action = dummy_action
self.exit_action = exit_action

# Warning: don't use self.States or self.Actions to initialize an instance of the class.
# Use self.states_from_tensor or self.actions_from_tensor instead.
self.States = self.make_states_class()
self.Actions = self.make_actions_class()

Expand All @@ -85,7 +87,9 @@ def states_from_tensor(self, tensor: torch.Tensor):
"""
return self.States(tensor)

def states_from_batch_shape(self, batch_shape: Tuple):
def states_from_batch_shape(
self, batch_shape: Tuple, random: bool = False, sink: bool = False
):
"""Returns a batch of s0 states with a given batch_shape.
Args:
Expand All @@ -94,7 +98,7 @@ def states_from_batch_shape(self, batch_shape: Tuple):
Returns:
States: A batch of initial states.
"""
return self.States.from_batch_shape(batch_shape)
return self.States.from_batch_shape(batch_shape, random=random, sink=sink)

def actions_from_tensor(self, tensor: torch.Tensor):
"""Wraps the supplied Tensor an an Actions instance.
Expand Down Expand Up @@ -218,7 +222,7 @@ def reset(
batch_shape = (1,)
if isinstance(batch_shape, int):
batch_shape = (batch_shape,)
return self.States.from_batch_shape(
return self.states_from_batch_shape(
batch_shape=batch_shape, random=random, sink=sink
)

Expand Down Expand Up @@ -441,7 +445,7 @@ def reset(
batch_shape = (1,)
if isinstance(batch_shape, int):
batch_shape = (batch_shape,)
states = self.States.from_batch_shape(
states = self.states_from_batch_shape(
batch_shape=batch_shape, random=random, sink=sink
)
self.update_masks(states)
Expand All @@ -455,7 +459,7 @@ def update_masks(self, states: States) -> None:
Called automatically after each step for discrete environments.
"""

def make_states_class(self) -> type[States]:
def make_states_class(self) -> type[DiscreteStates]:
env = self

class DiscreteEnvStates(DiscreteStates):
Expand Down
6 changes: 3 additions & 3 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
preprocessor=preprocessor,
)

def update_masks(self, states: States) -> None:
def update_masks(self, states: DiscreteStates) -> None:
states.forward_masks[..., : self.ndim] = states.tensor == -1
states.forward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == -1
states.forward_masks[..., -1] = torch.all(states.tensor != -1, dim=-1)
Expand Down Expand Up @@ -248,13 +248,13 @@ def all_states(self) -> DiscreteStates:
digits = torch.arange(3, device=self.device)
all_states = torch.cartesian_prod(*[digits] * self.ndim)
all_states = all_states - 1
return self.States(all_states)
return self.states_from_tensor(all_states)

@property
def terminating_states(self) -> DiscreteStates:
digits = torch.arange(2, device=self.device)
all_states = torch.cartesian_prod(*[digits] * self.ndim)
return self.States(all_states)
return self.states_from_tensor(all_states)

@property
def true_dist_pmf(self) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,13 @@ def build_grid(self) -> DiscreteStates:
rearrange_string += " ".join([f"n{i}" for i in range(ndim, 0, -1)])
rearrange_string += " ndim"
grid = rearrange(grid, rearrange_string).long()
return self.States(grid)
return self.states_from_tensor(grid)

@property
def all_states(self) -> DiscreteStates:
grid = self.build_grid()
flat_grid = rearrange(grid.tensor, "... ndim -> (...) ndim")
return self.States(flat_grid)
return self.states_from_tensor(flat_grid)

@property
def terminating_states(self) -> DiscreteStates:
Expand Down
Loading

0 comments on commit e7fa8b6

Please sign in to comment.