From 6b47e06dad648b3e7030866f4cf61de5ab0816d8 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 8 Oct 2024 23:33:28 -0400 Subject: [PATCH] typing --- src/gfn/states.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index f4fa1a20..fac0ac09 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -6,6 +6,7 @@ from typing import Callable, ClassVar, List, Optional, Sequence, cast import torch +from torch import Tensor from torchtyping import TensorType as TT @@ -126,7 +127,9 @@ def __repr__(self): def device(self) -> torch.device: return self.tensor.device - def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> States: + def __getitem__( + self, index: int | Sequence[int] | Sequence[bool] | Tensor + ) -> States: """Access particular states of the batch.""" out = self.__class__( self.tensor[index]