Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework saving for DL models #98

Merged
merged 9 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add missing classes from decomposition into API Reference, add modules into page titles in API Reference ([#61](https://github.com/etna-team/etna/pull/61))
- Update `CONTRIBUTING.md` with scenarios of documentation updates and release instruction ([#77](https://github.com/etna-team/etna/pull/77))
- Set up sharding for running tests ([#99](https://github.com/etna-team/etna/pull/99))
- Rework saving DL models by separating saving model's hyperparameters and model's weights ([#98](https://github.com/etna-team/etna/pull/98))

### Fixed
- Fix `ResampleWithDistributionTransform` working with categorical columns ([#82](https://github.com/etna-team/etna/pull/82))
Expand Down
49 changes: 39 additions & 10 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
"""Get default representation of etna object."""
# TODO: add tests default behaviour for all registered objects
args_str_representation = ""
init_args = inspect.signature(self.__init__).parameters
for arg, param in init_args.items():
init_parameters = self._get_init_parameters()
for arg, param in init_parameters.items():

Check warning on line 31 in etna/core/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/core/mixins.py#L30-L31

Added lines #L30 - L31 were not covered by tests
if param.kind == param.VAR_POSITIONAL:
continue
elif param.kind == param.VAR_KEYWORD:
Expand All @@ -43,6 +43,9 @@
args_str_representation += f"{arg} = {repr(value)}, "
return f"{self.__class__.__name__}({args_str_representation})"

def _get_init_parameters(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added for simplification.

return inspect.signature(self.__init__).parameters

Check warning on line 47 in etna/core/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/core/mixins.py#L47

Added line #L47 was not covered by tests

@staticmethod
def _get_target_from_class(value: Any):
if value is None:
Expand Down Expand Up @@ -84,9 +87,9 @@

def to_dict(self):
"""Collect all information about etna object in dict."""
init_args = inspect.signature(self.__init__).parameters
init_parameters = self._get_init_parameters()

Check warning on line 90 in etna/core/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/core/mixins.py#L90

Added line #L90 was not covered by tests
params = {}
for arg in init_args.keys():
for arg in init_parameters.keys():

Check warning on line 92 in etna/core/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/core/mixins.py#L92

Added line #L92 was not covered by tests
value = self.__dict__[arg]
if value is None:
continue
Expand Down Expand Up @@ -226,21 +229,47 @@
with archive.open("metadata.json", "w") as output_file:
output_file.write(metadata_bytes)

def _save_state(self, archive: zipfile.ZipFile):
with archive.open("object.pkl", "w", force_zip64=True) as output_file:
dill.dump(self, output_file)
def _save_state(self, archive: zipfile.ZipFile, skip_attributes: Sequence[str] = ()):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to rework _save_state to make other mixins simpler.

saved_attributes = {}
try:

Check warning on line 234 in etna/core/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/core/mixins.py#L233-L234

Added lines #L233 - L234 were not covered by tests
# remove attributes we can't easily save
saved_attributes = {attr: getattr(self, attr) for attr in skip_attributes}
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
for attr in skip_attributes:
delattr(self, attr)

Check warning on line 238 in etna/core/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/core/mixins.py#L236-L238

Added lines #L236 - L238 were not covered by tests

def save(self, path: pathlib.Path):
"""Save the object.
# save the remaining part
with archive.open("object.pkl", "w", force_zip64=True) as output_file:
dill.dump(self, output_file)

Check warning on line 242 in etna/core/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/core/mixins.py#L241-L242

Added lines #L241 - L242 were not covered by tests
finally:
# restore the skipped attributes
for attr, value in saved_attributes.items():
setattr(self, attr, value)

Check warning on line 246 in etna/core/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/core/mixins.py#L245-L246

Added lines #L245 - L246 were not covered by tests

def _save(self, path: pathlib.Path, skip_attributes: Sequence[str] = ()):
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
"""Save the object with more options.

This method is intended to use to implement ``save`` method during inheritance.

Parameters
----------
path:
Path to save object to.
skip_attributes:
Attributes to be skipped during saving state. These attributes are intended to be saved manually.
"""
with zipfile.ZipFile(path, "w") as archive:
self._save_metadata(archive)
self._save_state(archive)
self._save_state(archive, skip_attributes=skip_attributes)

Check warning on line 262 in etna/core/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/core/mixins.py#L262

Added line #L262 was not covered by tests

def save(self, path: pathlib.Path):
"""Save the object.

Parameters
----------
path:
Path to save object to.
"""
self._save(path=path)

Check warning on line 272 in etna/core/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/core/mixins.py#L272

Added line #L272 was not covered by tests

@classmethod
def _load_metadata(cls, archive: zipfile.ZipFile) -> Dict[str, Any]:
Expand Down
15 changes: 2 additions & 13 deletions etna/ensembles/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,7 @@
self.pipelines: List[BasePipeline]
self.ts: Optional[TSDataset]

pipelines = self.pipelines
ts = self.ts
try:
# extract attributes we can't easily save
delattr(self, "pipelines")
delattr(self, "ts")

# save the remaining part
super().save(path=path)
finally:
self.pipelines = pipelines
self.ts = ts
self._save(path=path, skip_attributes=["pipelines", "ts"])

Check warning on line 88 in etna/ensembles/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/ensembles/mixins.py#L88

Added line #L88 was not covered by tests

with zipfile.ZipFile(path, "a") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
Expand All @@ -106,7 +95,7 @@
pipelines_dir = temp_dir / "pipelines"
pipelines_dir.mkdir()
num_digits = 8
for i, pipeline in enumerate(pipelines):
for i, pipeline in enumerate(self.pipelines):

Check warning on line 98 in etna/ensembles/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/ensembles/mixins.py#L98

Added line #L98 was not covered by tests
save_name = f"{i:0{num_digits}d}.zip"
pipeline_save_path = pipelines_dir / save_name
pipeline.save(pipeline_save_path)
Expand Down
14 changes: 2 additions & 12 deletions etna/experimental/prediction_intervals/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,15 @@ def save(self, path: pathlib.Path):
"""
self.pipeline: BasePipeline

pipeline = self.pipeline

try:
# extract pipeline to save it with its own method later
delattr(self, "pipeline")

# save the remaining part
super().save(path=path)

finally:
self.pipeline = pipeline
self._save(path=path, skip_attributes=["pipeline"])

with zipfile.ZipFile(path, "a") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)

# save pipeline separately and add to the archive
pipeline_save_path = temp_dir / "pipeline.zip"
pipeline.save(path=pipeline_save_path)
self.pipeline.save(path=pipeline_save_path)

archive.write(pipeline_save_path, "pipeline.zip")

Expand Down
10 changes: 7 additions & 3 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from etna.distributions import BaseDistribution
from etna.loggers import tslogger
from etna.models.decorators import log_decorator
from etna.models.mixins import SaveNNMixin
from etna.models.mixins import SaveDeepBaseModelMixin

if SETTINGS.torch_required:
import torch
Expand Down Expand Up @@ -429,7 +429,11 @@ def get_model(self) -> "DeepBaseNet":


class DeepBaseNet(DeepAbstractNet, LightningModule):
"""Class for partially implemented LightningModule interface."""
"""Class for partially implemented LightningModule interface.

During inheritance don't forget to add ``self.save_hyperparameters()`` to the ``__init__``.
Otherwise, methods ``save`` and ``load`` won't work properly for your implementation of :py:class:`~etna.models.base.DeepBaseModel`.
"""

def __init__(self):
"""Init DeepBaseNet."""
Expand Down Expand Up @@ -470,7 +474,7 @@ def validation_step(self, batch: dict, *args, **kwargs): # type: ignore
return loss


class DeepBaseModel(DeepBaseAbstractModel, SaveNNMixin, NonPredictionIntervalContextRequiredAbstractModel):
class DeepBaseModel(DeepBaseAbstractModel, SaveDeepBaseModelMixin, NonPredictionIntervalContextRequiredAbstractModel):
"""Class for partially implemented interfaces for holding deep models."""

def __init__(
Expand Down
153 changes: 141 additions & 12 deletions etna/models/mixins.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pathlib
import zipfile
from abc import ABC
from abc import abstractmethod
Expand All @@ -11,13 +12,21 @@
import dill
import numpy as np
import pandas as pd
from hydra_slayer import get_factory
from typing_extensions import Self

from etna import SETTINGS
from etna.core.mixins import BaseMixin
from etna.core.mixins import SaveMixin
from etna.datasets.tsdataset import TSDataset
from etna.datasets.utils import match_target_quantiles
from etna.models.decorators import log_decorator

if SETTINGS.torch_required:
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer


class ModelForecastingMixin(ABC):
"""Base class for model mixins."""
Expand Down Expand Up @@ -640,25 +649,145 @@
return self._base_model.get_model()


class SaveNNMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` torch related classes.
def _save_pl_model(archive: zipfile.ZipFile, filename: str, model: "LightningModule"):
with archive.open(filename, "w", force_zip64=True) as output_file:
to_save = {

Check warning on line 654 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L653-L654

Added lines #L653 - L654 were not covered by tests
"class": BaseMixin._get_target_from_class(model),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This saving can potentially be improved.

"hyperparameters": dict(model.hparams),
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
"state_dict": model.state_dict(),
}
torch.save(to_save, output_file, pickle_module=dill)

Check warning on line 659 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L659

Added line #L659 was not covered by tests


def _load_pl_model(archive: zipfile.ZipFile, filename: str) -> "LightningModule":
with archive.open(filename, "r") as input_file:
net_loaded = torch.load(input_file, pickle_module=dill)

Check warning on line 664 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L663-L664

Added lines #L663 - L664 were not covered by tests

cls = get_factory(net_loaded["class"])
net = cls(**net_loaded["hyperparameters"])
net.load_state_dict(net_loaded["state_dict"])

Check warning on line 668 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L666-L668

Added lines #L666 - L668 were not covered by tests

return net

Check warning on line 670 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L670

Added line #L670 was not covered by tests


class SaveDeepBaseModelMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` for :py:class:`~etna.models.base.DeepBaseModel` models.

It saves object to the zip archive with files:

* metadata.json: contains library version and class name.

* object.pkl: pickled without ``self.net`` and ``self.trainer``.

* net.pt: parameters of ``self.net`` saved by ``torch.save``.
"""

def save(self, path: pathlib.Path):
"""Save the object.

Parameters
----------
path:
Path to save object to.
"""
from etna.models.base import DeepBaseNet

Check warning on line 693 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L693

Added line #L693 was not covered by tests

self.trainer: Optional[Trainer]
self.net: DeepBaseNet

Check warning on line 696 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L695-L696

Added lines #L695 - L696 were not covered by tests

self._save(path=path, skip_attributes=["trainer", "net"])

Check warning on line 698 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L698

Added line #L698 was not covered by tests

with zipfile.ZipFile(path, "a") as archive:
_save_pl_model(archive=archive, filename="net.pt", model=self.net)

Check warning on line 701 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L700-L701

Added lines #L700 - L701 were not covered by tests

@classmethod
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self:
"""Load an object.

Warning
-------
This method uses :py:mod:`dill` module which is not secure.
It is possible to construct malicious data which will execute arbitrary code during loading.
Never load data that could have come from an untrusted source, or that could have been tampered with.

Parameters
----------
path:
Path to load object from.
ts:
TSDataset to set into loaded pipeline.

Returns
-------
:
Loaded object.
"""
obj = super().load(path=path)

Check warning on line 725 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L725

Added line #L725 was not covered by tests

with zipfile.ZipFile(path, "r") as archive:
obj.net = _load_pl_model(archive=archive, filename="net.pt")
obj.trainer = None

Check warning on line 729 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L727-L729

Added lines #L727 - L729 were not covered by tests

return obj

Check warning on line 731 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L731

Added line #L731 was not covered by tests


It saves object to the zip archive with 2 files:
class SavePytorchForecastingMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` for :py:mod:`pytorch_forecasting` models.

It saves object to the zip archive with files:

* metadata.json: contains library version and class name.

* object.pt: object saved by ``torch.save``.
* object.pkl: pickled without ``self.model`` and ``self.trainer``.

* model.pt: parameters of ``self.model`` saved by ``torch.save`` if model was fitted.
"""

def _save_state(self, archive: zipfile.ZipFile):
import torch
def save(self, path: pathlib.Path):
"""Save the object.

Parameters
----------
path:
Path to save object to.
"""
self.trainer: Optional[Trainer]
self.model: Optional[LightningModule]

Check warning on line 755 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L754-L755

Added lines #L754 - L755 were not covered by tests

with archive.open("object.pt", "w", force_zip64=True) as output_file:
torch.save(self, output_file, pickle_module=dill)
if self.model is None:
self._save(path=path, skip_attributes=["trainer"])

Check warning on line 758 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L757-L758

Added lines #L757 - L758 were not covered by tests
else:
self._save(path=path, skip_attributes=["trainer", "model"])
with zipfile.ZipFile(path, "a") as archive:
_save_pl_model(archive=archive, filename="model.pt", model=self.model)

Check warning on line 762 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L760-L762

Added lines #L760 - L762 were not covered by tests

@classmethod
def _load_state(cls, archive: zipfile.ZipFile) -> Self:
import torch
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self:
"""Load an object.

Warning
-------
This method uses :py:mod:`dill` module which is not secure.
It is possible to construct malicious data which will execute arbitrary code during loading.
Never load data that could have come from an untrusted source, or that could have been tampered with.

Parameters
----------
path:
Path to load object from.
ts:
TSDataset to set into loaded pipeline.

Returns
-------
:
Loaded object.
"""
obj = super().load(path=path)
obj.trainer = None

Check warning on line 787 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L786-L787

Added lines #L786 - L787 were not covered by tests

if not hasattr(obj, "model"):
with zipfile.ZipFile(path, "r") as archive:
obj.model = _load_pl_model(archive=archive, filename="model.pt")

Check warning on line 791 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L789-L791

Added lines #L789 - L791 were not covered by tests

with archive.open("object.pt", "r") as input_file:
return torch.load(input_file, pickle_module=dill)
return obj

Check warning on line 793 in etna/models/mixins.py

View check run for this annotation

Codecov / codecov/patch

etna/models/mixins.py#L793

Added line #L793 was not covered by tests
Loading
Loading