Skip to content

[BUG] Saving and Loading of TFT model in etna #1120

Closed
1 task done
DionisMuzenitov opened this issue Feb 16, 2023 · 10 comments
Closed
1 task done

[BUG] Saving and Loading of TFT model in etna #1120

DionisMuzenitov opened this issue Feb 16, 2023 · 10 comments
Assignees
Labels
bug Something isn't working

Comments

@DionisMuzenitov
Copy link

🐛 Bug Report

I tried to implement and save etna TFT model from this example:
And it appears that save/load does not work as intended. Loading of saved model causes a
ValueError: Type names and field names must be valid identifiers: 'BaseModel.to_network_output.<locals>.Output'

Expected behavior

Normal load of saved model

How To Reproduce

Code

import pandas as pd
import numpy as np

from etna.datasets.tsdataset import TSDataset
from etna.pipeline import Pipeline
from etna.transforms import DateFlagsTransform
from etna.transforms import LagTransform
from etna.transforms import PytorchForecastingTransform
from pytorch_forecasting.data import GroupNormalizer
from etna.models.nn import TFTModel


original_df = pd.DataFrame(np.array([["2021-05-31", 1, 3],
                                     ["2021-06-07", 1, 6],
                                     ["2021-06-14", 1, 9],
                                     ["2021-06-21", 1, 12],
                                     ["2021-06-28", 1, 15]]),
                           columns=['timestamp', 'segment', 'target'])
original_df['timestamp'] = pd.to_datetime(original_df['timestamp'])
original_df['target'] = original_df['target'].astype(float)
df = TSDataset.to_dataset(original_df)
ts = TSDataset(df, freq="W-MON")

HORIZON = 1
transform_date = DateFlagsTransform(day_number_in_week=True, day_number_in_month=False, out_column="dateflag")
num_lags = 2
transform_lag = LagTransform(
    in_column="target",
    lags=[HORIZON + i for i in range(num_lags)],
    out_column="target_lag",
)

transform_tft = PytorchForecastingTransform(
    max_encoder_length=HORIZON,
    max_prediction_length=HORIZON,
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_reals=["target"],
    time_varying_known_categoricals=["dateflag_day_number_in_week"],
    static_categoricals=["segment"],
    target_normalizer=GroupNormalizer(groups=["segment"]),
)

model_tft = TFTModel(max_epochs=5, learning_rate=[0.1], gpus=0, batch_size=64)

pipeline_tft = Pipeline(
    model=model_tft,
    horizon=HORIZON,
    transforms=[transform_lag, transform_date, transform_tft],
)

pipeline_tft.save("666")
pipeline_tft.load("666")
pipeline_tft.fit(ts)
pipeline_tft.save("666")
pipeline_tft.load("666")

Environment

Python 3.9.13
etna 1.15.0

Additional context

It seems that this bug appears only after pipeline_tft.fit(ts)

Checklist

  • Bug appears at the latest library version
@DionisMuzenitov DionisMuzenitov added the bug Something isn't working label Feb 16, 2023
@DionisMuzenitov DionisMuzenitov changed the title [BUG] [BUG] Saving and Loading of TFT model in etna Feb 16, 2023
@github-project-automation github-project-automation bot moved this to Specification in etna board Feb 17, 2023
@Mr-Geekman Mr-Geekman moved this from Specification to Todo in etna board Feb 17, 2023
@Mr-Geekman
Copy link
Contributor

Can you please share with us the versions of the external libraries? Especially connected to torch like: torch, pytorch_forecasting, pytorch_lightning.

@DionisMuzenitov
Copy link
Author

Can you please share with us the versions of the external libraries? Especially connected to torch like: torch, pytorch_forecasting, pytorch_lightning.

Basically I created new env using conda only with python and pip, and then just 'pip install etna'. After that to get rid of warnings I used:
pip install etna[wandb]
pip install etna[prophet]
pip install tsfresh==0.19.0 && pip install protobuf==3.20.1

Here is the whole lib:
aiohttp 3.8.3
aiosignal 1.3.1
alabaster 0.7.12
alembic 1.9.3
anaconda-client 1.11.0
anaconda-project 0.11.1
antlr4-python3-runtime 4.9.3
anyio 3.5.0
appdirs 1.4.4
argon2-cffi 21.3.0
argon2-cffi-bindings 21.2.0
arrow 1.2.2
astroid 2.11.7
astropy 5.1
async-timeout 4.0.2
atomicwrites 1.4.0
attrs 21.4.0
Automat 20.2.0
autopage 0.5.1
autopep8 1.6.0
Babel 2.9.1
backcall 0.2.0
backports.functools-lru-cache 1.6.4
backports.tempfile 1.0
backports.weakref 1.0.post1
bcrypt 3.2.0
beautifulsoup4 4.11.1
binaryornot 0.4.4
bitarray 2.5.1
bkcharts 0.2
black 22.6.0
bleach 4.1.0
bokeh 2.4.3
boto3 1.24.28
botocore 1.27.28
Bottleneck 1.3.5
brotlipy 0.7.0
catboost 1.1.1
certifi 2022.9.14
cffi 1.15.1
chardet 4.0.0
charset-normalizer 2.0.4
click 8.0.4
cliff 4.1.0
cloudpickle 2.0.0
clyent 1.2.2
cmaes 0.9.1
cmd2 2.4.3
cmdstanpy 1.1.0
colorama 0.4.5
colorcet 3.0.0
colorlog 6.7.0
comtypes 1.1.10
conda-content-trust 0.1.3
conda-pack 0.6.0
conda-package-handling 1.9.0
conda-repo-cli 1.0.20
conda-verify 3.4.2
constantly 15.1.0
convertdate 2.4.0
cookiecutter 1.7.3
cryptography 37.0.1
cssselect 1.1.0
cycler 0.11.0
Cython 0.29.32
cytoolz 0.11.0
daal4py 2021.6.0
dask 2022.7.0
datashader 0.14.1
datashape 0.5.4
debugpy 1.5.1
decorator 5.1.1
defusedxml 0.7.1
Deprecated 1.2.13
diff-match-patch 20200713
dill 0.3.4
distributed 2022.7.0
docker-pycreds 0.4.0
docutils 0.18.1
entrypoints 0.4
ephem 4.1.4
et-xmlfile 1.1.0
etna 1.15.0
fastjsonschema 2.16.2
filelock 3.6.0
flake8 4.0.1
Flask 1.1.2
fonttools 4.25.0
frozenlist 1.3.3
fsspec 2022.7.1
future 0.18.2
gensim 4.1.2
gitdb 4.0.10
GitPython 3.1.30
glob2 0.7
graphviz 0.20.1
greenlet 1.1.1
h5py 3.7.0
HeapDict 1.0.1
hijri-converter 2.2.4
holidays 0.13
holoviews 1.15.0
hvplot 0.8.0
hydra-core 1.3.1
hydra-slayer 0.2.0
hyperlink 21.0.0
idna 3.3
imagecodecs 2021.8.26
imageio 2.19.3
imagesize 1.4.1
importlib-metadata 4.11.3
incremental 21.3.0
inflection 0.5.1
iniconfig 1.1.1
intake 0.6.5
intervaltree 3.1.0
ipykernel 6.15.2
ipython 7.31.1
ipython-genutils 0.2.0
ipywidgets 7.6.5
isort 5.9.3
itemadapter 0.3.0
itemloaders 1.0.4
itsdangerous 2.0.1
jdcal 1.4.1
jedi 0.18.1
jellyfish 0.9.0
Jinja2 2.11.3
jinja2-time 0.2.0
jmespath 0.10.0
joblib 1.1.0
json5 0.9.6
jsonschema 4.16.0
jupyter 1.0.0
jupyter_client 7.3.4
jupyter-console 6.4.3
jupyter_core 4.11.1
jupyter-server 1.18.1
jupyterlab 3.4.4
jupyterlab-pygments 0.1.2
jupyterlab-server 2.10.3
jupyterlab-widgets 1.0.0
keyring 23.4.0
kiwisolver 1.4.2
korean-lunar-calendar 0.3.1
lazy-object-proxy 1.6.0
libarchive-c 2.9
lightning-utilities 0.6.0.post0
llvmlite 0.38.0
locket 1.0.0
loguru 0.5.3
LunarCalendar 0.0.9
lxml 4.9.1
lz4 3.1.3
Mako 1.2.4
Markdown 3.3.4
MarkupSafe 2.0.1
matplotlib 3.5.2
matplotlib-inline 0.1.6
matrixprofile 1.1.10
mccabe 0.6.1
menuinst 1.4.19
mistune 0.8.4
mkl-fft 1.3.1
mkl-random 1.2.2
mkl-service 2.4.0
mock 4.0.3
mpmath 1.2.1
msgpack 1.0.3
multidict 6.0.4
multipledispatch 0.6.0
munkres 1.1.4
mypy-extensions 0.4.3
nbclassic 0.3.5
nbclient 0.5.13
nbconvert 6.4.4
nbformat 5.5.0
nest-asyncio 1.5.5
networkx 2.8.4
nltk 3.7
nose 1.3.7
notebook 6.4.12
numba 0.55.1
numexpr 2.8.3
numpy 1.21.5
numpydoc 1.4.0
olefile 0.46
omegaconf 2.3.0
openpyxl 3.0.10
optuna 2.10.1
packaging 21.3
pandas 1.4.4
pandocfilters 1.5.0
panel 0.13.1
param 1.12.0
paramiko 2.8.1
parsel 1.6.0
parso 0.8.3
partd 1.2.0
pathlib 1.0.1
pathspec 0.9.0
pathtools 0.1.2
patsy 0.5.2
pbr 5.11.1
pep8 1.7.1
pexpect 4.8.0
pickleshare 0.7.5
Pillow 9.2.0
pip 22.2.2
pkginfo 1.8.2
platformdirs 2.5.2
plotly 5.9.0
pluggy 1.0.0
pmdarima 2.0.2
poyo 0.5.0
prettytable 3.6.0
prometheus-client 0.14.1
promise 2.3
prompt-toolkit 3.0.20
prophet 1.1
Protego 0.1.16
protobuf 3.20.1
psutil 5.9.0
ptyprocess 0.7.0
py 1.11.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pycodestyle 2.8.0
pycosat 0.6.3
pycparser 2.21
pyct 0.4.8
pycurl 7.45.1
PyDispatcher 2.0.5
pydocstyle 6.1.1
pyerfa 2.0.0
pyflakes 2.4.0
Pygments 2.11.2
PyHamcrest 2.0.2
PyJWT 2.4.0
pylint 2.14.5
pyls-spyder 0.4.0
PyMeeus 0.5.12
PyNaCl 1.5.0
pyodbc 4.0.34
pyOpenSSL 22.0.0
pyparsing 3.0.9
pyperclip 1.8.2
pyreadline3 3.4.1
pyrsistent 0.18.0
PySocks 1.7.1
pytest 7.1.2
python-dateutil 2.8.2
python-lsp-black 1.0.0
python-lsp-jsonrpc 1.0.0
python-lsp-server 1.3.3
python-slugify 5.0.2
python-snappy 0.6.0
pytorch-forecasting 0.9.2
pytorch-lightning 1.9.0
pytz 2022.1
pyviz-comms 2.0.2
PyWavelets 1.3.0
pywin32 302
pywin32-ctypes 0.2.0
pywinpty 2.0.2
PyYAML 6.0
pyzmq 23.2.0
QDarkStyle 3.0.2
qstylizer 0.1.10
QtAwesome 1.0.3
qtconsole 5.2.2
QtPy 2.2.0
queuelib 1.5.0
regex 2022.7.9
requests 2.28.1
requests-file 1.5.1
rope 0.22.0
Rtree 0.9.7
ruamel-yaml-conda 0.15.100
ruptures 1.1.5
s3transfer 0.6.0
scikit-image 0.19.2
scikit-learn 1.0.2
scikit-learn-intelex 2021.20221004.171935
scipy 1.7.3
Scrapy 2.6.2
seaborn 0.11.2
Send2Trash 1.8.0
sentry-sdk 1.15.0
service-identity 18.1.0
setproctitle 1.3.2
setuptools 63.4.1
setuptools-git 1.2
shortuuid 1.0.11
sip 4.19.13
six 1.16.0
smart-open 5.2.1
smmap 5.0.0
sniffio 1.2.0
snowballstemmer 2.2.0
sortedcollections 2.1.0
sortedcontainers 2.4.0
soupsieve 2.3.1
Sphinx 5.0.2
sphinxcontrib-applehelp 1.0.2
sphinxcontrib-devhelp 1.0.2
sphinxcontrib-htmlhelp 2.0.0
sphinxcontrib-jsmath 1.0.1
sphinxcontrib-qthelp 1.0.3
sphinxcontrib-serializinghtml 1.1.5
spyder 5.2.2
spyder-kernels 2.2.1
SQLAlchemy 1.4.39
statsmodels 0.13.2
stevedore 4.1.1
stumpy 1.11.1
sympy 1.10.1
tables 3.6.1
tabulate 0.8.10
tbats 1.1.2
TBB 0.2
tblib 1.7.0
tenacity 8.0.1
terminado 0.13.1
testpath 0.6.0
text-unidecode 1.3
textdistance 4.2.1
threadpoolctl 2.2.0
three-merge 0.1.1
tifffile 2021.7.2
tinycss 0.4
tldextract 3.2.0
toml 0.10.2
tomli 2.0.1
tomlkit 0.11.1
toolz 0.11.2
torch 1.11.0
torchmetrics 0.11.1
tornado 6.1
tqdm 4.64.1
traitlets 5.1.1
tsfresh 0.19.0
Twisted 22.2.0
twisted-iocpsupport 1.0.2
typer 0.4.2
types-Deprecated 1.2.9
typing_extensions 4.3.0
ujson 5.4.0
Unidecode 1.2.0
urllib3 1.26.11
w3lib 1.21.0
wandb 0.12.21
watchdog 2.1.6
wcwidth 0.2.5
webencodings 0.5.1
websocket-client 0.58.0
Werkzeug 2.0.3
wheel 0.37.1
widgetsnbextension 3.5.2
win-inet-pton 1.1.0
win-unicode-console 0.5
win32-setctime 1.1.0
wincertstore 0.2
wrapt 1.14.1
xarray 0.20.1
xlrd 2.0.1
XlsxWriter 3.0.3
xlwings 0.27.15
yapf 0.31.0
yarl 1.8.2
zict 2.1.0
zipp 3.8.0
zope.interface 5.4.0

@Mr-Geekman
Copy link
Contributor

Haven't you installed etna[torch]? I don't really get how torch was installed.

@DionisMuzenitov
Copy link
Author

Haven't you installed etna[torch]? I don't really get how torch was installed.

Yes, I did. Just forgot to mention, sorry

@Mr-Geekman
Copy link
Contributor

Mr-Geekman commented Feb 17, 2023

Ok, I also have some problems on pytorch_lightning==1.9.2 (but the error is different). Probable temporary fix is to install pytorch_lightning==1.8.6.

Please, give a feedback if it works for you.

@Mr-Geekman Mr-Geekman moved this from Todo to In Progress in etna board Feb 17, 2023
@Mr-Geekman
Copy link
Contributor

For now, I discovered problem with pytorch_lightning>=1.9.1 and pytorch-forecasting==0.9.2. The error was:

Traceback (most recent call last):
  File "/Users/d.a.binin/Documents/tasks/temp/main.py", line 53, in <module>
    pipeline_tft.fit(ts)
  File "/Users/d.a.binin/Documents/tasks/temp/venv/lib/python3.9/site-packages/etna/pipeline/pipeline.py", line 55, in fit
    self.model.fit(self.ts)
  File "/Users/d.a.binin/Documents/tasks/temp/venv/lib/python3.9/site-packages/etna/models/decorators.py", line 15, in wrapper
    result = f(self, *args, **kwargs)
  File "/Users/d.a.binin/Documents/tasks/temp/venv/lib/python3.9/site-packages/etna/models/nn/tft.py", line 163, in fit
    self.model = self._from_dataset(pf_transform.pf_dataset_train)
  File "/Users/d.a.binin/Documents/tasks/temp/venv/lib/python3.9/site-packages/etna/models/nn/tft.py", line 125, in _from_dataset
    return TemporalFusionTransformer.from_dataset(
  File "/Users/d.a.binin/Documents/tasks/temp/venv/lib/python3.9/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py", line 351, in from_dataset
    return super().from_dataset(
  File "/Users/d.a.binin/Documents/tasks/temp/venv/lib/python3.9/site-packages/pytorch_forecasting/models/base_model.py", line 1437, in from_dataset
    return super().from_dataset(dataset, **new_kwargs)
  File "/Users/d.a.binin/Documents/tasks/temp/venv/lib/python3.9/site-packages/pytorch_forecasting/models/base_model.py", line 969, in from_dataset
    net = cls(**kwargs)
  File "/Users/d.a.binin/Documents/tasks/temp/venv/lib/python3.9/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py", line 140, in __init__
    super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs)
  File "/Users/d.a.binin/Documents/tasks/temp/venv/lib/python3.9/site-packages/pytorch_forecasting/models/base_model.py", line 238, in __init__
    {name: val for name, val in init_args.items() if name not in self.hparams and name not in ["self"]}
AttributeError: 'tuple' object has no attribute 'items'

@DionisMuzenitov
Copy link
Author

Ok, I also have some problems on pytorch_lightning==1.9.2 (but the error is different). Probable temporary fix is to install pytorch_lightning==1.8.6.

Please, give a feedback if it works for you.

I rebuilded etna with different versions of everything (my colleague found a well-working set up). I reckon this is not only about pytorch_lightning, because pytorch_lightning==1.9.0 was in my previous (the one with bug) and in my current set ups. Now everything works. Maybe it will be useful to add some builds with fixed lib versions, where everything definitely works.

@Mr-Geekman
Copy link
Contributor

We have fixed versions of packages in poetry.lock. You can install library from sources using poetry and it will use our fixed versions that are tested in our CI.

@Mr-Geekman
Copy link
Contributor

It looks like my bug (not original one) is connected to this one: 'tuple' object has no attribute 'items' in models.

@Mr-Geekman
Copy link
Contributor

Made a separate issue about discovered error: [BUG] DeepARModel and TFTModel don't work on pytorch_lightning>=1.9.1.

@github-project-automation github-project-automation bot moved this from In Progress to Done in etna board Feb 27, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Something isn't working
Projects
Status: Done
Development

No branches or pull requests

2 participants