forked from tinkoff-ai/etna
-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rework saving for DL models #98
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
af129c1
feature: redo saving for dl models
d-a-bunin 2ae685c
Merge remote-tracking branch 'origin/master' into issue-32
d-a-bunin 7e01b43
chore: update changelog
d-a-bunin 94fbf4d
Merge remote-tracking branch 'origin/master' into issue-32
d-a-bunin a0e0cd3
fix: fix poetry.lock
d-a-bunin 4765546
docs: add docs for SaveMixin._save method
d-a-bunin b9ff5b7
fix: remove _load_object because it isn't needed
d-a-bunin 45b2699
docs: add info about save_hyperparameters
d-a-bunin a21e204
Merge remote-tracking branch 'origin/master' into issue-32
d-a-bunin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,8 +27,8 @@ | |
"""Get default representation of etna object.""" | ||
# TODO: add tests default behaviour for all registered objects | ||
args_str_representation = "" | ||
init_args = inspect.signature(self.__init__).parameters | ||
for arg, param in init_args.items(): | ||
init_parameters = self._get_init_parameters() | ||
for arg, param in init_parameters.items(): | ||
if param.kind == param.VAR_POSITIONAL: | ||
continue | ||
elif param.kind == param.VAR_KEYWORD: | ||
|
@@ -43,6 +43,9 @@ | |
args_str_representation += f"{arg} = {repr(value)}, " | ||
return f"{self.__class__.__name__}({args_str_representation})" | ||
|
||
def _get_init_parameters(self): | ||
return inspect.signature(self.__init__).parameters | ||
|
||
@staticmethod | ||
def _get_target_from_class(value: Any): | ||
if value is None: | ||
|
@@ -84,9 +87,9 @@ | |
|
||
def to_dict(self): | ||
"""Collect all information about etna object in dict.""" | ||
init_args = inspect.signature(self.__init__).parameters | ||
init_parameters = self._get_init_parameters() | ||
params = {} | ||
for arg in init_args.keys(): | ||
for arg in init_parameters.keys(): | ||
value = self.__dict__[arg] | ||
if value is None: | ||
continue | ||
|
@@ -226,21 +229,47 @@ | |
with archive.open("metadata.json", "w") as output_file: | ||
output_file.write(metadata_bytes) | ||
|
||
def _save_state(self, archive: zipfile.ZipFile): | ||
with archive.open("object.pkl", "w", force_zip64=True) as output_file: | ||
dill.dump(self, output_file) | ||
def _save_state(self, archive: zipfile.ZipFile, skip_attributes: Sequence[str] = ()): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to rework |
||
saved_attributes = {} | ||
try: | ||
# remove attributes we can't easily save | ||
saved_attributes = {attr: getattr(self, attr) for attr in skip_attributes} | ||
brsnw250 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for attr in skip_attributes: | ||
delattr(self, attr) | ||
|
||
def save(self, path: pathlib.Path): | ||
"""Save the object. | ||
# save the remaining part | ||
with archive.open("object.pkl", "w", force_zip64=True) as output_file: | ||
dill.dump(self, output_file) | ||
finally: | ||
# restore the skipped attributes | ||
for attr, value in saved_attributes.items(): | ||
setattr(self, attr, value) | ||
|
||
def _save(self, path: pathlib.Path, skip_attributes: Sequence[str] = ()): | ||
brsnw250 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Save the object with more options. | ||
|
||
This method is intended to use to implement ``save`` method during inheritance. | ||
|
||
Parameters | ||
---------- | ||
path: | ||
Path to save object to. | ||
skip_attributes: | ||
Attributes to be skipped during saving state. These attributes are intended to be saved manually. | ||
""" | ||
with zipfile.ZipFile(path, "w") as archive: | ||
self._save_metadata(archive) | ||
self._save_state(archive) | ||
self._save_state(archive, skip_attributes=skip_attributes) | ||
|
||
def save(self, path: pathlib.Path): | ||
"""Save the object. | ||
|
||
Parameters | ||
---------- | ||
path: | ||
Path to save object to. | ||
""" | ||
self._save(path=path) | ||
|
||
@classmethod | ||
def _load_metadata(cls, archive: zipfile.ZipFile) -> Dict[str, Any]: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import pathlib | ||
import zipfile | ||
from abc import ABC | ||
from abc import abstractmethod | ||
|
@@ -11,13 +12,21 @@ | |
import dill | ||
import numpy as np | ||
import pandas as pd | ||
from hydra_slayer import get_factory | ||
from typing_extensions import Self | ||
|
||
from etna import SETTINGS | ||
from etna.core.mixins import BaseMixin | ||
from etna.core.mixins import SaveMixin | ||
from etna.datasets.tsdataset import TSDataset | ||
from etna.datasets.utils import match_target_quantiles | ||
from etna.models.decorators import log_decorator | ||
|
||
if SETTINGS.torch_required: | ||
import torch | ||
from pytorch_lightning import LightningModule | ||
from pytorch_lightning import Trainer | ||
|
||
|
||
class ModelForecastingMixin(ABC): | ||
"""Base class for model mixins.""" | ||
|
@@ -640,25 +649,145 @@ | |
return self._base_model.get_model() | ||
|
||
|
||
class SaveNNMixin(SaveMixin): | ||
"""Implementation of ``AbstractSaveable`` torch related classes. | ||
def _save_pl_model(archive: zipfile.ZipFile, filename: str, model: "LightningModule"): | ||
with archive.open(filename, "w", force_zip64=True) as output_file: | ||
to_save = { | ||
"class": BaseMixin._get_target_from_class(model), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This saving can potentially be improved. |
||
"hyperparameters": dict(model.hparams), | ||
brsnw250 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"state_dict": model.state_dict(), | ||
} | ||
torch.save(to_save, output_file, pickle_module=dill) | ||
|
||
|
||
def _load_pl_model(archive: zipfile.ZipFile, filename: str) -> "LightningModule": | ||
with archive.open(filename, "r") as input_file: | ||
net_loaded = torch.load(input_file, pickle_module=dill) | ||
|
||
cls = get_factory(net_loaded["class"]) | ||
net = cls(**net_loaded["hyperparameters"]) | ||
net.load_state_dict(net_loaded["state_dict"]) | ||
|
||
return net | ||
|
||
|
||
class SaveDeepBaseModelMixin(SaveMixin): | ||
"""Implementation of ``AbstractSaveable`` for :py:class:`~etna.models.base.DeepBaseModel` models. | ||
|
||
It saves object to the zip archive with files: | ||
|
||
* metadata.json: contains library version and class name. | ||
|
||
* object.pkl: pickled without ``self.net`` and ``self.trainer``. | ||
|
||
* net.pt: parameters of ``self.net`` saved by ``torch.save``. | ||
""" | ||
|
||
def save(self, path: pathlib.Path): | ||
"""Save the object. | ||
|
||
Parameters | ||
---------- | ||
path: | ||
Path to save object to. | ||
""" | ||
from etna.models.base import DeepBaseNet | ||
|
||
self.trainer: Optional[Trainer] | ||
self.net: DeepBaseNet | ||
|
||
self._save(path=path, skip_attributes=["trainer", "net"]) | ||
|
||
with zipfile.ZipFile(path, "a") as archive: | ||
_save_pl_model(archive=archive, filename="net.pt", model=self.net) | ||
|
||
@classmethod | ||
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self: | ||
"""Load an object. | ||
|
||
Warning | ||
------- | ||
This method uses :py:mod:`dill` module which is not secure. | ||
It is possible to construct malicious data which will execute arbitrary code during loading. | ||
Never load data that could have come from an untrusted source, or that could have been tampered with. | ||
|
||
Parameters | ||
---------- | ||
path: | ||
Path to load object from. | ||
ts: | ||
TSDataset to set into loaded pipeline. | ||
|
||
Returns | ||
------- | ||
: | ||
Loaded object. | ||
""" | ||
obj = super().load(path=path) | ||
|
||
with zipfile.ZipFile(path, "r") as archive: | ||
obj.net = _load_pl_model(archive=archive, filename="net.pt") | ||
obj.trainer = None | ||
|
||
return obj | ||
|
||
|
||
It saves object to the zip archive with 2 files: | ||
class SavePytorchForecastingMixin(SaveMixin): | ||
"""Implementation of ``AbstractSaveable`` for :py:mod:`pytorch_forecasting` models. | ||
|
||
It saves object to the zip archive with files: | ||
|
||
* metadata.json: contains library version and class name. | ||
|
||
* object.pt: object saved by ``torch.save``. | ||
* object.pkl: pickled without ``self.model`` and ``self.trainer``. | ||
|
||
* model.pt: parameters of ``self.model`` saved by ``torch.save`` if model was fitted. | ||
""" | ||
|
||
def _save_state(self, archive: zipfile.ZipFile): | ||
import torch | ||
def save(self, path: pathlib.Path): | ||
"""Save the object. | ||
|
||
Parameters | ||
---------- | ||
path: | ||
Path to save object to. | ||
""" | ||
self.trainer: Optional[Trainer] | ||
self.model: Optional[LightningModule] | ||
|
||
with archive.open("object.pt", "w", force_zip64=True) as output_file: | ||
torch.save(self, output_file, pickle_module=dill) | ||
if self.model is None: | ||
self._save(path=path, skip_attributes=["trainer"]) | ||
else: | ||
self._save(path=path, skip_attributes=["trainer", "model"]) | ||
with zipfile.ZipFile(path, "a") as archive: | ||
_save_pl_model(archive=archive, filename="model.pt", model=self.model) | ||
|
||
@classmethod | ||
def _load_state(cls, archive: zipfile.ZipFile) -> Self: | ||
import torch | ||
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self: | ||
"""Load an object. | ||
|
||
Warning | ||
------- | ||
This method uses :py:mod:`dill` module which is not secure. | ||
It is possible to construct malicious data which will execute arbitrary code during loading. | ||
Never load data that could have come from an untrusted source, or that could have been tampered with. | ||
|
||
Parameters | ||
---------- | ||
path: | ||
Path to load object from. | ||
ts: | ||
TSDataset to set into loaded pipeline. | ||
|
||
Returns | ||
------- | ||
: | ||
Loaded object. | ||
""" | ||
obj = super().load(path=path) | ||
obj.trainer = None | ||
|
||
if not hasattr(obj, "model"): | ||
with zipfile.ZipFile(path, "r") as archive: | ||
obj.model = _load_pl_model(archive=archive, filename="model.pt") | ||
|
||
with archive.open("object.pt", "r") as input_file: | ||
return torch.load(input_file, pickle_module=dill) | ||
return obj | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added for simplification.