diff --git a/dvc/api/__init__.py b/dvc/api/__init__.py index b3ba6bebe08..64959a02543 100644 --- a/dvc/api/__init__.py +++ b/dvc/api/__init__.py @@ -1,5 +1,6 @@ from dvc.fs.dvc import _DVCFileSystem as DVCFileSystem +from . import artifacts from .data import open # pylint: disable=redefined-builtin from .data import get_url, read from .experiments import exp_save, exp_show @@ -10,6 +11,7 @@ "all_branches", "all_commits", "all_tags", + "artifacts", "exp_save", "exp_show", "get_url", diff --git a/dvc/cli/parser.py b/dvc/cli/parser.py index 10e1cce5402..a71b9fb9467 100644 --- a/dvc/cli/parser.py +++ b/dvc/cli/parser.py @@ -7,6 +7,7 @@ from dvc import __version__ from dvc.commands import ( add, + artifacts, cache, check_ignore, checkout, @@ -89,6 +90,7 @@ experiments, check_ignore, data, + artifacts, ] diff --git a/dvc/commands/artifacts.py b/dvc/commands/artifacts.py new file mode 100644 index 00000000000..d3cdb1e33ec --- /dev/null +++ b/dvc/commands/artifacts.py @@ -0,0 +1,110 @@ +import argparse +import logging + +from dvc.cli import completion +from dvc.cli.command import CmdBaseNoRepo +from dvc.cli.utils import append_doc_link, fix_subparsers +from dvc.exceptions import DvcException + +logger = logging.getLogger(__name__) + + +class CmdArtifactsGet(CmdBaseNoRepo): + def run(self): + from dvc.repo.artifacts import Artifacts + from dvc.scm import CloneError + + try: + Artifacts.get( + self.args.url, + name=self.args.name, + version=self.args.rev, + stage=self.args.stage, + force=self.args.force, + config=self.args.config, + out=self.args.out, + ) + return 0 + except CloneError: + logger.exception("failed to get '%s'", self.args.name) + return 1 + except DvcException: + logger.exception( + "failed to get '%s' from '%s'", self.args.name, self.args.url + ) + return 1 + + +def add_parser(subparsers, parent_parser): + ARTIFACTS_HELP = "DVC model registry artifact commands." + + artifacts_parser = subparsers.add_parser( + "artifacts", + parents=[parent_parser], + description=append_doc_link(ARTIFACTS_HELP, "artifacts"), + help=ARTIFACTS_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + artifacts_subparsers = artifacts_parser.add_subparsers( + dest="cmd", + help="Use `dvc artifacts CMD --help` to display command-specific help.", + ) + fix_subparsers(artifacts_subparsers) + + ARTIFACTS_GET_HELP = "Download an artifact from a DVC project." + get_parser = artifacts_subparsers.add_parser( + "get", + parents=[parent_parser], + description=append_doc_link(ARTIFACTS_GET_HELP, "artifacts/get"), + help=ARTIFACTS_HELP, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + get_parser.add_argument("url", help="Location of DVC repository to download from") + get_parser.add_argument( + "name", help="Name of artifact in the repository" + ).complete = completion.FILE + get_parser.add_argument( + "--rev", + nargs="?", + help="Artifact version", + metavar="", + ) + get_parser.add_argument( + "--stage", + nargs="?", + help="Artifact stage", + metavar="", + ) + get_parser.add_argument( + "-o", + "--out", + nargs="?", + help="Destination path to download artifact to", + metavar="", + ).complete = completion.DIR + get_parser.add_argument( + "-j", + "--jobs", + type=int, + help=( + "Number of jobs to run simultaneously. " + "The default value is 4 * cpu_count(). " + ), + metavar="", + ) + get_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Override local file or folder if exists.", + ) + get_parser.add_argument( + "--config", + type=str, + help=( + "Path to a config file that will be merged with the config " + "in the target repository." + ), + ) + get_parser.set_defaults(func=CmdArtifactsGet) diff --git a/dvc/exceptions.py b/dvc/exceptions.py index fd0a0b26e13..6ddae112450 100644 --- a/dvc/exceptions.py +++ b/dvc/exceptions.py @@ -1,6 +1,6 @@ """Exceptions raised by the dvc.""" import errno -from typing import Dict, List +from typing import Dict, List, Optional from dvc.utils import format_link @@ -332,3 +332,24 @@ def __init__(self, fs_paths): class PrettyDvcException(DvcException): def __pretty_exc__(self, **kwargs): """Print prettier exception message.""" + + +class ArtifactNotFoundError(DvcException): + """Thrown if an artifact is not found in the DVC repo. + + Args: + name (str): artifact name. + """ + + def __init__( + self, + name: str, + version: Optional[str] = None, + stage: Optional[str] = None, + ): + self.name = name + self.version = version + self.stage = stage + + desc = f" @ {stage or version}" if (stage or version) else "" + super().__init__(f"Unable to find artifact '{name}{desc}'") diff --git a/dvc/repo/artifacts.py b/dvc/repo/artifacts.py index 0615486c293..b684b291183 100644 --- a/dvc/repo/artifacts.py +++ b/dvc/repo/artifacts.py @@ -1,15 +1,24 @@ import logging +import os import re from pathlib import Path -from typing import Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from dvc.annotations import Artifact from dvc.dvcfile import PROJECT_FILE -from dvc.exceptions import InvalidArgumentError -from dvc.repo import Repo -from dvc.utils import relpath +from dvc.exceptions import ( + ArtifactNotFoundError, + FileExistsLocallyError, + InvalidArgumentError, +) +from dvc.utils import relpath, resolve_output +from dvc.utils.objects import cached_property from dvc.utils.serialize import modify_yaml +if TYPE_CHECKING: + from dvc.repo import Repo + from dvc.scm import Git + logger = logging.getLogger(__name__) @@ -37,6 +46,8 @@ def check_name_format(name: str) -> None: def check_for_nested_dvc_repo(dvcfile: Path): + from dvc.repo import Repo + if dvcfile.is_absolute(): raise InvalidArgumentError("Use relative path to dvc.yaml.") path = dvcfile.parent @@ -52,7 +63,16 @@ class Artifacts: def __init__(self, repo: "Repo") -> None: self.repo = repo + @cached_property + def scm(self) -> Optional["Git"]: + from dvc.scm import Git + + if isinstance(self.repo.scm, Git): + return self.repo.scm + return None + def read(self) -> Dict[str, Dict[str, Artifact]]: + """Read artifacts from dvc.yaml.""" artifacts: Dict[str, Dict[str, Artifact]] = {} for ( dvcfile, @@ -69,6 +89,7 @@ def read(self) -> Dict[str, Dict[str, Artifact]]: return artifacts def add(self, name: str, artifact: Artifact, dvcfile: Optional[str] = None): + """Add artifact to dvc.yaml.""" with self.repo.scm_context(quiet=True): check_name_format(name) dvcyaml = Path(dvcfile or PROJECT_FILE) @@ -85,3 +106,188 @@ def add(self, name: str, artifact: Artifact, dvcfile: Optional[str] = None): self.repo.scm_context.track_file(dvcfile) return artifacts.get(name) + + def get_rev( + self, name: str, version: Optional[str] = None, stage: Optional[str] = None + ): + """Return revision containing the given artifact.""" + from gto.tag import find as find_tags + + assert not (version and stage) + tags = find_tags(name=name, version=version, stage=stage, scm=self.scm) + if not tags: + raise ArtifactNotFoundError(name, version=version, stage=stage) + return tags[-1].target + + def get_path(self, name: str): + """Return repo fspath for the given artifact.""" + from gto.constants import fullname_re + + m = fullname_re.match(name) + if not m: + raise ArtifactNotFoundError(name) + dirname = m.group("dirname") + if dirname: + dirname = dirname.rstrip(SEPARATOR_IN_NAME) + dvcyaml = os.path.join(dirname, PROJECT_FILE) if dirname else PROJECT_FILE + artifact_name = m.group("name") + try: + artifact = self.read()[dvcyaml][artifact_name] + except KeyError as exc: + raise ArtifactNotFoundError(name) from exc + return os.path.join(dirname, artifact.path) if dirname else artifact.path + + def download( + self, + name: str, + version: Optional[str] = None, + stage: Optional[str] = None, + out: Optional[str] = None, + force: bool = False, + jobs: Optional[int] = None, + ): + """Download the specified artifact.""" + from dvc.fs import download as fs_download + + logger.debug("Trying to download artifact '%s' via DVC", name) + rev = self.get_rev(name, version=version, stage=stage) + with self.repo.switch(rev): + path = self.get_path(name) + out = resolve_output(path, out, force=force) + fs = self.repo.dvcfs + fs_path = fs.from_os_path(path) + fs_download( + fs, + fs_path, + os.path.abspath(out), + jobs=jobs, + ) + + @staticmethod + def _download_studio( + repo_url: str, + name: str, + version: Optional[str] = None, + stage: Optional[str] = None, + out: Optional[str] = None, + force: bool = False, + jobs: Optional[int] = None, + dvc_studio_config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + from dvc_studio_client.model_registry import get_download_uris + + from dvc.fs import Callback, HTTPFileSystem, generic, localfs + from dvc.utils.studio import env_to_config + + logger.debug("Trying to download artifact '%s' via studio", name) + out = out or os.getcwd() + to_infos: List[str] = [] + from_infos: List[str] = [] + if dvc_studio_config is None: + dvc_studio_config = {} + dvc_studio_config.update(env_to_config(dict(os.environ))) + dvc_studio_config["repo_url"] = repo_url + for path, url in get_download_uris( + repo_url, + name, + version=version, + stage=stage, + dvc_studio_config=dvc_studio_config, + **kwargs, + ).items(): + to_info = localfs.path.join(out, path) + if localfs.exists(to_info) and not force: + hint = "\nTo override it, re-run with '--force'." + raise FileExistsLocallyError(to_info, hint=hint) + to_infos.append(to_info) + from_infos.append(url) + fs = HTTPFileSystem() + jobs = jobs or fs.jobs + with Callback.as_tqdm_callback( + desc=f"Downloading '{name}' from '{repo_url}'", + unit="files", + ) as cb: + cb.set_size(len(from_infos)) + generic.copy( + fs, from_infos, localfs, to_infos, callback=cb, batch_size=jobs + ) + + @classmethod + def get( # noqa: C901 + cls, + url: str, + name: str, + version: Optional[str] = None, + stage: Optional[str] = None, + config: Optional[Union[str, Dict[str, Any]]] = None, + out: Optional[str] = None, + force: bool = False, + jobs: Optional[int] = None, + ): + from dvc.config import Config + from dvc.env import DVC_STUDIO_TOKEN + from dvc.repo import Repo + + if version and stage: + raise InvalidArgumentError( + "Artifact version and stage are mutually exclusive." + ) + + # NOTE: We try to download the artifact up to three times + # 1. via studio with studio config loaded from environment + # 2. via studio with studio config loaded from DVC repo 'studio' + # section + environment + # 3. via DVC remote + + saved_exc: Optional[Exception] = None + if DVC_STUDIO_TOKEN in os.environ: + try: + return cls._download_studio( + url, + name, + version=version, + stage=stage, + out=out, + force=force, + jobs=jobs, + ) + except Exception as exc: # noqa: BLE001, pylint: disable=broad-except + saved_exc = exc + + if config and not isinstance(config, dict): + config = Config.load_file(config) + with Repo.open( + url=url, + subrepos=True, + uninitialized=True, + config=config, + ) as repo: + dvc_studio_config = dict(repo.config.get("studio")) + try: + return cls._download_studio( + url, + name, + version=version, + stage=stage, + out=out, + force=force, + jobs=jobs, + dvc_studio_config=dvc_studio_config, + ) + except Exception as exc: # noqa: BLE001, pylint: disable=broad-except + saved_exc = exc + + try: + return repo.artifacts.download( + name, + version=version, + stage=stage, + out=out, + force=force, + jobs=jobs, + ) + except Exception as exc: + if saved_exc: + raise exc from saved_exc + raise diff --git a/dvc/utils/__init__.py b/dvc/utils/__init__.py index 4022691d49e..a37325c37f5 100644 --- a/dvc/utils/__init__.py +++ b/dvc/utils/__init__.py @@ -245,7 +245,7 @@ def env2bool(var, undefined=False): return bool(re.search("1|y|yes|true", var, flags=re.I)) -def resolve_output(inp, out, force=False): +def resolve_output(inp: str, out: Optional[str], force=False) -> str: from urllib.parse import urlparse from dvc.exceptions import FileExistsLocallyError diff --git a/tests/func/artifacts/test_artifacts.py b/tests/func/artifacts/test_artifacts.py index cf6c4220367..ca31215c2ce 100644 --- a/tests/func/artifacts/test_artifacts.py +++ b/tests/func/artifacts/test_artifacts.py @@ -5,7 +5,7 @@ import pytest from dvc.annotations import Artifact -from dvc.exceptions import InvalidArgumentError +from dvc.exceptions import ArtifactNotFoundError, InvalidArgumentError from dvc.repo.artifacts import name_is_compatible from dvc.utils.strictyaml import YAMLSyntaxError, YAMLValidationError @@ -182,3 +182,27 @@ def test_name_is_compatible(name): ) def test_name_is_compatible_fails(name): assert not name_is_compatible(name) + + +def test_get_rev(tmp_dir, dvc, scm): + scm.tag("myart@v1.0.0#1", annotated=True, message="foo") + scm.tag("subdir=myart@v2.0.0#1", annotated=True, message="foo") + scm.tag("myart#dev#1", annotated=True, message="foo") + rev = scm.get_rev() + + assert dvc.artifacts.get_rev("myart") == rev + assert dvc.artifacts.get_rev("myart", version="v1.0.0") == rev + with pytest.raises(ArtifactNotFoundError): + dvc.artifacts.get_rev("myart", version="v3.0.0") + with pytest.raises(ArtifactNotFoundError): + dvc.artifacts.get_rev("myart", stage="prod") + + +def test_get_path(tmp_dir, dvc): + (tmp_dir / "dvc.yaml").dump(dvcyaml) + subdir = tmp_dir / "subdir" + subdir.mkdir() + (subdir / "dvc.yaml").dump(dvcyaml) + + assert dvc.artifacts.get_path("myart") == "myart.pkl" + assert dvc.artifacts.get_path("subdir:myart") == os.path.join("subdir", "myart.pkl")