From af129c1ba74ef3e6a8feed8a6357119c32a6d537 Mon Sep 17 00:00:00 2001 From: Dmitry Bunin Date: Wed, 4 Oct 2023 16:02:59 +0300 Subject: [PATCH 1/6] feature: redo saving for dl models --- etna/core/mixins.py | 38 ++-- etna/ensembles/mixins.py | 15 +- .../prediction_intervals/mixins.py | 14 +- etna/models/base.py | 4 +- etna/models/mixins.py | 160 +++++++++++++++-- etna/models/nn/deepar.py | 11 +- etna/models/nn/deepstate/deepstate.py | 1 + etna/models/nn/mlp.py | 1 + etna/models/nn/nbeats/nets.py | 2 + etna/models/nn/patchts.py | 6 +- etna/models/nn/rnn.py | 1 + etna/models/nn/tft.py | 11 +- etna/models/nn/utils.py | 1 + etna/pipeline/mixins.py | 23 +-- poetry.lock | 2 +- pyproject.toml | 2 +- tests/test_models/test_mixins.py | 170 +++++++++++++++--- tests/test_models/{nn => test_nn}/conftest.py | 0 .../{nn => test_nn}/deepstate/test_lds.py | 0 .../{nn => test_nn}/deepstate/test_ssm.py | 0 .../{nn => test_nn}/nbeats/test_blocks.py | 0 .../{nn => test_nn}/nbeats/test_nbeats.py | 0 .../nbeats/test_nbeats_metrics.py | 0 .../nbeats/test_nbeats_nets.py | 0 .../nbeats/test_nbeats_utils.py | 0 .../{nn => test_nn}/test_deepar.py | 0 .../{nn => test_nn}/test_deepstate.py | 12 ++ tests/test_models/{nn => test_nn}/test_mlp.py | 0 .../{nn => test_nn}/test_patchts.py | 6 + tests/test_models/{nn => test_nn}/test_rnn.py | 0 tests/test_models/{nn => test_nn}/test_tft.py | 0 .../test_models/{nn => test_nn}/test_utils.py | 0 32 files changed, 378 insertions(+), 102 deletions(-) rename tests/test_models/{nn => test_nn}/conftest.py (100%) rename tests/test_models/{nn => test_nn}/deepstate/test_lds.py (100%) rename tests/test_models/{nn => test_nn}/deepstate/test_ssm.py (100%) rename tests/test_models/{nn => test_nn}/nbeats/test_blocks.py (100%) rename tests/test_models/{nn => test_nn}/nbeats/test_nbeats.py (100%) rename tests/test_models/{nn => test_nn}/nbeats/test_nbeats_metrics.py (100%) rename tests/test_models/{nn => test_nn}/nbeats/test_nbeats_nets.py (100%) rename tests/test_models/{nn => test_nn}/nbeats/test_nbeats_utils.py (100%) rename tests/test_models/{nn => test_nn}/test_deepar.py (100%) rename tests/test_models/{nn => test_nn}/test_deepstate.py (76%) rename tests/test_models/{nn => test_nn}/test_mlp.py (100%) rename tests/test_models/{nn => test_nn}/test_patchts.py (91%) rename tests/test_models/{nn => test_nn}/test_rnn.py (100%) rename tests/test_models/{nn => test_nn}/test_tft.py (100%) rename tests/test_models/{nn => test_nn}/test_utils.py (100%) diff --git a/etna/core/mixins.py b/etna/core/mixins.py index c3504dcda..f2a44fdf6 100644 --- a/etna/core/mixins.py +++ b/etna/core/mixins.py @@ -27,8 +27,8 @@ def __repr__(self): """Get default representation of etna object.""" # TODO: add tests default behaviour for all registered objects args_str_representation = "" - init_args = inspect.signature(self.__init__).parameters - for arg, param in init_args.items(): + init_parameters = self._get_init_parameters() + for arg, param in init_parameters.items(): if param.kind == param.VAR_POSITIONAL: continue elif param.kind == param.VAR_KEYWORD: @@ -43,6 +43,9 @@ def __repr__(self): args_str_representation += f"{arg} = {repr(value)}, " return f"{self.__class__.__name__}({args_str_representation})" + def _get_init_parameters(self): + return inspect.signature(self.__init__).parameters + @staticmethod def _get_target_from_class(value: Any): if value is None: @@ -84,9 +87,9 @@ def _parse_value(value: Any) -> Any: def to_dict(self): """Collect all information about etna object in dict.""" - init_args = inspect.signature(self.__init__).parameters + init_parameters = self._get_init_parameters() params = {} - for arg in init_args.keys(): + for arg in init_parameters.keys(): value = self.__dict__[arg] if value is None: continue @@ -226,9 +229,26 @@ def _save_metadata(self, archive: zipfile.ZipFile): with archive.open("metadata.json", "w") as output_file: output_file.write(metadata_bytes) - def _save_state(self, archive: zipfile.ZipFile): - with archive.open("object.pkl", "w", force_zip64=True) as output_file: - dill.dump(self, output_file) + def _save_state(self, archive: zipfile.ZipFile, skip_attributes: Sequence[str] = ()): + saved_attributes = {} + try: + # remove attributes we can't easily save + saved_attributes = {attr: getattr(self, attr) for attr in skip_attributes} + for attr in skip_attributes: + delattr(self, attr) + + # save the remaining part + with archive.open("object.pkl", "w", force_zip64=True) as output_file: + dill.dump(self, output_file) + finally: + # restore the skipped attributes + for attr, value in saved_attributes.items(): + setattr(self, attr, value) + + def _save(self, path: pathlib.Path, skip_attributes: Sequence[str] = ()): + with zipfile.ZipFile(path, "w") as archive: + self._save_metadata(archive) + self._save_state(archive, skip_attributes=skip_attributes) def save(self, path: pathlib.Path): """Save the object. @@ -238,9 +258,7 @@ def save(self, path: pathlib.Path): path: Path to save object to. """ - with zipfile.ZipFile(path, "w") as archive: - self._save_metadata(archive) - self._save_state(archive) + self._save(path=path) @classmethod def _load_metadata(cls, archive: zipfile.ZipFile) -> Dict[str, Any]: diff --git a/etna/ensembles/mixins.py b/etna/ensembles/mixins.py index a0445f950..a2772b193 100644 --- a/etna/ensembles/mixins.py +++ b/etna/ensembles/mixins.py @@ -85,18 +85,7 @@ def save(self, path: pathlib.Path): self.pipelines: List[BasePipeline] self.ts: Optional[TSDataset] - pipelines = self.pipelines - ts = self.ts - try: - # extract attributes we can't easily save - delattr(self, "pipelines") - delattr(self, "ts") - - # save the remaining part - super().save(path=path) - finally: - self.pipelines = pipelines - self.ts = ts + self._save(path=path, skip_attributes=["pipelines", "ts"]) with zipfile.ZipFile(path, "a") as archive: with tempfile.TemporaryDirectory() as _temp_dir: @@ -106,7 +95,7 @@ def save(self, path: pathlib.Path): pipelines_dir = temp_dir / "pipelines" pipelines_dir.mkdir() num_digits = 8 - for i, pipeline in enumerate(pipelines): + for i, pipeline in enumerate(self.pipelines): save_name = f"{i:0{num_digits}d}.zip" pipeline_save_path = pipelines_dir / save_name pipeline.save(pipeline_save_path) diff --git a/etna/experimental/prediction_intervals/mixins.py b/etna/experimental/prediction_intervals/mixins.py index 56e5719d5..78f8acaf8 100644 --- a/etna/experimental/prediction_intervals/mixins.py +++ b/etna/experimental/prediction_intervals/mixins.py @@ -33,17 +33,7 @@ def save(self, path: pathlib.Path): """ self.pipeline: BasePipeline - pipeline = self.pipeline - - try: - # extract pipeline to save it with its own method later - delattr(self, "pipeline") - - # save the remaining part - super().save(path=path) - - finally: - self.pipeline = pipeline + self._save(path=path, skip_attributes=["pipeline"]) with zipfile.ZipFile(path, "a") as archive: with tempfile.TemporaryDirectory() as _temp_dir: @@ -51,7 +41,7 @@ def save(self, path: pathlib.Path): # save pipeline separately and add to the archive pipeline_save_path = temp_dir / "pipeline.zip" - pipeline.save(path=pipeline_save_path) + self.pipeline.save(path=pipeline_save_path) archive.write(pipeline_save_path, "pipeline.zip") diff --git a/etna/models/base.py b/etna/models/base.py index 41bc083a1..67195b1cb 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -20,7 +20,7 @@ from etna.distributions import BaseDistribution from etna.loggers import tslogger from etna.models.decorators import log_decorator -from etna.models.mixins import SaveNNMixin +from etna.models.mixins import SaveDeepBaseModelMixin if SETTINGS.torch_required: import torch @@ -470,7 +470,7 @@ def validation_step(self, batch: dict, *args, **kwargs): # type: ignore return loss -class DeepBaseModel(DeepBaseAbstractModel, SaveNNMixin, NonPredictionIntervalContextRequiredAbstractModel): +class DeepBaseModel(DeepBaseAbstractModel, SaveDeepBaseModelMixin, NonPredictionIntervalContextRequiredAbstractModel): """Class for partially implemented interfaces for holding deep models.""" def __init__( diff --git a/etna/models/mixins.py b/etna/models/mixins.py index d42ce5985..0d45e8268 100644 --- a/etna/models/mixins.py +++ b/etna/models/mixins.py @@ -1,3 +1,4 @@ +import pathlib import zipfile from abc import ABC from abc import abstractmethod @@ -11,12 +12,20 @@ import dill import numpy as np import pandas as pd +from hydra_slayer import get_factory from typing_extensions import Self +from etna import SETTINGS +from etna.core.mixins import BaseMixin from etna.core.mixins import SaveMixin from etna.datasets.tsdataset import TSDataset from etna.models.decorators import log_decorator +if SETTINGS.torch_required: + import torch + from pytorch_lightning import LightningModule + from pytorch_lightning import Trainer + class ModelForecastingMixin(ABC): """Base class for model mixins.""" @@ -629,25 +638,152 @@ def get_model(self) -> Any: return self._base_model.get_model() -class SaveNNMixin(SaveMixin): - """Implementation of ``AbstractSaveable`` torch related classes. +def _load_object(class_name, class_parameters): + cls = get_factory(class_name) + obj = cls(**class_parameters) + return obj + + +def _save_pl_model(archive: zipfile.ZipFile, filename: str, model: "LightningModule"): + with archive.open(filename, "w", force_zip64=True) as output_file: + to_save = { + "class": BaseMixin._get_target_from_class(model), + "hyperparameters": dict(model.hparams), + "state_dict": model.state_dict(), + } + torch.save(to_save, output_file, pickle_module=dill) + + +def _load_pl_model(archive: zipfile.ZipFile, filename: str) -> "LightningModule": + with archive.open(filename, "r") as input_file: + net_loaded = torch.load(input_file, pickle_module=dill) + + # fixes the [issue](https://github.com/Lightning-AI/lightning/issues/18405) with `save_hyperparameters` + net = _load_object(class_name=net_loaded["class"], class_parameters=net_loaded["hyperparameters"]) + + net.load_state_dict(net_loaded["state_dict"]) + + return net - It saves object to the zip archive with 2 files: + +class SaveDeepBaseModelMixin(SaveMixin): + """Implementation of ``AbstractSaveable`` for :py:class:`~etna.models.base.DeepBaseModel` models. + + It saves object to the zip archive with files: * metadata.json: contains library version and class name. - * object.pt: object saved by ``torch.save``. + * object.pkl: pickled without ``self.net`` and ``self.trainer``. + + * net.pt: parameters of ``self.net`` saved by ``torch.save``. """ - def _save_state(self, archive: zipfile.ZipFile): - import torch + def save(self, path: pathlib.Path): + """Save the object. + + Parameters + ---------- + path: + Path to save object to. + """ + from etna.models.base import DeepBaseNet + + self.trainer: Optional[Trainer] + self.net: DeepBaseNet + + self._save(path=path, skip_attributes=["trainer", "net"]) - with archive.open("object.pt", "w", force_zip64=True) as output_file: - torch.save(self, output_file, pickle_module=dill) + with zipfile.ZipFile(path, "a") as archive: + _save_pl_model(archive=archive, filename="net.pt", model=self.net) @classmethod - def _load_state(cls, archive: zipfile.ZipFile) -> Self: - import torch + def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self: + """Load an object. + + Warning + ------- + This method uses :py:mod:`dill` module which is not secure. + It is possible to construct malicious data which will execute arbitrary code during loading. + Never load data that could have come from an untrusted source, or that could have been tampered with. + + Parameters + ---------- + path: + Path to load object from. + ts: + TSDataset to set into loaded pipeline. + + Returns + ------- + : + Loaded object. + """ + obj = super().load(path=path) + + with zipfile.ZipFile(path, "r") as archive: + obj.net = _load_pl_model(archive=archive, filename="net.pt") + obj.trainer = None + + return obj + + +class SavePytorchForecastingMixin(SaveMixin): + """Implementation of ``AbstractSaveable`` for :py:mod:`pytorch_forecasting` models. + + It saves object to the zip archive with files: + + * metadata.json: contains library version and class name. + + * object.pkl: pickled without ``self.model`` and ``self.trainer``. + + * model.pt: parameters of ``self.model`` saved by ``torch.save`` if model was fitted. + """ + + def save(self, path: pathlib.Path): + """Save the object. + + Parameters + ---------- + path: + Path to save object to. + """ + self.trainer: Optional[Trainer] + self.model: Optional[LightningModule] + + if self.model is None: + self._save(path=path, skip_attributes=["trainer"]) + else: + self._save(path=path, skip_attributes=["trainer", "model"]) + with zipfile.ZipFile(path, "a") as archive: + _save_pl_model(archive=archive, filename="model.pt", model=self.model) + + @classmethod + def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self: + """Load an object. + + Warning + ------- + This method uses :py:mod:`dill` module which is not secure. + It is possible to construct malicious data which will execute arbitrary code during loading. + Never load data that could have come from an untrusted source, or that could have been tampered with. + + Parameters + ---------- + path: + Path to load object from. + ts: + TSDataset to set into loaded pipeline. + + Returns + ------- + : + Loaded object. + """ + obj = super().load(path=path) + obj.trainer = None + + if not hasattr(obj, "model"): + with zipfile.ZipFile(path, "r") as archive: + obj.model = _load_pl_model(archive=archive, filename="model.pt") - with archive.open("object.pt", "r") as input_file: - return torch.load(input_file, pickle_module=dill) + return obj diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index 6ae68e2b6..037b439af 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -2,7 +2,6 @@ from typing import Dict from typing import Optional from typing import Sequence -from typing import Union import pandas as pd @@ -13,7 +12,7 @@ from etna.distributions import IntDistribution from etna.models.base import PredictionIntervalContextRequiredAbstractModel from etna.models.base import log_decorator -from etna.models.mixins import SaveNNMixin +from etna.models.mixins import SavePytorchForecastingMixin from etna.models.nn.utils import PytorchForecastingDatasetBuilder from etna.models.nn.utils import PytorchForecastingMixin from etna.models.nn.utils import _DeepCopyMixin @@ -25,9 +24,12 @@ from pytorch_forecasting.metrics import NormalDistributionLoss from pytorch_forecasting.models import DeepAR from pytorch_lightning import LightningModule + from pytorch_lightning import Trainer -class DeepARModel(_DeepCopyMixin, PytorchForecastingMixin, SaveNNMixin, PredictionIntervalContextRequiredAbstractModel): +class DeepARModel( + _DeepCopyMixin, PytorchForecastingMixin, SavePytorchForecastingMixin, PredictionIntervalContextRequiredAbstractModel +): """Wrapper for :py:class:`pytorch_forecasting.models.deepar.DeepAR`. Note @@ -123,7 +125,8 @@ def __init__( self.loss = loss self.trainer_params = trainer_params if trainer_params is not None else dict() self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict() - self.model: Optional[Union[LightningModule, DeepAR]] = None + self.model: Optional[DeepAR] = None + self.trainer: Optional[Trainer] = None self._last_train_timestamp = None def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule: diff --git a/etna/models/nn/deepstate/deepstate.py b/etna/models/nn/deepstate/deepstate.py index c1dcfb965..a6b0f21ae 100644 --- a/etna/models/nn/deepstate/deepstate.py +++ b/etna/models/nn/deepstate/deepstate.py @@ -54,6 +54,7 @@ def __init__( Parameters for optimizer for Adam optimizer (api reference :py:class:`torch.optim.Adam`) """ super().__init__() + self.save_hyperparameters() self.ssm = ssm self.input_size = input_size self.num_layers = num_layers diff --git a/etna/models/nn/mlp.py b/etna/models/nn/mlp.py index 6d9b8a7ff..b13452068 100644 --- a/etna/models/nn/mlp.py +++ b/etna/models/nn/mlp.py @@ -55,6 +55,7 @@ def __init__( parameters for optimizer for Adam optimizer (api reference :py:class:`torch.optim.Adam`) """ super().__init__() + self.save_hyperparameters() self.input_size = input_size self.hidden_size = hidden_size self.lr = lr diff --git a/etna/models/nn/nbeats/nets.py b/etna/models/nn/nbeats/nets.py index b07034556..725c77fb5 100644 --- a/etna/models/nn/nbeats/nets.py +++ b/etna/models/nn/nbeats/nets.py @@ -210,6 +210,7 @@ def __init__( lr=lr, optimizer_params=optimizer_params, ) + self.save_hyperparameters() class NBeatsGenericNet(NBeatsBaseNet): @@ -270,3 +271,4 @@ def __init__( lr=lr, optimizer_params=optimizer_params, ) + self.save_hyperparameters() diff --git a/etna/models/nn/patchts.py b/etna/models/nn/patchts.py index 332718347..478b49831 100644 --- a/etna/models/nn/patchts.py +++ b/etna/models/nn/patchts.py @@ -96,12 +96,14 @@ def __init__( parameters for optimizer for Adam optimizer (api reference :py:class:`torch.optim.Adam`) """ super().__init__() + self.save_hyperparameters() + self.encoder_length = encoder_length self.patch_len = patch_len + self.stride = stride self.num_layers = num_layers - self.feedforward_size = feedforward_size self.hidden_size = hidden_size + self.feedforward_size = feedforward_size self.nhead = nhead - self.stride = stride self.loss = loss encoder_layers = nn.TransformerEncoderLayer( diff --git a/etna/models/nn/rnn.py b/etna/models/nn/rnn.py index 2576de7d4..49bc37f2f 100644 --- a/etna/models/nn/rnn.py +++ b/etna/models/nn/rnn.py @@ -59,6 +59,7 @@ def __init__( parameters for optimizer for Adam optimizer (api reference :py:class:`torch.optim.Adam`) """ super().__init__() + self.save_hyperparameters() self.num_layers = num_layers self.input_size = input_size self.hidden_size = hidden_size diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index 57a3f1f4f..a38aea595 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -3,7 +3,6 @@ from typing import Dict from typing import Optional from typing import Sequence -from typing import Union import pandas as pd @@ -14,7 +13,7 @@ from etna.distributions import IntDistribution from etna.models.base import PredictionIntervalContextRequiredAbstractModel from etna.models.base import log_decorator -from etna.models.mixins import SaveNNMixin +from etna.models.mixins import SavePytorchForecastingMixin from etna.models.nn.utils import PytorchForecastingDatasetBuilder from etna.models.nn.utils import PytorchForecastingMixin from etna.models.nn.utils import _DeepCopyMixin @@ -25,9 +24,12 @@ from pytorch_forecasting.metrics import QuantileLoss from pytorch_forecasting.models import TemporalFusionTransformer from pytorch_lightning import LightningModule + from pytorch_lightning import Trainer -class TFTModel(_DeepCopyMixin, PytorchForecastingMixin, SaveNNMixin, PredictionIntervalContextRequiredAbstractModel): +class TFTModel( + _DeepCopyMixin, PytorchForecastingMixin, SavePytorchForecastingMixin, PredictionIntervalContextRequiredAbstractModel +): """Wrapper for :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`. Note @@ -126,7 +128,8 @@ def __init__( self.loss = loss self.trainer_params = trainer_params if trainer_params is not None else dict() self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict() - self.model: Optional[Union[LightningModule, TemporalFusionTransformer]] = None + self.model: Optional[TemporalFusionTransformer] = None + self.trainer: Optional[Trainer] = None self._last_train_timestamp = None self.kwargs = kwargs diff --git a/etna/models/nn/utils.py b/etna/models/nn/utils.py index 1bf1cd16d..dd1a0e95d 100644 --- a/etna/models/nn/utils.py +++ b/etna/models/nn/utils.py @@ -228,6 +228,7 @@ class PytorchForecastingMixin: train_batch_size: int test_batch_size: int encoder_length: int + trainer: Optional[pl.Trainer] @log_decorator def fit(self, ts: TSDataset): diff --git a/etna/pipeline/mixins.py b/etna/pipeline/mixins.py index b2ab2426d..72acb12b5 100644 --- a/etna/pipeline/mixins.py +++ b/etna/pipeline/mixins.py @@ -144,7 +144,7 @@ def params_to_tune(self) -> Dict[str, BaseDistribution]: class SaveModelPipelineMixin(SaveMixin): """Implementation of ``AbstractSaveable`` abstract class for pipelines with model inside. - It saves object to the zip archive with 4 entities: + It saves object to the zip archive with entities: * metadata.json: contains library version and class name. @@ -167,22 +167,7 @@ def save(self, path: pathlib.Path): self.transforms: Sequence[Transform] self.ts: Optional[TSDataset] - model = self.model - transforms = self.transforms - ts = self.ts - - try: - # extract attributes we can't easily save - delattr(self, "model") - delattr(self, "transforms") - delattr(self, "ts") - - # save the remaining part - super().save(path=path) - finally: - self.model = model - self.transforms = transforms - self.ts = ts + self._save(path=path, skip_attributes=["model", "transforms", "ts"]) with zipfile.ZipFile(path, "a") as archive: with tempfile.TemporaryDirectory() as _temp_dir: @@ -190,14 +175,14 @@ def save(self, path: pathlib.Path): # save model separately model_save_path = temp_dir / "model.zip" - model.save(model_save_path) + self.model.save(model_save_path) archive.write(model_save_path, "model.zip") # save transforms separately transforms_dir = temp_dir / "transforms" transforms_dir.mkdir() num_digits = 8 - for i, transform in enumerate(transforms): + for i, transform in enumerate(self.transforms): save_name = f"{i:0{num_digits}d}.zip" transform_save_path = transforms_dir / save_name transform.save(transform_save_path) diff --git a/poetry.lock b/poetry.lock index 317436a08..59bd6dd70 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6260,4 +6260,4 @@ wandb = ["wandb"] [metadata] lock-version = "2.0" python-versions = ">=3.8.0, <3.11.0" -content-hash = "0f6dd03c7a5aa30597ef2aca5fc02fa4e6ca410a11f4b834b20dc630dcbd1384" +content-hash = "7e7273757b7cd622f19810a5d12363fa7438d6054def4090a9a91777fab0123b" diff --git a/pyproject.toml b/pyproject.toml index 7271706e3..2fdd87703 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,7 @@ types-Deprecated = "1.2.9" prophet = {version = "^1.0", optional = true} -torch = {version = ">=1.8.0,<1.12.0", optional = true} +torch = {version = ">=1.8.0,<3", optional = true} pytorch-forecasting = {version = "^0.9.0", optional = true} pytorch-lightning = {version = "*", optional = true} diff --git a/tests/test_models/test_mixins.py b/tests/test_models/test_mixins.py index 099040720..eb164aff4 100644 --- a/tests/test_models/test_mixins.py +++ b/tests/test_models/test_mixins.py @@ -1,18 +1,21 @@ import json import pathlib +import pickle +from copy import deepcopy from unittest.mock import MagicMock from unittest.mock import patch from zipfile import ZipFile -import dill import numpy as np import pytest from etna import SETTINGS from etna.datasets import TSDataset +from etna.models.nn.mlp import MLPNet if SETTINGS.torch_required: import torch + from pytorch_lightning import Trainer import pandas as pd @@ -23,7 +26,8 @@ from etna.models.mixins import PerSegmentModelMixin from etna.models.mixins import PredictionIntervalContextIgnorantModelMixin from etna.models.mixins import PredictionIntervalContextRequiredModelMixin -from etna.models.mixins import SaveNNMixin +from etna.models.mixins import SaveDeepBaseModelMixin +from etna.models.mixins import SavePytorchForecastingMixin class DummyPredictAdapter(BaseAdapter): @@ -145,55 +149,177 @@ def test_calling_private_prediction( ) -class DummyNN(SaveNNMixin): - def __init__(self, a, b): - self.a = torch.tensor(a) - self.b = torch.tensor(b) +class DummyDeepBaseModel(SaveDeepBaseModelMixin): + def __init__(self, size: int): + self.size = size + self.net = MLPNet(input_size=size, hidden_size=[size], lr=0.01, loss=torch.nn.MSELoss(), optimizer_params=None) + self.trainer = Trainer() -def test_save_nn_mixin_save(tmp_path): - dummy = DummyNN(a=1, b=2) +class DummyPytorchForecastingModel(SavePytorchForecastingMixin): + def __init__(self, size: int, init_model: bool): + self.size = size + self.init_model = init_model + if init_model: + self.model = MLPNet( + input_size=size, hidden_size=[size], lr=0.01, loss=torch.nn.MSELoss(), optimizer_params=None + ) + else: + self.model = None + self.trainer = Trainer() + + +def test_save_native_mixin_save(tmp_path): + dummy = DummyDeepBaseModel(size=1) dir_path = pathlib.Path(tmp_path) path = dir_path.joinpath("dummy.zip") + initial_dummy = deepcopy(dummy) dummy.save(path) with ZipFile(path, "r") as zip_file: files = zip_file.namelist() - assert sorted(files) == ["metadata.json", "object.pt"] + assert sorted(files) == ["metadata.json", "net.pt", "object.pkl"] with zip_file.open("metadata.json", "r") as input_file: metadata_bytes = input_file.read() metadata_str = metadata_bytes.decode("utf-8") metadata = json.loads(metadata_str) assert sorted(metadata.keys()) == ["class", "etna_version"] - assert metadata["class"] == "tests.test_models.test_mixins.DummyNN" + assert metadata["class"] == "tests.test_models.test_mixins.DummyDeepBaseModel" - with zip_file.open("object.pt", "r") as input_file: - loaded_dummy = torch.load(input_file, pickle_module=dill) - assert loaded_dummy.a == dummy.a - assert loaded_dummy.b == dummy.b + with zip_file.open("object.pkl", "r") as input_file: + loaded_obj = pickle.load(input_file) + assert loaded_obj.size == dummy.size + # basic check that we didn't break dummy object itself + assert dummy.size == initial_dummy.size + assert isinstance(dummy.net, MLPNet) + assert isinstance(dummy.trainer, Trainer) -def test_save_mixin_load_ok(recwarn, tmp_path): - dummy = DummyNN(a=1, b=2) + +def test_save_pf_mixin_save_without_model(tmp_path): + dummy = DummyPytorchForecastingModel(size=1, init_model=False) dir_path = pathlib.Path(tmp_path) path = dir_path.joinpath("dummy.zip") + initial_dummy = deepcopy(dummy) dummy.save(path) - loaded_dummy = DummyNN.load(path) - assert loaded_dummy.a == dummy.a - assert loaded_dummy.b == dummy.b + with ZipFile(path, "r") as zip_file: + files = zip_file.namelist() + assert sorted(files) == ["metadata.json", "object.pkl"] + + with zip_file.open("metadata.json", "r") as input_file: + metadata_bytes = input_file.read() + metadata_str = metadata_bytes.decode("utf-8") + metadata = json.loads(metadata_str) + assert sorted(metadata.keys()) == ["class", "etna_version"] + assert metadata["class"] == "tests.test_models.test_mixins.DummyPytorchForecastingModel" + + with zip_file.open("object.pkl", "r") as input_file: + loaded_obj = pickle.load(input_file) + assert loaded_obj.size == dummy.size + + # basic check that we didn't break dummy object itself + assert dummy.size == initial_dummy.size + assert dummy.model is None + assert isinstance(dummy.trainer, Trainer) + + +def test_save_pf_mixin_save_with_model(tmp_path): + dummy = DummyPytorchForecastingModel(size=1, init_model=True) + dir_path = pathlib.Path(tmp_path) + path = dir_path.joinpath("dummy.zip") + + initial_dummy = deepcopy(dummy) + dummy.save(path) + + with ZipFile(path, "r") as zip_file: + files = zip_file.namelist() + assert sorted(files) == ["metadata.json", "model.pt", "object.pkl"] + + with zip_file.open("metadata.json", "r") as input_file: + metadata_bytes = input_file.read() + metadata_str = metadata_bytes.decode("utf-8") + metadata = json.loads(metadata_str) + assert sorted(metadata.keys()) == ["class", "etna_version"] + assert metadata["class"] == "tests.test_models.test_mixins.DummyPytorchForecastingModel" + + with zip_file.open("object.pkl", "r") as input_file: + loaded_obj = pickle.load(input_file) + assert loaded_obj.size == dummy.size + + # basic check that we didn't break dummy object itself + assert dummy.size == initial_dummy.size + assert isinstance(dummy.model, MLPNet) + assert isinstance(dummy.trainer, Trainer) + + +@pytest.mark.parametrize("cls", [DummyDeepBaseModel, DummyPytorchForecastingModel]) +def test_save_mixin_load_fail_file_not_found(cls): + non_existent_path = pathlib.Path("archive.zip") + with pytest.raises(FileNotFoundError): + cls.load(non_existent_path) + + +def test_save_native_mixin_load_ok(recwarn, tmp_path): + dummy = DummyDeepBaseModel(size=1) + dir_path = pathlib.Path(tmp_path) + path = dir_path.joinpath("dummy.zip") + + dummy.save(path) + loaded_dummy = DummyDeepBaseModel.load(path) + + assert loaded_dummy.size == dummy.size + assert isinstance(loaded_dummy.net, MLPNet) + assert loaded_dummy.trainer is None + # one false positive warning + assert len(recwarn) == 1 + + +def test_save_pf_mixin_without_model_load_ok(recwarn, tmp_path): + dummy = DummyPytorchForecastingModel(size=1, init_model=False) + dir_path = pathlib.Path(tmp_path) + path = dir_path.joinpath("dummy.zip") + + dummy.save(path) + loaded_dummy = DummyPytorchForecastingModel.load(path) + + assert loaded_dummy.size == dummy.size + assert loaded_dummy.model is None + assert loaded_dummy.trainer is None assert len(recwarn) == 0 +def test_save_pf_mixin_with_model_load_ok(recwarn, tmp_path): + dummy = DummyPytorchForecastingModel(size=1, init_model=True) + dir_path = pathlib.Path(tmp_path) + path = dir_path.joinpath("dummy.zip") + + dummy.save(path) + loaded_dummy = DummyPytorchForecastingModel.load(path) + + assert loaded_dummy.size == dummy.size + assert isinstance(loaded_dummy.model, MLPNet) + assert loaded_dummy.trainer is None + # one false positive warning + assert len(recwarn) == 1 + + +@pytest.mark.parametrize( + "dummy", + [ + DummyDeepBaseModel(size=1), + DummyPytorchForecastingModel(size=1, init_model=False), + DummyPytorchForecastingModel(size=1, init_model=True), + ], +) @pytest.mark.parametrize( "save_version, load_version", [((1, 5, 0), (2, 5, 0)), ((2, 5, 0), (1, 5, 0)), ((1, 5, 0), (1, 3, 0))] ) @patch("etna.core.mixins.get_etna_version") -def test_save_mixin_load_warning(get_version_mock, save_version, load_version, tmp_path): - dummy = DummyNN(a=1, b=2) +def test_save_mixin_load_warning(get_version_mock, save_version, load_version, dummy, tmp_path): dir_path = pathlib.Path(tmp_path) path = dir_path.joinpath("dummy.zip") @@ -207,7 +333,7 @@ def test_save_mixin_load_warning(get_version_mock, save_version, load_version, t match=f"The object was saved under etna version {save_version_str} but running version is {load_version_str}", ): get_version_mock.return_value = load_version - _ = DummyNN.load(path) + _ = dummy.load(path) @pytest.mark.parametrize( diff --git a/tests/test_models/nn/conftest.py b/tests/test_models/test_nn/conftest.py similarity index 100% rename from tests/test_models/nn/conftest.py rename to tests/test_models/test_nn/conftest.py diff --git a/tests/test_models/nn/deepstate/test_lds.py b/tests/test_models/test_nn/deepstate/test_lds.py similarity index 100% rename from tests/test_models/nn/deepstate/test_lds.py rename to tests/test_models/test_nn/deepstate/test_lds.py diff --git a/tests/test_models/nn/deepstate/test_ssm.py b/tests/test_models/test_nn/deepstate/test_ssm.py similarity index 100% rename from tests/test_models/nn/deepstate/test_ssm.py rename to tests/test_models/test_nn/deepstate/test_ssm.py diff --git a/tests/test_models/nn/nbeats/test_blocks.py b/tests/test_models/test_nn/nbeats/test_blocks.py similarity index 100% rename from tests/test_models/nn/nbeats/test_blocks.py rename to tests/test_models/test_nn/nbeats/test_blocks.py diff --git a/tests/test_models/nn/nbeats/test_nbeats.py b/tests/test_models/test_nn/nbeats/test_nbeats.py similarity index 100% rename from tests/test_models/nn/nbeats/test_nbeats.py rename to tests/test_models/test_nn/nbeats/test_nbeats.py diff --git a/tests/test_models/nn/nbeats/test_nbeats_metrics.py b/tests/test_models/test_nn/nbeats/test_nbeats_metrics.py similarity index 100% rename from tests/test_models/nn/nbeats/test_nbeats_metrics.py rename to tests/test_models/test_nn/nbeats/test_nbeats_metrics.py diff --git a/tests/test_models/nn/nbeats/test_nbeats_nets.py b/tests/test_models/test_nn/nbeats/test_nbeats_nets.py similarity index 100% rename from tests/test_models/nn/nbeats/test_nbeats_nets.py rename to tests/test_models/test_nn/nbeats/test_nbeats_nets.py diff --git a/tests/test_models/nn/nbeats/test_nbeats_utils.py b/tests/test_models/test_nn/nbeats/test_nbeats_utils.py similarity index 100% rename from tests/test_models/nn/nbeats/test_nbeats_utils.py rename to tests/test_models/test_nn/nbeats/test_nbeats_utils.py diff --git a/tests/test_models/nn/test_deepar.py b/tests/test_models/test_nn/test_deepar.py similarity index 100% rename from tests/test_models/nn/test_deepar.py rename to tests/test_models/test_nn/test_deepar.py diff --git a/tests/test_models/nn/test_deepstate.py b/tests/test_models/test_nn/test_deepstate.py similarity index 76% rename from tests/test_models/nn/test_deepstate.py rename to tests/test_models/test_nn/test_deepstate.py index 12f3e10ca..da7d0dcf8 100644 --- a/tests/test_models/nn/test_deepstate.py +++ b/tests/test_models/test_nn/test_deepstate.py @@ -5,6 +5,7 @@ from etna.models.nn.deepstate import CompositeSSM from etna.models.nn.deepstate import WeeklySeasonalitySSM from etna.transforms import StandardScalerTransform +from tests.test_models.utils import assert_model_equals_loaded_original @pytest.mark.parametrize( @@ -43,3 +44,14 @@ def test_deepstate_model_run_weekly_overfit_with_scaler(ts_dataset_weekly_functi mae = MAE("macro") assert mae(ts_test, future) < 0.001 + + +def test_save_load(example_tsds): + model = DeepStateModel( + ssm=CompositeSSM(seasonal_ssms=[WeeklySeasonalitySSM()], nonseasonal_ssm=None), + input_size=0, + encoder_length=14, + decoder_length=14, + trainer_params=dict(max_epochs=1), + ) + assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=[], horizon=3) diff --git a/tests/test_models/nn/test_mlp.py b/tests/test_models/test_nn/test_mlp.py similarity index 100% rename from tests/test_models/nn/test_mlp.py rename to tests/test_models/test_nn/test_mlp.py diff --git a/tests/test_models/nn/test_patchts.py b/tests/test_models/test_nn/test_patchts.py similarity index 91% rename from tests/test_models/nn/test_patchts.py rename to tests/test_models/test_nn/test_patchts.py index b7efb05a2..312d0b37d 100644 --- a/tests/test_models/nn/test_patchts.py +++ b/tests/test_models/test_nn/test_patchts.py @@ -7,6 +7,7 @@ from etna.models.nn import PatchTSModel from etna.models.nn.patchts import PatchTSNet from etna.transforms import StandardScalerTransform +from tests.test_models.utils import assert_model_equals_loaded_original from tests.test_models.utils import assert_sampling_is_valid @@ -76,6 +77,11 @@ def test_patchts_make_samples(example_df): np.testing.assert_equal(example_df[["target"]].iloc[1 : encoder_length + 1], second_sample["encoder_real"]) +def test_save_load(example_tsds): + model = PatchTSModel(encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=1)) + assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=[], horizon=3) + + def test_params_to_tune(example_tsds): ts = example_tsds model = PatchTSModel(encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=1)) diff --git a/tests/test_models/nn/test_rnn.py b/tests/test_models/test_nn/test_rnn.py similarity index 100% rename from tests/test_models/nn/test_rnn.py rename to tests/test_models/test_nn/test_rnn.py diff --git a/tests/test_models/nn/test_tft.py b/tests/test_models/test_nn/test_tft.py similarity index 100% rename from tests/test_models/nn/test_tft.py rename to tests/test_models/test_nn/test_tft.py diff --git a/tests/test_models/nn/test_utils.py b/tests/test_models/test_nn/test_utils.py similarity index 100% rename from tests/test_models/nn/test_utils.py rename to tests/test_models/test_nn/test_utils.py From 7e01b43140aa24b1c3b68fd8ca88f463c3f6a3d9 Mon Sep 17 00:00:00 2001 From: Dmitry Bunin Date: Wed, 4 Oct 2023 16:07:13 +0300 Subject: [PATCH 2/6] chore: update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d0f83d42..113ec73cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Rework `get_started` notebook ([#1343](https://github.com/tinkoff-ai/etna/pull/1343)) - Add missing classes from decomposition into API Reference, add modules into page titles in API Reference ([#61](https://github.com/etna-team/etna/pull/61)) - Update `CONTRIBUTING.md` with scenarios of documentation updates and release instruction ([#77](https://github.com/etna-team/etna/pull/77)) +- Rework saving DL models by separating saving model's hyperparameters and model's weights ([#98](https://github.com/etna-team/etna/pull/98)) ### Fixed - Fix `ResampleWithDistributionTransform` working with categorical columns ([#82](https://github.com/etna-team/etna/pull/82)) From a0e0cd36b260e1c311b587bc7e192973e1499bd1 Mon Sep 17 00:00:00 2001 From: Dmitry Bunin Date: Fri, 6 Oct 2023 15:49:35 +0300 Subject: [PATCH 3/6] fix: fix poetry.lock --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index c7b417746..cec118353 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6275,4 +6275,4 @@ wandb = ["wandb"] [metadata] lock-version = "2.0" python-versions = ">=3.8.0, <3.11.0" -content-hash = "a7d886473db020164af0b6c357cf2f70d09ed46dba1a30c2e688ba5bd086a307" +content-hash = "1c02972c5a9dfe4446907a8af1105eaee92d99b94cd9b3f4138406cd86c335fb" From 4765546410d9a4a934ee57762cd3bc1f7287e6d1 Mon Sep 17 00:00:00 2001 From: Dmitry Bunin Date: Fri, 6 Oct 2023 15:58:07 +0300 Subject: [PATCH 4/6] docs: add docs for SaveMixin._save method --- etna/core/mixins.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/etna/core/mixins.py b/etna/core/mixins.py index f2a44fdf6..a18589be5 100644 --- a/etna/core/mixins.py +++ b/etna/core/mixins.py @@ -246,6 +246,17 @@ def _save_state(self, archive: zipfile.ZipFile, skip_attributes: Sequence[str] = setattr(self, attr, value) def _save(self, path: pathlib.Path, skip_attributes: Sequence[str] = ()): + """Save the object with more options. + + This method is intended to use to implement ``save`` method during inheritance. + + Parameters + ---------- + path: + Path to save object to. + skip_attributes: + Attributes to be skipped during saving state. These attributes are intended to be saved manually. + """ with zipfile.ZipFile(path, "w") as archive: self._save_metadata(archive) self._save_state(archive, skip_attributes=skip_attributes) From b9ff5b72a99557eaaf48dfbae4c3f1f4a701e92d Mon Sep 17 00:00:00 2001 From: Dmitry Bunin Date: Fri, 6 Oct 2023 16:36:18 +0300 Subject: [PATCH 5/6] fix: remove _load_object because it isn't needed --- etna/models/mixins.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/etna/models/mixins.py b/etna/models/mixins.py index 0d45e8268..2e9aacd7d 100644 --- a/etna/models/mixins.py +++ b/etna/models/mixins.py @@ -638,12 +638,6 @@ def get_model(self) -> Any: return self._base_model.get_model() -def _load_object(class_name, class_parameters): - cls = get_factory(class_name) - obj = cls(**class_parameters) - return obj - - def _save_pl_model(archive: zipfile.ZipFile, filename: str, model: "LightningModule"): with archive.open(filename, "w", force_zip64=True) as output_file: to_save = { @@ -658,9 +652,8 @@ def _load_pl_model(archive: zipfile.ZipFile, filename: str) -> "LightningModule" with archive.open(filename, "r") as input_file: net_loaded = torch.load(input_file, pickle_module=dill) - # fixes the [issue](https://github.com/Lightning-AI/lightning/issues/18405) with `save_hyperparameters` - net = _load_object(class_name=net_loaded["class"], class_parameters=net_loaded["hyperparameters"]) - + cls = get_factory(net_loaded["class"]) + net = cls(**net_loaded["hyperparameters"]) net.load_state_dict(net_loaded["state_dict"]) return net From 45b26990f22c4d5780381e65ffa3a01d5d0e3116 Mon Sep 17 00:00:00 2001 From: Dmitry Bunin Date: Mon, 9 Oct 2023 11:15:21 +0300 Subject: [PATCH 6/6] docs: add info about save_hyperparameters --- etna/models/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/etna/models/base.py b/etna/models/base.py index 67195b1cb..f648a6810 100644 --- a/etna/models/base.py +++ b/etna/models/base.py @@ -429,7 +429,11 @@ def get_model(self) -> "DeepBaseNet": class DeepBaseNet(DeepAbstractNet, LightningModule): - """Class for partially implemented LightningModule interface.""" + """Class for partially implemented LightningModule interface. + + During inheritance don't forget to add ``self.save_hyperparameters()`` to the ``__init__``. + Otherwise, methods ``save`` and ``load`` won't work properly for your implementation of :py:class:`~etna.models.base.DeepBaseModel`. + """ def __init__(self): """Init DeepBaseNet."""