Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Oct 9, 2024
1 parent e03c03a commit 6b47e06
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable, ClassVar, List, Optional, Sequence, cast

import torch
from torch import Tensor
from torchtyping import TensorType as TT


Expand Down Expand Up @@ -126,7 +127,9 @@ def __repr__(self):
def device(self) -> torch.device:
return self.tensor.device

def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> States:
def __getitem__(
self, index: int | Sequence[int] | Sequence[bool] | Tensor
) -> States:
"""Access particular states of the batch."""
out = self.__class__(
self.tensor[index]
Expand Down

0 comments on commit 6b47e06

Please sign in to comment.