Skip to content

Commit

Permalink
artifacts: add support for monorepo (#10386)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Apr 16, 2024
1 parent c52a25e commit 8517ccc
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 15 deletions.
15 changes: 13 additions & 2 deletions dvc/api/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Any, Optional

from dvc.repo import Repo
Expand Down Expand Up @@ -36,12 +37,22 @@ def artifacts_show(
if version and stage:
raise ValueError("Artifact version and stage are mutually exclusive.")

from dvc.repo.artifacts import Artifacts
from dvc.utils import as_posix

repo_kwargs: dict[str, Any] = {
"subrepos": True,
"uninitialized": True,
}

dirname, _ = Artifacts.parse_path(name)
with Repo.open(repo, **repo_kwargs) as _repo:
rev = _repo.artifacts.get_rev(name, version=version, stage=stage)
with _repo.switch(rev):
path = _repo.artifacts.get_path(name)
return {"rev": rev, "path": path}
root = _repo.fs.root_marker
_dirname = _repo.fs.join(root, dirname) if dirname else root
with Repo(_dirname, fs=_repo.fs, scm=_repo.scm) as r:
path = r.artifacts.get_path(name)
path = _repo.fs.join(_repo.fs.root_marker, as_posix(path))
parts = _repo.fs.relparts(path, _repo.root_dir)
return {"rev": rev, "path": os.path.join(*parts)}
59 changes: 48 additions & 11 deletions dvc/repo/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
InvalidArgumentError,
)
from dvc.log import logger
from dvc.utils import relpath, resolve_output
from dvc.utils import as_posix, relpath, resolve_output
from dvc.utils.objects import cached_property
from dvc.utils.serialize import modify_yaml

Expand Down Expand Up @@ -99,7 +99,7 @@ def read(self) -> dict[str, dict[str, Artifact]]:
"""Read artifacts from dvc.yaml."""
artifacts: dict[str, dict[str, Artifact]] = {}
for dvcfile, dvcfile_artifacts in self.repo.index._artifacts.items():
dvcyaml = relpath(dvcfile, self.repo.root_dir)
dvcyaml = self.repo.fs.relpath(dvcfile, self.repo.root_dir)
artifacts[dvcyaml] = {}
for name, value in dvcfile_artifacts.items():
try:
Expand Down Expand Up @@ -147,8 +147,8 @@ def get_rev(
gto_tags: list["GTOTag"] = sort_versions(parse_tag(tag) for tag in tags)
return gto_tags[0].tag.target

def get_path(self, name: str):
"""Return repo fspath for the given artifact."""
@classmethod
def parse_path(cls, name: str) -> tuple[Optional[str], str]:
from gto.constants import SEPARATOR_IN_NAME, fullname_re

name = _reformat_name(name)
Expand All @@ -158,13 +158,37 @@ def get_path(self, name: str):
dirname = m.group("dirname")
if dirname:
dirname = dirname.rstrip(SEPARATOR_IN_NAME)
dvcyaml = os.path.join(dirname, PROJECT_FILE) if dirname else PROJECT_FILE
artifact_name = m.group("name")

return dirname, m.group("name")

def get_path(self, name: str):
"""Return fspath for the given artifact relative to the git root."""
from dvc.fs import GitFileSystem

dirname, artifact_name = self.parse_path(name)
# `name`/`dirname` are expected to be a git root relative.
# We convert it to dvc-root relative path so that we can read artifacts
# from dvc.yaml file.
# But we return dirname intact, as we want to return a git-root relative path.
# This is useful when reading from `dvcfs` from remote.
fs = self.repo.fs
assert self.scm
if isinstance(fs, GitFileSystem):
scm_root = fs.root_marker
else:
scm_root = self.scm.root_dir

dirparts = posixpath.normpath(dirname).split(posixpath.sep) if dirname else ()
abspath = fs.join(scm_root, *dirparts, PROJECT_FILE)
rela = fs.relpath(abspath, self.repo.root_dir)
try:
artifact = self.read()[dvcyaml][artifact_name]
artifact = self.read()[rela][artifact_name]
except KeyError as exc:
raise ArtifactNotFoundError(name) from exc
return os.path.join(dirname, artifact.path) if dirname else artifact.path

path = posixpath.join(dirname or "", artifact.path)
parts = posixpath.normpath(path).split(posixpath.sep)
return os.path.join(*parts)

def download(
self,
Expand All @@ -177,15 +201,28 @@ def download(
) -> tuple[int, str]:
"""Download the specified artifact."""
from dvc.fs import download as fs_download
from dvc.repo import Repo

logger.debug("Trying to download artifact '%s' via DVC", name)
rev = self.get_rev(name, version=version, stage=stage)

dirname, _ = self.parse_path(name)
with self.repo.switch(rev):
path = self.get_path(name)
root = self.repo.fs.root_marker
_dirname = self.repo.fs.join(root, dirname) if dirname else root
with Repo(_dirname, fs=self.repo.fs, scm=self.repo.scm) as r:
path = r.artifacts.get_path(name)
path = self.repo.fs.join(root, as_posix(path))
path = self.repo.fs.relpath(path, self.repo.root_dir)
# when the `repo` is a subrepo, the path `/subrepo/myart.pkl` for dvcfs
# should be translated as `/myart.pkl`,
# i.e. relative to the root of the subrepo
path = self.repo.fs.join(root, path)
path = self.repo.fs.normpath(path)

out = resolve_output(path, out, force=force)
fs = self.repo.dvcfs
fs_path = fs.from_os_path(path)
count = fs_download(fs, fs_path, os.path.abspath(out), jobs=jobs)
count = fs_download(fs, path, os.path.abspath(out), jobs=jobs)
return count, out

@staticmethod
Expand Down
74 changes: 74 additions & 0 deletions tests/func/api/test_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from os.path import join, normpath

import pytest

from dvc.api import artifacts_show
from dvc.testing.tmp_dir import make_subrepo
from dvc.utils import as_posix
from tests.func.artifacts.test_artifacts import get_tag_and_name, make_artifact


@pytest.mark.parametrize("sub", ["sub", ""])
def test_artifacts_show(tmp_dir, dvc, scm, sub):
subdir = tmp_dir / sub

dirname = str(subdir.relative_to(tmp_dir))
tag, name = get_tag_and_name(as_posix(dirname), "myart", "v2.0.0")
make_artifact(tmp_dir, "myart", tag, subdir / "myart.pkl")

assert artifacts_show(name) == {
"path": normpath(join(dirname, "myart.pkl")),
"rev": scm.get_rev(),
}
assert artifacts_show(name, repo=tmp_dir.fs_path) == {
"path": normpath(join(dirname, "myart.pkl")),
"rev": scm.get_rev(),
}
assert artifacts_show(name, repo=f"file://{tmp_dir.as_posix()}") == {
"path": normpath(join(dirname, "myart.pkl")),
"rev": scm.get_rev(),
}

assert artifacts_show(name, repo=subdir.fs_path) == {
"path": normpath(join(dirname, "myart.pkl")),
"rev": scm.get_rev(),
}
with subdir.chdir():
assert artifacts_show(name) == {
"path": normpath(join(dirname, "myart.pkl")),
"rev": scm.get_rev(),
}


@pytest.mark.parametrize("sub", ["sub", ""])
def test_artifacts_show_subrepo(tmp_dir, scm, sub):
subrepo = tmp_dir / "subrepo"
make_subrepo(subrepo, scm)
subdir = subrepo / sub

dirname = str(subdir.relative_to(tmp_dir))
tag, name = get_tag_and_name(as_posix(dirname), "myart", "v2.0.0")
make_artifact(subrepo, "myart", tag, subdir / "myart.pkl")

assert artifacts_show(name) == {
"path": join(dirname, "myart.pkl"),
"rev": scm.get_rev(),
}
assert artifacts_show(name, repo=tmp_dir.fs_path) == {
"path": join(dirname, "myart.pkl"),
"rev": scm.get_rev(),
}
assert artifacts_show(name, repo=f"file://{tmp_dir.as_posix()}") == {
"path": join(dirname, "myart.pkl"),
"rev": scm.get_rev(),
}

assert artifacts_show(name, repo=subdir.fs_path) == {
"path": str((subdir / "myart.pkl").relative_to(subrepo)),
"rev": scm.get_rev(),
}
with subdir.chdir():
assert artifacts_show(name) == {
"path": str((subdir / "myart.pkl").relative_to(subrepo)),
"rev": scm.get_rev(),
}
79 changes: 77 additions & 2 deletions tests/func/artifacts/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from dvc.annotations import Artifact
from dvc.exceptions import ArtifactNotFoundError, InvalidArgumentError
from dvc.repo.artifacts import check_name_format
from dvc.repo.artifacts import Artifacts, check_name_format
from dvc.testing.tmp_dir import make_subrepo
from dvc.utils import as_posix
from dvc.utils.strictyaml import YAMLSyntaxError, YAMLValidationError

dvcyaml = {
Expand Down Expand Up @@ -185,7 +187,7 @@ def test_get_rev(tmp_dir, dvc, scm):
dvc.artifacts.get_rev("myart", stage="prod")


def test_get_path(tmp_dir, dvc):
def test_get_path(tmp_dir, dvc, scm):
(tmp_dir / "dvc.yaml").dump(dvcyaml)
subdir = tmp_dir / "subdir"
subdir.mkdir()
Expand All @@ -206,3 +208,76 @@ def test_parametrized(tmp_dir, dvc):
assert tmp_dir.dvc.artifacts.read() == {
"dvc.yaml": {"myart": Artifact(path="myart.pkl", type="model")}
}


def test_get_path_subrepo(tmp_dir, scm, dvc):
subrepo = tmp_dir / "subrepo"
make_subrepo(subrepo, scm)
(subrepo / "dvc.yaml").dump(dvcyaml)

assert dvc.artifacts.get_path("subrepo:myart") == os.path.join(
"subrepo", "myart.pkl"
)
assert dvc.artifacts.get_path("subrepo/dvc.yaml:myart") == os.path.join(
"subrepo", "myart.pkl"
)

assert subrepo.dvc.artifacts.get_path("subrepo:myart") == os.path.join(
"subrepo", "myart.pkl"
)
assert subrepo.dvc.artifacts.get_path("subrepo/dvc.yaml:myart") == os.path.join(
"subrepo", "myart.pkl"
)


def get_tag_and_name(dirname, name, version):
tagname = f"{name}@{version}"
if dirname in (os.curdir, ""):
return tagname, name
return f"{dirname}={tagname}", f"{dirname}:{name}"


def make_artifact(tmp_dir, name, tag, path) -> Artifact:
artifact = Artifact(path=path.name, type="model")
dvcfile = path.with_name("dvc.yaml")

tmp_dir.scm_gen(path, "hello_world", commit="add myart.pkl")
tmp_dir.dvc.artifacts.add(name, artifact, dvcfile=os.fspath(dvcfile))
tmp_dir.scm.add_commit([dvcfile], message="add dvc.yaml")
tmp_dir.scm.tag(tag, annotated=True, message="foo")
return artifact


@pytest.mark.parametrize("sub", ["sub", ""])
def test_artifacts_download(tmp_dir, dvc, scm, sub):
subdir = tmp_dir / sub
dirname = str(subdir.relative_to(tmp_dir))
tag, name = get_tag_and_name(as_posix(dirname), "myart", "v2.0.0")
make_artifact(tmp_dir, "myart", tag, subdir / "myart.pkl")

result = (1, "myart.pkl")
assert Artifacts.get(".", name, force=True) == result
assert Artifacts.get(tmp_dir.fs_path, name, force=True) == result
assert Artifacts.get(f"file://{tmp_dir.as_posix()}", name, force=True) == result
assert Artifacts.get(subdir.fs_path, name, force=True) == result
with subdir.chdir():
assert Artifacts.get(".", name, force=True) == result


@pytest.mark.parametrize("sub", ["sub", ""])
def test_artifacts_download_subrepo(tmp_dir, scm, sub):
subrepo = tmp_dir / "subrepo"
make_subrepo(subrepo, scm)
subdir = subrepo / sub

dirname = str(subdir.relative_to(tmp_dir))
tag, name = get_tag_and_name(as_posix(dirname), "myart", "v2.0.0")
make_artifact(subrepo, "myart", tag, subdir / "myart.pkl")

result = (1, "myart.pkl")
assert Artifacts.get(".", name) == result
assert Artifacts.get(tmp_dir.fs_path, name, force=True) == result
assert Artifacts.get(f"file://{tmp_dir.as_posix()}", name, force=True) == result
assert Artifacts.get(subdir.fs_path, name, force=True) == result
with subdir.chdir():
assert Artifacts.get(".", name, force=True) == result

0 comments on commit 8517ccc

Please sign in to comment.