Skip to content

Commit

Permalink
fix error in backward_masks
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Nov 5, 2024
1 parent a3af467 commit d7d95ca
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def reset(
return states

@abstractmethod
def update_masks(self, states: type[States]) -> None:
def update_masks(self, states: States) -> None:
"""Updates the masks in States.
Called automatically after each step for discrete environments.
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
preprocessor=preprocessor,
)

def update_masks(self, states: type[States]) -> None:
def update_masks(self, states: States) -> None:
states.forward_masks[..., : self.ndim] = states.tensor == -1
states.forward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == -1
states.forward_masks[..., -1] = torch.all(states.tensor != -1, dim=-1)
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
preprocessor=preprocessor,
)

def update_masks(self, states: type[DiscreteStates]) -> None:
def update_masks(self, states: DiscreteStates) -> None:
"""Update the masks based on the current states."""
# Not allowed to take any action beyond the environment height, but
# allow early termination.
Expand Down
6 changes: 3 additions & 3 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def local_search(
) # Episodic reward is assumed.
trajectories_logprobs = None # TODO

len_recon = recon_trajectories.states.tensor.shape[0]
for i in range(bs):
# Backward part
back_source_idx = backward_trajectories.when_is_done[i]
Expand All @@ -373,10 +372,11 @@ def local_search(
back_source_idx - n_back : back_source_idx, i
].flip(0)

len_recon = recon_trajectories.when_is_done[i] + 1
# Forward part
trajectories_states_tsr[
n_back : n_back + len_recon, i
] = recon_trajectories.states.tensor[:, i]
] = recon_trajectories.states.tensor[:len_recon, i]
trajectories_actions_tsr[
n_back : n_back + len_recon - 1, i
] = recon_trajectories.actions.tensor[: len_recon - 1, i]
Expand All @@ -385,7 +385,7 @@ def local_search(

new_trajectories = Trajectories(
env=env,
states=env.States(trajectories_states_tsr),
states=env.states_from_tensor(trajectories_states_tsr),

This comment has been minimized.

Copy link
@josephdviviano

josephdviviano Nov 5, 2024

Collaborator

Nice -- was this the cause of the backwards mask error?

conditioning=conditioning,
actions=env.Actions(trajectories_actions_tsr),
when_is_done=trajectories_dones,
Expand Down

0 comments on commit d7d95ca

Please sign in to comment.