Skip to content

Commit

Permalink
Merge pull request #572 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Expose more models for tuning, bump dependency PyGrinder version num, and overwrite torch.autocast
  • Loading branch information
WenjieDu authored Feb 21, 2025
2 parents 872a296 + adbac5b commit a3afa98
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 3 deletions.
22 changes: 22 additions & 0 deletions pypots/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .utils import load_package_from_path
from ..classification import BRITS as BRITS_classification
from ..classification import GRUD as GRUD_classification
from ..classification import CSAI as CSAI_classification
from ..classification import Raindrop
from ..clustering import CRLI, VaDER
from ..data.saving.h5 import load_dict_from_h5
Expand Down Expand Up @@ -47,6 +48,16 @@
TiDE,
Reformer,
RevIN_SCINet,
FEDformer,
TCN,
ImputeFormer,
TimeMixer,
ModernTCN,
TEFN,
CSAI,
SegRNN,
TRMF,
TimeLLM,
)
from ..optim import Adam
from ..utils.logging import logger
Expand Down Expand Up @@ -89,10 +100,21 @@
"pypots.imputation.TiDE": TiDE,
"pypots.imputation.Transformer": Transformer,
"pypots.imputation.USGAN": USGAN,
"pypots.imputation.FEDformer": FEDformer,
"pypots.imputation.TCN": TCN,
"pypots.imputation.ImputeFormer": ImputeFormer,
"pypots.imputation.TimeMixer": TimeMixer,
"pypots.imputation.ModernTCN": ModernTCN,
"pypots.imputation.TEFN": TEFN,
"pypots.imputation.CSAI": CSAI,
"pypots.imputation.SegRNN": SegRNN,
"pypots.imputation.TRMF": TRMF,
"pypots.imputation.TimeLLM": TimeLLM,
# classification models
"pypots.classification.BRITS": BRITS_classification,
"pypots.classification.GRUD": GRUD_classification,
"pypots.classification.Raindrop": Raindrop,
"pypots.classification.CSAI": CSAI_classification,
# clustering models
"pypots.clustering.CRLI": CRLI,
"pypots.clustering.VaDER": VaDER,
Expand Down
13 changes: 12 additions & 1 deletion pypots/nn/modules/reformer/local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@
from einops import rearrange
from einops import repeat, pack, unpack
from torch import nn, einsum
from torch.cuda.amp import autocast

TOKEN_SELF_ATTN_VALUE = -5e4


# overwrite autocast to make it compatible with both torch >=2.4 and <2.4
def autocast(**kwargs):
if torch.__version__ >= "2.4":
from torch.cuda.amp import autocast

return autocast(**kwargs)
else:
from torch.amp import autocast

return autocast("cuda", **kwargs)


def exists(val):
return val is not None

Expand Down
2 changes: 1 addition & 1 deletion requirements/conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies:
- conda-forge::tensorboard
- conda-forge::scikit-learn
- conda-forge::transformers
- conda-forge::pygrinder >=0.6.4
- conda-forge::pygrinder >=0.7
- conda-forge::tsdb >=0.6.1
- conda-forge::benchpots >=0.3.2
- conda-forge::ai4ts
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ scikit-learn
transformers
torch>=1.10.0
tsdb>=0.6.1
pygrinder>=0.6.4
pygrinder>=0.7
benchpots>=0.3.2
ai4ts

0 comments on commit a3afa98

Please sign in to comment.