From 872b98dd683cca594109d5c7614a1e1707a8b9cc Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Mon, 14 Aug 2023 11:51:32 -0400 Subject: [PATCH] lightning: copy best model --- src/dvclive/lightning.py | 2 +- tests/test_frameworks/test_lightning.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 069bb6df..5f0c1fff 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -157,7 +157,7 @@ def finalize(self, status: str) -> None: self._save_checkpoints(self._checkpoint_callback) best_model_path = self._checkpoint_callback.best_model_path self.experiment.log_artifact( - best_model_path, name="best", type="model", cache=False + best_model_path, name="best", type="model", copy=True ) self.experiment.end() diff --git a/tests/test_frameworks/test_lightning.py b/tests/test_frameworks/test_lightning.py index e572b235..08c9c73a 100644 --- a/tests/test_frameworks/test_lightning.py +++ b/tests/test_frameworks/test_lightning.py @@ -192,6 +192,9 @@ def test_lightning_log_model(tmp_dir, mocker, log_model, save_top_k): if log_model in [True, "all"]: trainer.fit(model) assert len(os.listdir(tmp_dir / "model")) == num_checkpoints + log_artifact.assert_any_call( + checkpoint.best_model_path, name="best", type="model", copy=True + ) def test_lightning_steps(tmp_dir, mocker):