Skip to content

Commit

Permalink
exp: Unify dvc exp list and show output (#9808)
Browse files Browse the repository at this point in the history
* use _describe in exp list

* fix pre-commit checks

* simplistic fix

* single _describe, too complex

* slightly simpler

* even simpler

* simple, but with regex

* less substitutions

* describe public

* name handling in UI (tests need updating)

* fixed old tests

* lazy imports

---------

Co-authored-by: Tibor Mach <[email protected]>
  • Loading branch information
tibor-mach and Tibor Mach authored Aug 10, 2023
1 parent f5dda2d commit 04e891c
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 86 deletions.
21 changes: 14 additions & 7 deletions dvc/commands/experiments/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,22 @@ def run(self):
git_remote=git_remote,
)

from dvc.repo.experiments.utils import describe
from dvc.scm import Git

if name_only or sha_only:
names = {}
else:
assert isinstance(self.repo.scm, Git)
names = describe(
self.repo.scm,
(baseline for baseline in exps),
logger=logger,
)

for baseline in exps:
if not (name_only or sha_only):
tag_base = "refs/tags/"
branch_base = "refs/heads/"
name = baseline[:7]
if baseline.startswith(tag_base):
name = baseline[len(tag_base) :]
elif baseline.startswith(branch_base):
name = baseline[len(branch_base) :]
name = names.get(baseline) or baseline[:7]
ui.write(f"{name}:")
for exp_name, rev in exps[baseline]:
if name_only:
Expand Down
54 changes: 4 additions & 50 deletions dvc/repo/experiments/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .exceptions import InvalidExpRefError
from .refs import EXEC_BRANCH, ExpRefInfo
from .serialize import ExpRange, ExpState, SerializableError, SerializableExp
from .utils import describe

if TYPE_CHECKING:
from dvc.repo import Repo
Expand Down Expand Up @@ -308,7 +309,9 @@ def collect(
if sha_only:
baseline_names: Dict[str, Optional[str]] = {}
else:
baseline_names = _describe(repo.scm, baseline_revs, refs=cached_refs)
baseline_names = describe(
repo.scm, baseline_revs, refs=cached_refs, logger=logger
)

workspace_data = collect_rev(repo, "workspace", **kwargs)
result: List["ExpState"] = [workspace_data]
Expand Down Expand Up @@ -337,55 +340,6 @@ def collect(
return result


def _describe(
scm: "Git",
revs: Iterable[str],
refs: Optional[Iterable[str]] = None,
) -> Dict[str, Optional[str]]:
"""Describe revisions using a tag, branch.
The first matching name will be returned for each rev. Names are preferred in this
order:
- current branch (if rev matches HEAD and HEAD is a branch)
- tags
- branches
Returns:
Dict mapping revisions from revs to a name.
"""

head_rev = scm.get_rev()
head_ref = scm.get_ref("HEAD", follow=False)
if head_ref and head_ref.startswith("refs/heads/"):
head_branch = head_ref[len("refs/heads/") :]
else:
head_branch = None

tags = {}
branches = {}
ref_it = iter(refs) if refs else scm.iter_refs()
for ref in ref_it:
is_tag = ref.startswith("refs/tags/")
is_branch = ref.startswith("refs/heads/")
if not (is_tag or is_branch):
continue
rev = scm.get_ref(ref)
if not rev:
logger.debug("unresolved ref %s", ref)
continue
if is_tag and rev not in tags:
tags[rev] = ref[len("refs/tags/") :]
if is_branch and rev not in branches:
branches[rev] = ref[len("refs/heads/") :]
names: Dict[str, Optional[str]] = {}
for rev in revs:
if rev == head_rev and head_branch:
names[rev] = head_branch
else:
names[rev] = tags.get(rev) or branches.get(rev)
return names


def _sorted_ranges(exp_ranges: Iterable["ExpRange"]) -> List["ExpRange"]:
"""Return list of ExpRange sorted by (timestamp, rev)."""

Expand Down
12 changes: 2 additions & 10 deletions dvc/repo/experiments/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,15 @@ def ls(
rev = [rev]
revs = iter_revs(repo.scm, rev, num)
rev_set = set(revs.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)

tags = repo.scm.describe(ref_info_dict.keys())
remained = {baseline for baseline, tag in tags.items() if tag is None}
base = "refs/heads"
ref_heads = repo.scm.describe(remained, base=base)

ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)
results = defaultdict(list)
for baseline in ref_info_dict:
name = baseline
if tags[baseline] or ref_heads[baseline]:
name = tags[baseline] or ref_heads[baseline]
for info in ref_info_dict[baseline]:
if git_remote:
exp_rev = None
else:
exp_rev = repo.scm.get_ref(str(info))
results[name].append((info.name, exp_rev))
results[baseline].append((info.name, exp_rev))

return results
52 changes: 52 additions & 0 deletions dvc/repo/experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,55 @@ def to_studio_params(dvc_params):
result[file_name] = file_data.get("data", {})

return result


def describe(
scm: "Git",
revs: Iterable[str],
logger,
refs: Optional[Iterable[str]] = None,
) -> Dict[str, Optional[str]]:
"""Describe revisions using a tag, branch.
The first matching name will be returned for each rev. Names are preferred in this
order:
- current branch (if rev matches HEAD and HEAD is a branch)
- tags
- branches
Returns:
Dict mapping revisions from revs to a name.
"""

head_rev = scm.get_rev()
head_ref = scm.get_ref("HEAD", follow=False)
if head_ref and head_ref.startswith("refs/heads/"):
head_branch = head_ref[len("refs/heads/") :]
else:
head_branch = None

tags = {}
branches = {}
ref_it = iter(refs) if refs else scm.iter_refs()
for ref in ref_it:
is_tag = ref.startswith("refs/tags/")
is_branch = ref.startswith("refs/heads/")
if not (is_tag or is_branch):
continue
rev = scm.get_ref(ref)
if not rev:
logger.debug("unresolved ref %s", ref)
continue
if is_tag and rev not in tags:
tags[rev] = ref[len("refs/tags/") :]
if is_branch and rev not in branches:
branches[rev] = ref[len("refs/heads/") :]

names: Dict[str, Optional[str]] = {}
for rev in revs:
if rev == head_rev and head_branch:
names[rev] = head_branch
else:
names[rev] = tags.get(rev) or branches.get(rev)

return names
25 changes: 14 additions & 11 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ def test_packed_args_exists(tmp_dir, scm, dvc, exp_stage, caplog):


def test_list(tmp_dir, scm, dvc, exp_stage):
baseline_a = scm.get_rev()
baseline_old = scm.get_rev()

results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"])
exp_a = first(results)
ref_info_a = first(exp_refs_by_rev(scm, exp_a))
Expand All @@ -358,34 +359,36 @@ def test_list(tmp_dir, scm, dvc, exp_stage):
ref_info_b = first(exp_refs_by_rev(scm, exp_b))

tmp_dir.scm_gen("new", "new", commit="new")
baseline_new = scm.get_rev()

results = dvc.experiments.run(exp_stage.addressing, params=["foo=4"])
exp_c = first(results)
ref_info_c = first(exp_refs_by_rev(scm, exp_c))

assert dvc.experiments.ls() == {"refs/heads/master": [(ref_info_c.name, exp_c)]}
assert dvc.experiments.ls() == {baseline_new: [(ref_info_c.name, exp_c)]}

exp_list = dvc.experiments.ls(rev=ref_info_a.baseline_sha)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)}
baseline_old: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)}
}

exp_list = dvc.experiments.ls(rev=[baseline_a, scm.get_rev()])
exp_list = dvc.experiments.ls(rev=[baseline_old, baseline_new])
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)},
"refs/heads/master": {(ref_info_c.name, exp_c)},
baseline_old: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)},
baseline_new: {(ref_info_c.name, exp_c)},
}

exp_list = dvc.experiments.ls(all_commits=True)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)},
"refs/heads/master": {(ref_info_c.name, exp_c)},
baseline_old: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)},
baseline_new: {(ref_info_c.name, exp_c)},
}

scm.checkout("branch", True)
exp_list = dvc.experiments.ls(all_commits=True)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)},
"refs/heads/branch": {(ref_info_c.name, exp_c)},
baseline_old: {(ref_info_a.name, exp_a), (ref_info_b.name, exp_b)},
baseline_new: {(ref_info_c.name, exp_c)},
}


Expand Down Expand Up @@ -666,7 +669,7 @@ def test_experiment_unchanged(tmp_dir, scm, dvc, exp_stage):
dvc.experiments.run(exp_stage.addressing)
dvc.experiments.run(exp_stage.addressing)

assert len(dvc.experiments.ls()["refs/heads/master"]) == 2
assert len(dvc.experiments.ls()[scm.get_rev()]) == 2


def test_experiment_run_dry(tmp_dir, scm, dvc, exp_stage):
Expand Down
12 changes: 7 additions & 5 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_push_ambiguous_name(tmp_dir, scm, dvc, git_upstream, exp_stage):

@pytest.mark.parametrize("use_url", [True, False])
def test_list_remote(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url):
baseline_a = scm.get_rev()
baseline_old = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"])
exp_a = first(results)
ref_info_a = first(exp_refs_by_rev(scm, exp_a))
Expand All @@ -149,6 +149,8 @@ def test_list_remote(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url):
ref_info_b = first(exp_refs_by_rev(scm, exp_b))

tmp_dir.scm_gen("new", "new", commit="new")
baseline_new = scm.get_rev()

results = dvc.experiments.run(exp_stage.addressing, params=["foo=4"])
exp_c = first(results)
ref_info_c = first(exp_refs_by_rev(scm, exp_c))
Expand All @@ -160,15 +162,15 @@ def test_list_remote(tmp_dir, scm, dvc, git_downstream, exp_stage, use_url):
assert downstream_exp.ls(git_remote=remote) == {}

git_downstream.tmp_dir.scm.fetch_refspecs(remote, ["master:master"])
exp_list = downstream_exp.ls(rev=baseline_a, git_remote=remote)
exp_list = downstream_exp.ls(rev=baseline_old, git_remote=remote)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a: {(ref_info_a.name, None), (ref_info_b.name, None)}
baseline_old: {(ref_info_a.name, None), (ref_info_b.name, None)}
}

exp_list = downstream_exp.ls(all_commits=True, git_remote=remote)
assert {key: set(val) for key, val in exp_list.items()} == {
baseline_a: {(ref_info_a.name, None), (ref_info_b.name, None)},
"refs/heads/master": {(ref_info_c.name, None)},
baseline_old: {(ref_info_a.name, None), (ref_info_b.name, None)},
baseline_new: {(ref_info_c.name, None)},
}


Expand Down
3 changes: 2 additions & 1 deletion tests/func/experiments/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ def test_exp_save_after_commit(tmp_dir, dvc, scm):
dvc.experiments.save(name="exp-1", force=True)

tmp_dir.scm_gen({"new_file": "new_file"}, commit="new baseline")
baseline_new = scm.get_rev()
dvc.experiments.save(name="exp-2", force=True)

all_exps = dvc.experiments.ls(all_commits=True)
assert all_exps[baseline][0][0] == "exp-1"
assert all_exps["refs/heads/master"][0][0] == "exp-2"
assert all_exps[baseline_new][0][0] == "exp-2"


def test_exp_save_with_staged_changes(tmp_dir, dvc, scm):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_experiments_list(dvc, scm, mocker):
@pytest.mark.parametrize(
"args,expected",
[
([], "main:\n\tsha-a [exp-a]\n"),
([], "master:\n\tsha-a [exp-a]\n"),
(["--name-only"], "exp-a\n"),
(["--sha-only"], "sha-a\n"),
],
Expand All @@ -198,7 +198,7 @@ def test_experiments_list_format(mocker, capsys, args, expected, dvc, scm):
mocker.patch(
"dvc.repo.experiments.ls.ls",
return_value={
"refs/heads/main": [
scm.get_rev(): [
("exp-a", "sha-a"),
]
},
Expand Down

0 comments on commit 04e891c

Please sign in to comment.