Skip to content

Commit

Permalink
cache: cache in dvc exp if not stage output (#660)
Browse files Browse the repository at this point in the history
* cache: cache in dvc exp if not stage output

* fix tests
  • Loading branch information
Dave Berenbaum authored Aug 16, 2023
1 parent f1b8e2a commit 75006eb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 36 deletions.
26 changes: 14 additions & 12 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,31 +462,33 @@ def log_artifact(
def cache(self, path):
try:
if self._inside_dvc_exp:
msg = f"Skipping dvc add {path} because `dvc exp run` is running."
path_stage = None
for stage in self._dvc_repo.index.stages:
for out in stage.outs:
if out.fspath == str(Path(path).absolute()):
path_stage = stage
break
if not path_stage:
msg += (
"\nTo track it automatically during `dvc exp run`, "
"add it as an output of the pipeline stage."
if path_stage and path_stage.cmd:
msg = (
f"Skipping `dvc add {path}` because it is already being tracked"
" automatically as an output of `dvc exp run`."
)
logger.warning(msg)
elif path_stage.cmd:
msg += "\nIt is already being tracked automatically."
logger.info(msg)
else:
msg += (
"\nTo track it automatically during `dvc exp run`:"
return # skip caching
if path_stage:
msg = (
f"\nTo track '{path}' automatically during `dvc exp run`:"
f"\n1. Run `dvc exp remove {path_stage.addressing}` "
"to stop tracking it outside the pipeline."
"\n2. Add it as an output of the pipeline stage."
)
logger.warning(msg)
return
else:
msg = (
f"\nTo track '{path}' automatically during `dvc exp run`, "
"add it as an output of the pipeline stage."
)
logger.warning(msg)

stage = self._dvc_repo.add(str(path))

Expand Down
45 changes: 21 additions & 24 deletions tests/test_log_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,44 +225,41 @@ def test_log_artifact_type_model_when_dvc_add_fails(tmp_dir, mocker, mocked_dvc_
}


def test_log_artifact_inside_exp(tmp_dir, mocked_dvc_repo):
data = tmp_dir / "data"
data.touch()
with Live() as live:
live._inside_dvc_exp = True
live.log_artifact("data")
mocked_dvc_repo.add.assert_not_called()


@pytest.mark.parametrize("tracked", ["data_source", "stage", None])
def test_log_artifact_inside_exp_logger(tmp_dir, mocker, dvc_repo, tracked):
def test_log_artifact_inside_exp(tmp_dir, mocker, dvc_repo, tracked):
logger = mocker.patch("dvclive.live.logger")
data = tmp_dir / "data"
data.touch()
if tracked == "data_source":
data = tmp_dir / "data"
data.touch()
dvc_repo.add(data)
elif tracked == "stage":
dvcyaml_path = tmp_dir / "dvc.yaml"
with open(dvcyaml_path, "w") as f:
f.write(dvcyaml)
with Live() as live:
live._inside_dvc_exp = True
live.log_artifact("data")
msg = "Skipping dvc add data because `dvc exp run` is running."
if tracked == "data_source":
msg += (
"\nTo track it automatically during `dvc exp run`:"
live = Live()
spy = mocker.spy(live._dvc_repo, "add")
live._inside_dvc_exp = True
live.log_artifact("data")
if tracked == "stage":
msg = (
"Skipping `dvc add data` because it is already being tracked"
" automatically as an output of `dvc exp run`."
)
logger.info.assert_called_with(msg)
spy.assert_not_called()
elif tracked == "data_source":
msg = (
"\nTo track 'data' automatically during `dvc exp run`:"
"\n1. Run `dvc exp remove data.dvc` "
"to stop tracking it outside the pipeline."
"\n2. Add it as an output of the pipeline stage."
)
logger.warning.assert_called_with(msg)
elif tracked == "stage":
msg += "\nIt is already being tracked automatically."
logger.info.assert_called_with(msg)
spy.assert_called_once()
else:
msg += (
"\nTo track it automatically during `dvc exp run`, "
msg = (
"\nTo track 'data' automatically during `dvc exp run`, "
"add it as an output of the pipeline stage."
)
logger.warning.assert_called_with(msg)
spy.assert_called_once()

0 comments on commit 75006eb

Please sign in to comment.