Skip to content

Commit e5b928d

Browse files
author
Daniel Mann
committed
typing
1 parent a81a16b commit e5b928d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

i6_models/parts/fsa.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
__all__ = ["TorchFsaBuilder", "WeightedFsa"]
44

55
from functools import reduce
6-
from typing import Iterable, NamedTuple, Tuple, TypeVar
6+
from typing import Iterable, NamedTuple, Tuple, Union
77

88
import numpy as np
99
import torch
@@ -37,7 +37,7 @@ def __mul__(self, scale: float) -> WeightedFsa:
3737
self.start_end_states,
3838
)
3939

40-
def to(self, device: torch.device) -> WeightedFsa:
40+
def to(self, device: Union[str, torch.device]) -> WeightedFsa:
4141
"""Move the tensors to a given device. This wraps around the
4242
PyTorch `Tensor.to(device)` method."""
4343
return WeightedFsa._make(tensor.to(device) for tensor in self)

0 commit comments

Comments
 (0)