diff --git a/src/gfn/env.py b/src/gfn/env.py index 7a60e8ec..68c30d8a 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -449,7 +449,7 @@ def reset( return states @abstractmethod - def update_masks(self, states: type[States]) -> None: + def update_masks(self, states: States) -> None: """Updates the masks in States. Called automatically after each step for discrete environments. diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index 5823736d..3820a735 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -113,7 +113,7 @@ def __init__( preprocessor=preprocessor, ) - def update_masks(self, states: type[States]) -> None: + def update_masks(self, states: States) -> None: states.forward_masks[..., : self.ndim] = states.tensor == -1 states.forward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == -1 states.forward_masks[..., -1] = torch.all(states.tensor != -1, dim=-1) diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index ac76a8df..847fb32e 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -82,7 +82,7 @@ def __init__( preprocessor=preprocessor, ) - def update_masks(self, states: type[DiscreteStates]) -> None: + def update_masks(self, states: DiscreteStates) -> None: """Update the masks based on the current states.""" # Not allowed to take any action beyond the environment height, but # allow early termination. diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 19695ac8..5525a2e1 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -358,7 +358,6 @@ def local_search( ) # Episodic reward is assumed. trajectories_logprobs = None # TODO - len_recon = recon_trajectories.states.tensor.shape[0] for i in range(bs): # Backward part back_source_idx = backward_trajectories.when_is_done[i] @@ -373,10 +372,11 @@ def local_search( back_source_idx - n_back : back_source_idx, i ].flip(0) + len_recon = recon_trajectories.when_is_done[i] + 1 # Forward part trajectories_states_tsr[ n_back : n_back + len_recon, i - ] = recon_trajectories.states.tensor[:, i] + ] = recon_trajectories.states.tensor[:len_recon, i] trajectories_actions_tsr[ n_back : n_back + len_recon - 1, i ] = recon_trajectories.actions.tensor[: len_recon - 1, i] @@ -385,7 +385,7 @@ def local_search( new_trajectories = Trajectories( env=env, - states=env.States(trajectories_states_tsr), + states=env.states_from_tensor(trajectories_states_tsr), conditioning=conditioning, actions=env.Actions(trajectories_actions_tsr), when_is_done=trajectories_dones,