Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework saving for DL models #98

Merged
merged 9 commits into from
Oct 10, 2023
Merged

Rework saving for DL models #98

merged 9 commits into from
Oct 10, 2023

Conversation

d-a-bunin
Copy link
Collaborator

@d-a-bunin d-a-bunin commented Oct 4, 2023

Before submitting (must do checklist)

  • Did you read the contribution guide?
  • Did you update the docs? We use Numpy format for all the methods and classes.
  • Did you write any new necessary tests?
  • Did you update the CHANGELOG?

Proposed Changes

  • Separate saving of state, hyperparameters and weights
  • Add missing tests on save/load for all DL models

Closing issues

Closes #32.

@d-a-bunin d-a-bunin self-assigned this Oct 4, 2023
@codecov
Copy link

codecov bot commented Oct 4, 2023

Codecov Report

All modified lines are covered by tests ✅

Comparison is base (799ccb1) 9.35% compared to head (a21e204) 89.24%.

Additional details and impacted files
@@             Coverage Diff             @@
##           master      #98       +/-   ##
===========================================
+ Coverage    9.35%   89.24%   +79.88%     
===========================================
  Files         195      195               
  Lines       12545    12596       +51     
===========================================
+ Hits         1174    11241    +10067     
+ Misses      11371     1355    -10016     
Files Coverage Δ
etna/core/mixins.py 96.11% <100.00%> (+69.48%) ⬆️
etna/ensembles/mixins.py 100.00% <100.00%> (+100.00%) ⬆️
etna/models/base.py 86.77% <100.00%> (+34.92%) ⬆️
etna/models/mixins.py 96.95% <100.00%> (+62.72%) ⬆️
etna/models/nn/deepar.py 97.43% <100.00%> (+97.43%) ⬆️
etna/models/nn/deepstate/deepstate.py 100.00% <100.00%> (+100.00%) ⬆️
etna/models/nn/mlp.py 100.00% <100.00%> (+100.00%) ⬆️
etna/models/nn/nbeats/nets.py 100.00% <100.00%> (+100.00%) ⬆️
etna/models/nn/patchts.py 100.00% <100.00%> (+100.00%) ⬆️
etna/models/nn/rnn.py 100.00% <100.00%> (+100.00%) ⬆️
... and 3 more

... and 171 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@github-actions
Copy link

github-actions bot commented Oct 4, 2023

🚀 Deployed on https://deploy-preview-98--etna-docs.netlify.app

@github-actions github-actions bot temporarily deployed to pull request October 4, 2023 13:15 Inactive
def _save_state(self, archive: zipfile.ZipFile):
with archive.open("object.pkl", "w", force_zip64=True) as output_file:
dill.dump(self, output_file)
def _save_state(self, archive: zipfile.ZipFile, skip_attributes: Sequence[str] = ()):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to rework _save_state to make other mixins simpler.

@@ -73,7 +73,7 @@ types-Deprecated = "1.2.9"

prophet = {version = "^1.0", optional = true}

torch = {version = ">=1.8.0,<1.12.0", optional = true}
torch = {version = ">=1.8.0,<3", optional = true}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked torch==1.13.1 and torch==2.0.1.

self.size = size
self.init_model = init_model
if init_model:
self.model = MLPNet(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't very easy to create DeepAR model here to be more accurate with the testing.

@@ -43,3 +44,14 @@ def test_deepstate_model_run_weekly_overfit_with_scaler(ts_dataset_weekly_functi

mae = MAE("macro")
assert mae(ts_test, future) < 0.001


def test_save_load(example_tsds):
Copy link
Collaborator Author

@d-a-bunin d-a-bunin Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was forgotten and I added it.

@@ -43,6 +43,9 @@ def __repr__(self):
args_str_representation += f"{arg} = {repr(value)}, "
return f"{self.__class__.__name__}({args_str_representation})"

def _get_init_parameters(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added for simplification.

def _save_pl_model(archive: zipfile.ZipFile, filename: str, model: "LightningModule"):
with archive.open(filename, "w", force_zip64=True) as output_file:
to_save = {
"class": BaseMixin._get_target_from_class(model),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This saving can potentially be improved.

@brsnw250 brsnw250 self-requested a review October 5, 2023 07:49
etna/core/mixins.py Show resolved Hide resolved
etna/core/mixins.py Show resolved Hide resolved
etna/models/mixins.py Outdated Show resolved Hide resolved
etna/models/mixins.py Show resolved Hide resolved
@github-actions github-actions bot temporarily deployed to pull request October 6, 2023 13:43 Inactive
@d-a-bunin d-a-bunin requested review from brsnw250 and removed request for brsnw250 October 9, 2023 08:35
@github-actions github-actions bot temporarily deployed to pull request October 9, 2023 08:39 Inactive
brsnw250
brsnw250 previously approved these changes Oct 9, 2023
@github-actions github-actions bot temporarily deployed to pull request October 9, 2023 15:45 Inactive
@d-a-bunin d-a-bunin merged commit 2651829 into master Oct 10, 2023
15 checks passed
@d-a-bunin d-a-bunin deleted the issue-32 branch October 10, 2023 06:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix SaveNNMixin to work on torch-1.13 and torch-2.0
2 participants