Skip to content

Commit

Permalink
function to stack a list of states
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Feb 22, 2024
1 parent 617cc22 commit 89027dd
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from math import prod
from typing import Callable, ClassVar, Optional, Sequence, cast
from typing import Callable, ClassVar, Optional, Sequence, List, cast

import torch
from torchtyping import TensorType as TT
Expand Down Expand Up @@ -446,3 +446,20 @@ def init_forward_masks(self, set_ones: bool = True):
self.forward_masks = torch.ones(shape).bool()
else:
self.forward_masks = torch.zeros(shape).bool()


def stack_states(states: List[States]):
"""Given a list of states, stacks them along a new dimension (0)."""
state_example = states[0] # We assume all elems of `states` are the same.

stacked_states = state_example.from_batch_shape((0, 0)) # Empty.
stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0)
if state_example._log_rewards:
stacked_states._log_rewards = torch.stack([s._log_rewards for s in states], dim=0)
stacked_states.forward_masks = torch.stack([s.forward_masks for s in states], dim=0)
stacked_states.backward_masks = torch.stack([s.backward_masks for s in states], dim=0)

# Adds the trajectory dimension.
stacked_states.batch_shape = (stacked_states.tensor.shape[0],) + state_example.batch_shape

return stacked_states

0 comments on commit 89027dd

Please sign in to comment.