We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a81a16b commit e5b928dCopy full SHA for e5b928d
i6_models/parts/fsa.py
@@ -3,7 +3,7 @@
3
__all__ = ["TorchFsaBuilder", "WeightedFsa"]
4
5
from functools import reduce
6
-from typing import Iterable, NamedTuple, Tuple, TypeVar
+from typing import Iterable, NamedTuple, Tuple, Union
7
8
import numpy as np
9
import torch
@@ -37,7 +37,7 @@ def __mul__(self, scale: float) -> WeightedFsa:
37
self.start_end_states,
38
)
39
40
- def to(self, device: torch.device) -> WeightedFsa:
+ def to(self, device: Union[str, torch.device]) -> WeightedFsa:
41
"""Move the tensors to a given device. This wraps around the
42
PyTorch `Tensor.to(device)` method."""
43
return WeightedFsa._make(tensor.to(device) for tensor in self)
0 commit comments