Skip to content

Commit

Permalink
Add step completed signal file for VS Code (#688)
Browse files Browse the repository at this point in the history
* Add step completed signal file for VS Code

* Use new DVC environment variable

* Write file if no env variable present

* Bump `dvc>=3.17.0`.

---------

Co-authored-by: daavoo <[email protected]>
  • Loading branch information
mattseddon and daavoo authored Aug 29, 2023
1 parent 914a7e6 commit bf38034
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"dvc>=2.58.0",
"dvc>=3.17.0",
"dvc-render>=0.5.0,<1.0",
"dvc-studio-client>=0.10.0,<1",
"funcy",
Expand Down
44 changes: 44 additions & 0 deletions src/dvclive/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dvclive.serialize import dump_yaml
from dvclive.utils import StrPath

from . import env

if TYPE_CHECKING:
from dvc.repo import Repo
from dvc.stage import Stage
Expand All @@ -27,6 +29,11 @@ def _dvclive_only_signal_file(root_dir: StrPath) -> str:
return os.path.join(dvc_exps_run_dir, "DVCLIVE_ONLY")


def _dvclive_step_completed_signal_file(root_dir: StrPath) -> str:
dvc_exps_run_dir = _dvc_exps_run_dir(root_dir)
return os.path.join(dvc_exps_run_dir, "DVCLIVE_STEP_COMPLETED")


def _find_dvc_root(root: Optional[StrPath] = None) -> Optional[str]:
if not root:
root = os.getcwd()
Expand All @@ -46,6 +53,10 @@ def _find_dvc_root(root: Optional[StrPath] = None) -> Optional[str]:
return None


def _find_non_queue_root() -> Optional[str]:
return os.getenv(env.DVC_ROOT) or _find_dvc_root()


def _write_file(file: str, contents: Dict[str, Union[str, int]]):
import builtins

Expand Down Expand Up @@ -106,6 +117,39 @@ def make_dvcyaml(live) -> None:
dump_yaml(dvcyaml, live.dvc_file)


def mark_dvclive_step_completed(step: int) -> None:
"""
https://github.com/iterative/vscode-dvc/issues/4528
Signal DVC VS Code extension that
a step has been completed for an experiment running in the queue
"""
non_queue_root_dir = _find_non_queue_root()

if not non_queue_root_dir:
return

exp_run_dir = _dvc_exps_run_dir(non_queue_root_dir)
os.makedirs(exp_run_dir, exist_ok=True)

signal_file = _dvclive_step_completed_signal_file(non_queue_root_dir)

_write_file(signal_file, {"pid": os.getpid(), "step": step})


def cleanup_dvclive_step_completed() -> None:
non_queue_root_dir = _find_non_queue_root()

if not non_queue_root_dir:
return

signal_file = _dvclive_step_completed_signal_file(non_queue_root_dir)

if not os.path.exists(signal_file):
return

os.remove(signal_file)


def mark_dvclive_only_started(exp_name: str) -> None:
"""
Signal DVC VS Code extension that
Expand Down
1 change: 1 addition & 0 deletions src/dvclive/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
DVC_CHECKPOINT = "DVC_CHECKPOINT"
DVC_EXP_BASELINE_REV = "DVC_EXP_BASELINE_REV"
DVC_EXP_NAME = "DVC_EXP_NAME"
DVC_ROOT = "DVC_ROOT"
5 changes: 5 additions & 0 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@

from . import env
from .dvc import (
cleanup_dvclive_step_completed,
ensure_dir_is_tracked,
find_overlapping_stage,
get_dvc_repo,
get_random_exp_name,
make_dvcyaml,
mark_dvclive_only_ended,
mark_dvclive_only_started,
mark_dvclive_step_completed,
)
from .error import (
InvalidDataTypeError,
Expand Down Expand Up @@ -295,6 +297,7 @@ def next_step(self):
self.make_dvcyaml()

self.make_report()
mark_dvclive_step_completed(self.step)
self.step += 1

def log_metric(
Expand Down Expand Up @@ -570,6 +573,8 @@ def end(self):
else:
self.make_report()

cleanup_dvclive_step_completed()

def read_step(self):
if Path(self.metrics_file).exists():
latest = self.read_latest()
Expand Down
52 changes: 52 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,58 @@ def test_context_manager_skips_end_calls(tmp_dir):
assert (tmp_dir / live.metrics_file).exists()


@pytest.mark.vscode()
@pytest.mark.parametrize("dvc_root", [True, False])
def test_vscode_dvclive_step_completed_signal_file(
tmp_dir, dvc_root, mocker, monkeypatch
):
signal_file = os.path.join(
tmp_dir, ".dvc", "tmp", "exps", "run", "DVCLIVE_STEP_COMPLETED"
)
cwd = tmp_dir
test_pid = 12345

if dvc_root:
cwd = tmp_dir / ".dvc" / "tmp" / "exps" / "asdasasf"
monkeypatch.setenv(env.DVC_ROOT, tmp_dir)
(cwd / ".dvc").mkdir(parents=True)

assert not os.path.exists(signal_file)

dvc_repo = mocker.MagicMock()
dvc_repo.index.stages = []
dvc_repo.config = {}
dvc_repo.scm.get_rev.return_value = "current_rev"
dvc_repo.scm.get_ref.return_value = None
dvc_repo.scm.no_commits = False
with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo), mocker.patch(
"dvclive.live.os.getpid", return_value=test_pid
):
dvclive = Live(save_dvc_exp=True)
assert not os.path.exists(signal_file)
dvclive.next_step()
assert dvclive.step == 1

if dvc_root:
assert os.path.exists(signal_file)
with open(signal_file, encoding="utf-8") as f:
assert json.load(f) == {"pid": test_pid, "step": 0}

else:
assert not os.path.exists(signal_file)

dvclive.next_step()
assert dvclive.step == 2

if dvc_root:
with open(signal_file, encoding="utf-8") as f:
assert json.load(f) == {"pid": test_pid, "step": 1}

dvclive.end()

assert not os.path.exists(signal_file)


@pytest.mark.vscode()
@pytest.mark.parametrize("dvc_root", [True, False])
def test_vscode_dvclive_only_signal_file(tmp_dir, dvc_root, mocker):
Expand Down

0 comments on commit bf38034

Please sign in to comment.