diff --git a/src/dvclive/dvc.py b/src/dvclive/dvc.py index bfb46785..321383cf 100644 --- a/src/dvclive/dvc.py +++ b/src/dvclive/dvc.py @@ -174,13 +174,14 @@ def get_dvc_stage_template(live): "deps": [""], } outs = [] - rel_path = Path(os.path.relpath(os.getcwd(), live._dvc_repo.root_dir)) + root_dir = live._dvc_repo.root_dir if live._images and live._cache_images: - images_path = (rel_path / live.plots_dir / Image.subfolder).as_posix() + plots_path = Path(os.path.relpath(live.plots_dir, root_dir)) + images_path = (plots_path / Image.subfolder).as_posix() outs.append(images_path) for o, cache in live._outs.items(): artifact_path = Path(os.getcwd()) / o - artifact_path = artifact_path.relative_to(live._dvc_repo.root_dir).as_posix() + artifact_path = artifact_path.relative_to(root_dir).as_posix() if cache: outs.append(artifact_path) else: diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 3ca8f9b0..c863b6e5 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -531,10 +531,16 @@ def make_report(self): def make_dvcyaml(self): make_dvcyaml(self) - @catch_and_warn(Exception, logger) + @catch_and_warn(DvcException, logger) def make_stage_template(self): - stage_content = get_dvc_stage_template(self) - dump_yaml(stage_content, self._stage_template_file) + if self._dvc_repo: + stage_content = get_dvc_stage_template(self) + dump_yaml(stage_content, self._stage_template_file) + else: + logger.warning( + "Can't make stage template without a DVC repo." + "\nRun `dvc init` to initialize a DVC repo." + ) def end(self): if self._inside_with: diff --git a/tests/conftest.py b/tests/conftest.py index 1c856e45..d9a6e8ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,6 +36,26 @@ def dvc_repo(tmp_dir): return repo +@pytest.fixture() +def dvc_repo_subdir(tmp_dir): + from dvc.repo import Repo + from scmrepo.git import Git + + Git.init(tmp_dir) + subdir = tmp_dir / "subdir" + subdir.mkdir() + repo = Repo.init(subdir, subdir=True) + repo.scm.add_commit(".", "init") + return repo + + +@pytest.fixture() +def dvc_repo_no_scm(tmp_dir): + from dvc.repo import Repo + + return Repo.init(tmp_dir, no_scm=True) + + @pytest.fixture(autouse=True) def _capture_wrap(): # https://github.com/pytest-dev/pytest/issues/5502#issuecomment-678368525 diff --git a/tests/test_dvc.py b/tests/test_dvc.py index 2191362a..b5ca2cac 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -342,6 +342,10 @@ def test_warn_on_dvcyaml_output_overlap(tmp_dir, mocker, mocked_dvc_repo, dvcyam def test_get_dvc_stage_template_empty(tmp_dir, mocked_dvc_repo): live = Live() + live.log_param("foo", 1) + live.log_metric("bar", 1) + live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250))) + live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0]) template = get_dvc_stage_template(live) assert template == { @@ -425,10 +429,6 @@ def test_get_dvc_stage_template_chdir(tmp_dir, mocked_dvc_repo, monkeypatch): d.mkdir(parents=True) monkeypatch.chdir(d) live = Live("live") - live.log_param("foo", 1) - live.log_metric("bar", 1) - live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250))) - live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0]) live.log_artifact("artifact.txt") template = get_dvc_stage_template(live) @@ -441,3 +441,37 @@ def test_get_dvc_stage_template_chdir(tmp_dir, mocked_dvc_repo, monkeypatch): } } } + + +def test_get_dvc_stage_template_subdir(tmp_dir, dvc_repo_subdir): + d = dvc_repo_subdir.root_dir + os.chdir(d) + live = Live("live") + live.log_artifact("artifact.txt") + template = get_dvc_stage_template(live) + + assert template == { + "stages": { + "dvclive": { + "cmd": "", + "deps": [""], + "outs": ["artifact.txt"], + } + } + } + + +def test_get_dvc_stage_template_no_scm(tmp_dir, dvc_repo_no_scm): + live = Live("live") + live.log_artifact("artifact.txt") + template = get_dvc_stage_template(live) + + assert template == { + "stages": { + "dvclive": { + "cmd": "", + "deps": [""], + "outs": ["artifact.txt"], + } + } + } diff --git a/tests/test_main.py b/tests/test_main.py index 43d2513f..cf03615f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -484,6 +484,23 @@ def test_make_stage_template(tmp_dir, dvc_repo): assert stage_template_path.is_file() +def test_make_stage_template_no_repo(tmp_dir, mocker): + logger = mocker.patch("dvclive.live.logger") + dvclive = Live("logs") + dvclive.log_metric("m1", 1) + + dvclive.make_stage_template() + + logger.warning.assert_called_with( + "Can't make stage template without a DVC repo." + "\nRun `dvc init` to initialize a DVC repo." + ) + + stage_template_path = tmp_dir / dvclive.dir / "stage_template.yaml" + + assert not stage_template_path.is_file() + + @pytest.mark.parametrize("report", ["html", None]) @pytest.mark.parametrize("dvcyaml", [True, False]) def test_end(tmp_dir, dvc_repo, report, dvcyaml):