From bf3803456b763819e0584a3e90d8d4f9a6eaa826 Mon Sep 17 00:00:00 2001 From: Matt Seddon <37993418+mattseddon@users.noreply.github.com> Date: Wed, 30 Aug 2023 03:05:59 +1000 Subject: [PATCH] Add step completed signal file for VS Code (#688) * Add step completed signal file for VS Code * Use new DVC environment variable * Write file if no env variable present * Bump `dvc>=3.17.0`. --------- Co-authored-by: daavoo --- pyproject.toml | 2 +- src/dvclive/dvc.py | 44 ++++++++++++++++++++++++++++++++++++++ src/dvclive/env.py | 1 + src/dvclive/live.py | 5 +++++ tests/test_main.py | 52 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 103 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9f972aec..5aca0b71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "dvc>=2.58.0", + "dvc>=3.17.0", "dvc-render>=0.5.0,<1.0", "dvc-studio-client>=0.10.0,<1", "funcy", diff --git a/src/dvclive/dvc.py b/src/dvclive/dvc.py index 08ed9939..6cc95b94 100644 --- a/src/dvclive/dvc.py +++ b/src/dvclive/dvc.py @@ -9,6 +9,8 @@ from dvclive.serialize import dump_yaml from dvclive.utils import StrPath +from . import env + if TYPE_CHECKING: from dvc.repo import Repo from dvc.stage import Stage @@ -27,6 +29,11 @@ def _dvclive_only_signal_file(root_dir: StrPath) -> str: return os.path.join(dvc_exps_run_dir, "DVCLIVE_ONLY") +def _dvclive_step_completed_signal_file(root_dir: StrPath) -> str: + dvc_exps_run_dir = _dvc_exps_run_dir(root_dir) + return os.path.join(dvc_exps_run_dir, "DVCLIVE_STEP_COMPLETED") + + def _find_dvc_root(root: Optional[StrPath] = None) -> Optional[str]: if not root: root = os.getcwd() @@ -46,6 +53,10 @@ def _find_dvc_root(root: Optional[StrPath] = None) -> Optional[str]: return None +def _find_non_queue_root() -> Optional[str]: + return os.getenv(env.DVC_ROOT) or _find_dvc_root() + + def _write_file(file: str, contents: Dict[str, Union[str, int]]): import builtins @@ -106,6 +117,39 @@ def make_dvcyaml(live) -> None: dump_yaml(dvcyaml, live.dvc_file) +def mark_dvclive_step_completed(step: int) -> None: + """ + https://github.com/iterative/vscode-dvc/issues/4528 + Signal DVC VS Code extension that + a step has been completed for an experiment running in the queue + """ + non_queue_root_dir = _find_non_queue_root() + + if not non_queue_root_dir: + return + + exp_run_dir = _dvc_exps_run_dir(non_queue_root_dir) + os.makedirs(exp_run_dir, exist_ok=True) + + signal_file = _dvclive_step_completed_signal_file(non_queue_root_dir) + + _write_file(signal_file, {"pid": os.getpid(), "step": step}) + + +def cleanup_dvclive_step_completed() -> None: + non_queue_root_dir = _find_non_queue_root() + + if not non_queue_root_dir: + return + + signal_file = _dvclive_step_completed_signal_file(non_queue_root_dir) + + if not os.path.exists(signal_file): + return + + os.remove(signal_file) + + def mark_dvclive_only_started(exp_name: str) -> None: """ Signal DVC VS Code extension that diff --git a/src/dvclive/env.py b/src/dvclive/env.py index db29d1fc..4b26b651 100644 --- a/src/dvclive/env.py +++ b/src/dvclive/env.py @@ -4,3 +4,4 @@ DVC_CHECKPOINT = "DVC_CHECKPOINT" DVC_EXP_BASELINE_REV = "DVC_EXP_BASELINE_REV" DVC_EXP_NAME = "DVC_EXP_NAME" +DVC_ROOT = "DVC_ROOT" diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 1f894eaf..670a199b 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -13,6 +13,7 @@ from . import env from .dvc import ( + cleanup_dvclive_step_completed, ensure_dir_is_tracked, find_overlapping_stage, get_dvc_repo, @@ -20,6 +21,7 @@ make_dvcyaml, mark_dvclive_only_ended, mark_dvclive_only_started, + mark_dvclive_step_completed, ) from .error import ( InvalidDataTypeError, @@ -295,6 +297,7 @@ def next_step(self): self.make_dvcyaml() self.make_report() + mark_dvclive_step_completed(self.step) self.step += 1 def log_metric( @@ -570,6 +573,8 @@ def end(self): else: self.make_report() + cleanup_dvclive_step_completed() + def read_step(self): if Path(self.metrics_file).exists(): latest = self.read_latest() diff --git a/tests/test_main.py b/tests/test_main.py index f7d38fe4..3ee70653 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -400,6 +400,58 @@ def test_context_manager_skips_end_calls(tmp_dir): assert (tmp_dir / live.metrics_file).exists() +@pytest.mark.vscode() +@pytest.mark.parametrize("dvc_root", [True, False]) +def test_vscode_dvclive_step_completed_signal_file( + tmp_dir, dvc_root, mocker, monkeypatch +): + signal_file = os.path.join( + tmp_dir, ".dvc", "tmp", "exps", "run", "DVCLIVE_STEP_COMPLETED" + ) + cwd = tmp_dir + test_pid = 12345 + + if dvc_root: + cwd = tmp_dir / ".dvc" / "tmp" / "exps" / "asdasasf" + monkeypatch.setenv(env.DVC_ROOT, tmp_dir) + (cwd / ".dvc").mkdir(parents=True) + + assert not os.path.exists(signal_file) + + dvc_repo = mocker.MagicMock() + dvc_repo.index.stages = [] + dvc_repo.config = {} + dvc_repo.scm.get_rev.return_value = "current_rev" + dvc_repo.scm.get_ref.return_value = None + dvc_repo.scm.no_commits = False + with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo), mocker.patch( + "dvclive.live.os.getpid", return_value=test_pid + ): + dvclive = Live(save_dvc_exp=True) + assert not os.path.exists(signal_file) + dvclive.next_step() + assert dvclive.step == 1 + + if dvc_root: + assert os.path.exists(signal_file) + with open(signal_file, encoding="utf-8") as f: + assert json.load(f) == {"pid": test_pid, "step": 0} + + else: + assert not os.path.exists(signal_file) + + dvclive.next_step() + assert dvclive.step == 2 + + if dvc_root: + with open(signal_file, encoding="utf-8") as f: + assert json.load(f) == {"pid": test_pid, "step": 1} + + dvclive.end() + + assert not os.path.exists(signal_file) + + @pytest.mark.vscode() @pytest.mark.parametrize("dvc_root", [True, False]) def test_vscode_dvclive_only_signal_file(tmp_dir, dvc_root, mocker):