diff --git a/rul_datasets/adaption.py b/rul_datasets/adaption.py index 9ec024d..dc6e21b 100644 --- a/rul_datasets/adaption.py +++ b/rul_datasets/adaption.py @@ -2,15 +2,13 @@ import warnings from copy import deepcopy -from typing import List, Optional, Any, Tuple, Callable, Sequence, Union, cast +from typing import List, Optional, Any, Tuple, Callable, Sequence, cast import numpy as np import pytorch_lightning as pl import torch -from torch.utils.data import DataLoader, Dataset -from torch.utils.data.dataset import ConcatDataset, TensorDataset +from torch.utils.data import DataLoader, Dataset, ConcatDataset -from rul_datasets import utils from rul_datasets.core import PairedRulDataset, RulDataModule, RulDataset diff --git a/rul_datasets/core.py b/rul_datasets/core.py index b987edb..3618123 100644 --- a/rul_datasets/core.py +++ b/rul_datasets/core.py @@ -416,7 +416,7 @@ class RulDataset(Dataset): def __init__( self, features: List[np.ndarray], - *targets: Tuple[List[np.ndarray]], + *targets: List[np.ndarray], copy_tensors: bool = False, ) -> None: """ diff --git a/rul_datasets/utils.py b/rul_datasets/utils.py index e2273bf..68469ca 100644 --- a/rul_datasets/utils.py +++ b/rul_datasets/utils.py @@ -169,7 +169,7 @@ def to_tensor( ) -> Tuple[List[torch.Tensor], ...]: dtype = torch.float32 tensor_feats = [feature_to_tensor(f, dtype, copy) for f in features] - convert = torch.tensor if copy else torch.as_tensor + convert: Callable = torch.tensor if copy else torch.as_tensor # type: ignore tensor_targets = [[convert(t, dtype=dtype) for t in target] for target in targets] return tensor_feats, *tensor_targets