From ba66faab4fae87451f409fdeb24532a02c981e6c Mon Sep 17 00:00:00 2001 From: AbdullahMakhdoom Date: Mon, 9 Sep 2024 04:26:34 +0500 Subject: [PATCH 1/2] refactor: Update `Live` class to handle `pathlib.Path` object for `dvcyaml` argument. --- src/dvclive/live.py | 25 +++++++++++++++++-------- tests/test_make_dvcyaml.py | 3 ++- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index c0b4aa81..239f16cc 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -82,7 +82,7 @@ def __init__( resume: bool = False, report: Literal["md", "notebook", "html", None] = None, save_dvc_exp: bool = True, - dvcyaml: Optional[str] = "dvc.yaml", + dvcyaml: Union[str, os.PathLike, bool, None] = "dvc.yaml", cache_images: bool = False, exp_name: Optional[str] = None, exp_message: Optional[str] = None, @@ -104,11 +104,11 @@ def __init__( part of `Live.end()`. Defaults to `True`. If you are using DVCLive inside a DVC Pipeline and running with `dvc exp run`, the option will be ignored. - dvcyaml (str | None): where to write dvc.yaml file, which adds DVC + dvcyaml (str | Path | None): where to write dvc.yaml file, which adds DVC configuration for metrics, plots, and parameters as part of `Live.next_step()` and `Live.end()`. If `None`, no dvc.yaml file is written. Defaults to `"dvc.yaml"`. See `Live.make_dvcyaml()`. - If a string like `"subdir/dvc.yaml"`, DVCLive will write the + If a string or Path like `"subdir/dvc.yaml"`, DVCLive will write the configuration to that path (file must be named "dvc.yaml"). If `False`, DVCLive will not write to "dvc.yaml" (useful if you are tracking DVCLive metrics, plots, and parameters independently and @@ -265,11 +265,19 @@ def _init_dvc(self): # noqa: C901 self._include_untracked.append(self.dir) def _init_dvc_file(self) -> str: - if isinstance(self._dvcyaml, str): - if os.path.basename(self._dvcyaml) == "dvc.yaml": - return self._dvcyaml - raise InvalidDvcyamlError - return "dvc.yaml" + if self._dvcyaml is None: + return "dvc.yaml" + if isinstance(self._dvcyaml, bool): + return "dvc.yaml" + + self._dvcyaml = os.fspath(self._dvcyaml) + if ( + isinstance(self._dvcyaml, str) + and os.path.basename(self._dvcyaml) == "dvc.yaml" + ): + return self._dvcyaml + + raise InvalidDvcyamlError def _init_dvc_pipeline(self): if os.getenv(env.DVC_EXP_BASELINE_REV, None): @@ -334,6 +342,7 @@ def _init_test(self): """ with tempfile.TemporaryDirectory() as dirpath: self._dir = os.path.join(dirpath, self._dir) + self._dvcyaml = os.fspath(self._dvcyaml) if isinstance(self._dvcyaml, str): self._dvc_file = os.path.join(dirpath, self._dvcyaml) self._save_dvc_exp = False diff --git a/tests/test_make_dvcyaml.py b/tests/test_make_dvcyaml.py index 7f5da8bf..16397851 100644 --- a/tests/test_make_dvcyaml.py +++ b/tests/test_make_dvcyaml.py @@ -2,6 +2,7 @@ import pytest from PIL import Image +from pathlib import Path from dvclive import Live from dvclive.dvc import make_dvcyaml @@ -423,7 +424,7 @@ def test_warn_on_dvcyaml_output_overlap(tmp_dir, mocker, mocked_dvc_repo, dvcyam @pytest.mark.parametrize( "dvcyaml", - [True, False, "dvc.yaml"], + [True, False, "dvc.yaml", Path("dvc.yaml")], ) def test_make_dvcyaml(tmp_dir, mocked_dvc_repo, dvcyaml): dvclive = Live("logs", dvcyaml=dvcyaml) From 838d9d97a7d93b6a2dd2a6417ffd76fde6cb817d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 18:27:26 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/dvclive/live.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 239f16cc..754137f3 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -82,7 +82,7 @@ def __init__( resume: bool = False, report: Literal["md", "notebook", "html", None] = None, save_dvc_exp: bool = True, - dvcyaml: Union[str, os.PathLike, bool, None] = "dvc.yaml", + dvcyaml: Union[str, os.PathLike, bool, None] = "dvc.yaml", cache_images: bool = False, exp_name: Optional[str] = None, exp_message: Optional[str] = None, @@ -269,14 +269,14 @@ def _init_dvc_file(self) -> str: return "dvc.yaml" if isinstance(self._dvcyaml, bool): return "dvc.yaml" - + self._dvcyaml = os.fspath(self._dvcyaml) if ( isinstance(self._dvcyaml, str) and os.path.basename(self._dvcyaml) == "dvc.yaml" ): return self._dvcyaml - + raise InvalidDvcyamlError def _init_dvc_pipeline(self):