-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rework saving for DL models #98
Conversation
Codecov ReportAll modified lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #98 +/- ##
===========================================
+ Coverage 9.35% 89.24% +79.88%
===========================================
Files 195 195
Lines 12545 12596 +51
===========================================
+ Hits 1174 11241 +10067
+ Misses 11371 1355 -10016
☔ View full report in Codecov by Sentry. |
🚀 Deployed on https://deploy-preview-98--etna-docs.netlify.app |
def _save_state(self, archive: zipfile.ZipFile): | ||
with archive.open("object.pkl", "w", force_zip64=True) as output_file: | ||
dill.dump(self, output_file) | ||
def _save_state(self, archive: zipfile.ZipFile, skip_attributes: Sequence[str] = ()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to rework _save_state
to make other mixins simpler.
@@ -73,7 +73,7 @@ types-Deprecated = "1.2.9" | |||
|
|||
prophet = {version = "^1.0", optional = true} | |||
|
|||
torch = {version = ">=1.8.0,<1.12.0", optional = true} | |||
torch = {version = ">=1.8.0,<3", optional = true} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked torch==1.13.1
and torch==2.0.1
.
self.size = size | ||
self.init_model = init_model | ||
if init_model: | ||
self.model = MLPNet( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't very easy to create DeepAR
model here to be more accurate with the testing.
@@ -43,3 +44,14 @@ def test_deepstate_model_run_weekly_overfit_with_scaler(ts_dataset_weekly_functi | |||
|
|||
mae = MAE("macro") | |||
assert mae(ts_test, future) < 0.001 | |||
|
|||
|
|||
def test_save_load(example_tsds): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test was forgotten and I added it.
@@ -43,6 +43,9 @@ def __repr__(self): | |||
args_str_representation += f"{arg} = {repr(value)}, " | |||
return f"{self.__class__.__name__}({args_str_representation})" | |||
|
|||
def _get_init_parameters(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added for simplification.
def _save_pl_model(archive: zipfile.ZipFile, filename: str, model: "LightningModule"): | ||
with archive.open(filename, "w", force_zip64=True) as output_file: | ||
to_save = { | ||
"class": BaseMixin._get_target_from_class(model), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This saving can potentially be improved.
# Conflicts: # CHANGELOG.md # poetry.lock
Before submitting (must do checklist)
Proposed Changes
save
/load
for all DL modelsClosing issues
Closes #32.