Skip to content

Commit

Permalink
unify rel_path util func
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed Aug 28, 2023
1 parent 07a64b1 commit d72bbea
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 22 deletions.
18 changes: 7 additions & 11 deletions src/dvclive/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dvclive.plots import Image, Metric
from dvclive.serialize import dump_yaml
from dvclive.utils import StrPath
from dvclive.utils import StrPath, rel_path

if TYPE_CHECKING:
from dvc.repo import Repo
Expand Down Expand Up @@ -74,35 +74,31 @@ def get_dvc_repo() -> Optional["Repo"]:
def make_dvcyaml(live) -> None: # noqa: C901
dvcyaml_dir = Path(live.dvc_file).parent.absolute().as_posix()

def _get_relpath(path):
path = Path(path).absolute().as_posix()
return os.path.relpath(path, dvcyaml_dir)

dvcyaml = {}
if live._params:
dvcyaml["params"] = [_get_relpath(live.params_file)]
dvcyaml["params"] = [rel_path(live.params_file, dvcyaml_dir)]
if live._metrics or live.summary:
dvcyaml["metrics"] = [_get_relpath(live.metrics_file)]
dvcyaml["metrics"] = [rel_path(live.metrics_file, dvcyaml_dir)]
plots: List[Any] = []
plots_path = Path(live.plots_dir)
plots_metrics_path = plots_path / Metric.subfolder
if plots_metrics_path.exists():
metrics_config = {_get_relpath(plots_metrics_path): {"x": "step"}}
metrics_config = {rel_path(plots_metrics_path, dvcyaml_dir): {"x": "step"}}
plots.append(metrics_config)
if live._images:
images_path = _get_relpath(plots_path / Image.subfolder)
images_path = rel_path(plots_path / Image.subfolder, dvcyaml_dir)
plots.append(images_path)
if live._plots:
for plot in live._plots.values():
plot_path = _get_relpath(plot.output_path)
plot_path = rel_path(plot.output_path, dvcyaml_dir)
plots.append({plot_path: plot.plot_config})
if plots:
dvcyaml["plots"] = plots

if live._artifacts:
dvcyaml["artifacts"] = copy.deepcopy(live._artifacts)
for artifact in dvcyaml["artifacts"].values(): # type: ignore
artifact["path"] = _get_relpath(artifact["path"])
artifact["path"] = rel_path(artifact["path"], dvcyaml_dir)

if not os.path.exists(live.dvc_file):
dump_yaml(dvcyaml, live.dvc_file)
Expand Down
16 changes: 5 additions & 11 deletions src/dvclive/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import base64
import math
import os
from pathlib import Path

from dvc_studio_client.post_live_metrics import get_studio_config

from dvclive.serialize import load_yaml
from dvclive.utils import parse_metrics
from dvclive.utils import parse_metrics, rel_path


def _get_unsent_datapoints(plot, latest_step):
Expand All @@ -30,18 +29,13 @@ def _cast_to_numbers(datapoints):
return datapoints


def _rel_path(path, dvc_root_path):
absolute_path = Path(path).resolve()
return str(absolute_path.relative_to(dvc_root_path).as_posix())


def _adapt_plot_name(live, name):
if live._dvc_repo is not None:
name = _rel_path(name, live._dvc_repo.root_dir)
name = rel_path(name, live._dvc_repo.root_dir)
if os.path.isfile(live.dvc_file):
dvc_file = live.dvc_file
if live._dvc_repo is not None:
dvc_file = _rel_path(live.dvc_file, live._dvc_repo.root_dir)
dvc_file = rel_path(live.dvc_file, live._dvc_repo.root_dir)
name = f"{dvc_file}::{name}"
return name

Expand Down Expand Up @@ -70,7 +64,7 @@ def get_studio_updates(live):
if os.path.isfile(live.params_file):
params_file = live.params_file
if live._dvc_repo is not None:
params_file = _rel_path(params_file, live._dvc_repo.root_dir)
params_file = rel_path(params_file, live._dvc_repo.root_dir)
params = {params_file: load_yaml(live.params_file)}
else:
params = {}
Expand All @@ -79,7 +73,7 @@ def get_studio_updates(live):

metrics_file = live.metrics_file
if live._dvc_repo is not None:
metrics_file = _rel_path(metrics_file, live._dvc_repo.root_dir)
metrics_file = rel_path(metrics_file, live._dvc_repo.root_dir)
metrics = {metrics_file: {"data": metrics}}

plots = {
Expand Down
5 changes: 5 additions & 0 deletions src/dvclive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,8 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


def rel_path(path, dvc_root_path):
absolute_path = Path(path).absolute()
return str(Path(os.path.relpath(absolute_path, dvc_root_path)).as_posix())

0 comments on commit d72bbea

Please sign in to comment.