diff --git a/src/gfn/states.py b/src/gfn/states.py index f4fa1a20..fac0ac09 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -6,6 +6,7 @@ from typing import Callable, ClassVar, List, Optional, Sequence, cast import torch +from torch import Tensor from torchtyping import TensorType as TT @@ -126,7 +127,9 @@ def __repr__(self): def device(self) -> torch.device: return self.tensor.device - def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> States: + def __getitem__( + self, index: int | Sequence[int] | Sequence[bool] | Tensor + ) -> States: """Access particular states of the batch.""" out = self.__class__( self.tensor[index]