From ef260eaf2e2c207edc26505f510e4f79b11ba6a4 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Thu, 24 Aug 2023 10:31:56 -0400 Subject: [PATCH] exp show: show metrics that include separator (#9819) * exp show: show metrics that include separator * sort using partition instead of lpartition * handle ':' in file and metric names * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix sorting with leading : * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- dvc/repo/experiments/show.py | 21 +++++++++++++-------- tests/func/experiments/test_show.py | 14 ++++++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/dvc/repo/experiments/show.py b/dvc/repo/experiments/show.py index 155bc75b7a..04bb193a59 100644 --- a/dvc/repo/experiments/show.py +++ b/dvc/repo/experiments/show.py @@ -177,26 +177,31 @@ def _build_rows( ) -def _sort_column( +def _sort_column( # noqa: C901 sort_by: str, metric_names: Mapping[str, Iterable[str]], param_names: Mapping[str, Iterable[str]], ) -> Tuple[str, str, str]: - path, _, sort_name = sort_by.rpartition(":") + sep = ":" + parts = sort_by.split(sep) matches: Set[Tuple[str, str, str]] = set() - if path: + for split_num in range(len(parts)): + path = sep.join(parts[:split_num]) + sort_name = sep.join(parts[split_num:]) + if not path: # handles ':metric_name' case + sort_by = sort_name if path in metric_names and sort_name in metric_names[path]: matches.add((path, sort_name, "metrics")) if path in param_names and sort_name in param_names[path]: matches.add((path, sort_name, "params")) - else: + if not matches: for path in metric_names: - if sort_name in metric_names[path]: - matches.add((path, sort_name, "metrics")) + if sort_by in metric_names[path]: + matches.add((path, sort_by, "metrics")) for path in param_names: - if sort_name in param_names[path]: - matches.add((path, sort_name, "params")) + if sort_by in param_names[path]: + matches.add((path, sort_by, "params")) if len(matches) == 1: return matches.pop() diff --git a/tests/func/experiments/test_show.py b/tests/func/experiments/test_show.py index 4a3ffbf34c..6ecb481170 100644 --- a/tests/func/experiments/test_show.py +++ b/tests/func/experiments/test_show.py @@ -352,6 +352,20 @@ def test_show_sort(tmp_dir, scm, dvc, exp_stage, caplog): assert main(["exp", "show", "--no-pager", "--sort-by=metrics.yaml:foo"]) == 0 +def test_show_sort_metric_sep(tmp_dir, scm, dvc, caplog): + metrics_path = tmp_dir / "metrics:1.json" + metrics_path.write_text('{"my::metric": 1, "other_metric": 0.5}') + metrics_path = tmp_dir / "metrics:2.json" + metrics_path.write_text('{"my::metric": 2}') + dvcyaml_path = tmp_dir / "dvc.yaml" + dvcyaml_path.write_text("metrics: ['metrics:1.json', 'metrics:2.json']") + dvc.experiments.save() + assert ( + main(["exp", "show", "--no-pager", "--sort-by=metrics:1.json:my::metric"]) == 0 + ) + assert main(["exp", "show", "--no-pager", "--sort-by=:other_metric"]) == 0 + + @pytest.mark.vscode @pytest.mark.parametrize( "status, pid_exists",