Skip to content

Commit

Permalink
make_stage_template: more tests and tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed Aug 18, 2023
1 parent 70afb40 commit 2c164ca
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 10 deletions.
7 changes: 4 additions & 3 deletions src/dvclive/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,14 @@ def get_dvc_stage_template(live):
"deps": ["<my_code_file.py>"],
}
outs = []
rel_path = Path(os.path.relpath(os.getcwd(), live._dvc_repo.root_dir))
root_dir = live._dvc_repo.root_dir
if live._images and live._cache_images:
images_path = (rel_path / live.plots_dir / Image.subfolder).as_posix()
plots_path = Path(os.path.relpath(live.plots_dir, root_dir))
images_path = (plots_path / Image.subfolder).as_posix()
outs.append(images_path)
for o, cache in live._outs.items():
artifact_path = Path(os.getcwd()) / o
artifact_path = artifact_path.relative_to(live._dvc_repo.root_dir).as_posix()
artifact_path = artifact_path.relative_to(root_dir).as_posix()
if cache:
outs.append(artifact_path)
else:
Expand Down
12 changes: 9 additions & 3 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,16 @@ def make_report(self):
def make_dvcyaml(self):
make_dvcyaml(self)

@catch_and_warn(Exception, logger)
@catch_and_warn(DvcException, logger)
def make_stage_template(self):
stage_content = get_dvc_stage_template(self)
dump_yaml(stage_content, self._stage_template_file)
if self._dvc_repo:
stage_content = get_dvc_stage_template(self)
dump_yaml(stage_content, self._stage_template_file)
else:
logger.warning(
"Can't make stage template without a DVC repo."
"\nRun `dvc init` to initialize a DVC repo."
)

def end(self):
if self._inside_with:
Expand Down
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,26 @@ def dvc_repo(tmp_dir):
return repo


@pytest.fixture()
def dvc_repo_subdir(tmp_dir):
from dvc.repo import Repo
from scmrepo.git import Git

Git.init(tmp_dir)
subdir = tmp_dir / "subdir"
subdir.mkdir()
repo = Repo.init(subdir, subdir=True)
repo.scm.add_commit(".", "init")
return repo


@pytest.fixture()
def dvc_repo_no_scm(tmp_dir):
from dvc.repo import Repo

return Repo.init(tmp_dir, no_scm=True)


@pytest.fixture(autouse=True)
def _capture_wrap():
# https://github.com/pytest-dev/pytest/issues/5502#issuecomment-678368525
Expand Down
42 changes: 38 additions & 4 deletions tests/test_dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ def test_warn_on_dvcyaml_output_overlap(tmp_dir, mocker, mocked_dvc_repo, dvcyam

def test_get_dvc_stage_template_empty(tmp_dir, mocked_dvc_repo):
live = Live()
live.log_param("foo", 1)
live.log_metric("bar", 1)
live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250)))
live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0])
template = get_dvc_stage_template(live)

assert template == {
Expand Down Expand Up @@ -425,10 +429,6 @@ def test_get_dvc_stage_template_chdir(tmp_dir, mocked_dvc_repo, monkeypatch):
d.mkdir(parents=True)
monkeypatch.chdir(d)
live = Live("live")
live.log_param("foo", 1)
live.log_metric("bar", 1)
live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250)))
live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0])
live.log_artifact("artifact.txt")
template = get_dvc_stage_template(live)

Expand All @@ -441,3 +441,37 @@ def test_get_dvc_stage_template_chdir(tmp_dir, mocked_dvc_repo, monkeypatch):
}
}
}


def test_get_dvc_stage_template_subdir(tmp_dir, dvc_repo_subdir):
d = dvc_repo_subdir.root_dir
os.chdir(d)
live = Live("live")
live.log_artifact("artifact.txt")
template = get_dvc_stage_template(live)

assert template == {
"stages": {
"dvclive": {
"cmd": "<python my_code_file.py my_args>",
"deps": ["<my_code_file.py>"],
"outs": ["artifact.txt"],
}
}
}


def test_get_dvc_stage_template_no_scm(tmp_dir, dvc_repo_no_scm):
live = Live("live")
live.log_artifact("artifact.txt")
template = get_dvc_stage_template(live)

assert template == {
"stages": {
"dvclive": {
"cmd": "<python my_code_file.py my_args>",
"deps": ["<my_code_file.py>"],
"outs": ["artifact.txt"],
}
}
}
17 changes: 17 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,23 @@ def test_make_stage_template(tmp_dir, dvc_repo):
assert stage_template_path.is_file()


def test_make_stage_template_no_repo(tmp_dir, mocker):
logger = mocker.patch("dvclive.live.logger")
dvclive = Live("logs")
dvclive.log_metric("m1", 1)

dvclive.make_stage_template()

logger.warning.assert_called_with(
"Can't make stage template without a DVC repo."
"\nRun `dvc init` to initialize a DVC repo."
)

stage_template_path = tmp_dir / dvclive.dir / "stage_template.yaml"

assert not stage_template_path.is_file()


@pytest.mark.parametrize("report", ["html", None])
@pytest.mark.parametrize("dvcyaml", [True, False])
def test_end(tmp_dir, dvc_repo, report, dvcyaml):
Expand Down

0 comments on commit 2c164ca

Please sign in to comment.