From 3fee81f1be245d464d4592cc8311bfdce707d4d5 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Sat, 5 Aug 2023 05:03:48 +0300 Subject: [PATCH] push: use index push --- dvc/repo/fetch.py | 12 +- dvc/repo/index.py | 3 +- dvc/repo/push.py | 190 +++++++++++++++------------ dvc/repo/worktree.py | 87 +----------- pyproject.toml | 2 +- tests/func/test_data_cloud.py | 4 +- tests/func/test_virtual_directory.py | 2 +- 7 files changed, 119 insertions(+), 181 deletions(-) diff --git a/dvc/repo/fetch.py b/dvc/repo/fetch.py index d63bec83aa..52d6f97444 100644 --- a/dvc/repo/fetch.py +++ b/dvc/repo/fetch.py @@ -59,10 +59,9 @@ def _collect_indexes( # noqa: PLR0913 types=types, ) - data = idx.data["repo"] - data.onerror = _make_index_onerror(onerror, rev) + idx.data["repo"].onerror = _make_index_onerror(onerror, rev) - indexes[idx.data_tree.hash_info.value] = data + indexes[rev or "workspace"] = idx except Exception as exc: # pylint: disable=broad-except if onerror: onerror(rev, None, exc) @@ -138,7 +137,10 @@ def fetch( # noqa: C901, PLR0913 onerror=onerror, ) - cache_key = ("fetch", tokenize(sorted(indexes.keys()))) + cache_key = ( + "fetch", + tokenize(sorted(idx.data_tree.hash_info.value for idx in indexes.values())), + ) with ui.progress( desc="Collecting", @@ -146,7 +148,7 @@ def fetch( # noqa: C901, PLR0913 leave=True, ) as pb: data = collect( - indexes.values(), + [idx.data["repo"] for idx in indexes.values()], "remote", cache_index=self.data_index, cache_key=cache_key, diff --git a/dvc/repo/index.py b/dvc/repo/index.py index a32a9c737a..011e751b9e 100644 --- a/dvc/repo/index.py +++ b/dvc/repo/index.py @@ -169,8 +169,9 @@ def _load_storage_from_out(storage_map, key, out): FileStorage( key=key, fs=remote.fs, - path=remote.fs.path.join(remote.path, *key), + path=remote.path, index=remote.index, + prefix=(), ) ) else: diff --git a/dvc/repo/push.py b/dvc/repo/push.py index e43b5c9770..b73babd969 100644 --- a/dvc/repo/push.py +++ b/dvc/repo/push.py @@ -1,17 +1,36 @@ from contextlib import suppress -from typing import TYPE_CHECKING, Optional, Sequence -from dvc.config import NoRemoteError from dvc.exceptions import InvalidArgumentError, UploadError -from dvc.utils import glob_targets +from dvc.ui import ui from . import locked -if TYPE_CHECKING: - from dvc.data_cloud import Remote - from dvc.repo import Repo - from dvc.types import TargetType - from dvc_objects.db import ObjectDB + +def _update_meta(index, **kwargs): + from dvc.repo.index import build_data_index + from dvc.repo.worktree import _merge_push_meta, worktree_view_by_remotes + + stages = set() + for remote_name, idx in worktree_view_by_remotes(index, push=True, **kwargs): + remote = index.repo.cloud.get_remote(remote_name) + + with ui.progress("Collecting", unit="entry") as pb: + new = build_data_index( + idx, + remote.path, + remote.fs, + callback=pb.as_callback(), + ) + + for out in idx.outs: + if not remote.fs.version_aware: + continue + + _merge_push_meta(out, new, remote.name) + stages.add(out.stage) + + for stage in stages: + stage.dump(with_files=True, update_pipeline=False) @locked @@ -28,92 +47,93 @@ def push( # noqa: C901, PLR0913 run_cache=False, revs=None, glob=False, - odb: Optional["ObjectDB"] = None, - include_imports=False, ): - worktree_remote: Optional["Remote"] = None + from fsspec.utils import tokenize + + from dvc.config import NoRemoteError + from dvc.utils import glob_targets + from dvc_data.index.fetch import collect + from dvc_data.index.push import push as ipush + + from .fetch import _collect_indexes + + failed_count = 0 + transferred_count = 0 + with suppress(NoRemoteError): _remote = self.cloud.get_remote(name=remote) - if _remote and (_remote.worktree or _remote.fs.version_aware): - worktree_remote = _remote + if ( + _remote + and (_remote.worktree or _remote.fs.version_aware) + and (revs or all_branches or all_tags or all_commits) + ): + raise InvalidArgumentError( + "Multiple rev push is unsupported for cloud versioned remotes" + ) - pushed = 0 - used_run_cache = self.stage_cache.push(remote, odb=odb) if run_cache else [] - pushed += len(used_run_cache) + used_run_cache = self.stage_cache.push(remote) if run_cache else [] + transferred_count += len(used_run_cache) if isinstance(targets, str): targets = [targets] - expanded_targets = glob_targets(targets, glob=glob) - - if worktree_remote is not None: - pushed += _push_worktree( - self, - worktree_remote, - revs=revs, - all_branches=all_branches, - all_tags=all_tags, - all_commits=all_commits, - targets=expanded_targets, - jobs=jobs, - with_deps=with_deps, - recursive=recursive, - ) - else: - used = self.used_objs( - expanded_targets, - all_branches=all_branches, - all_tags=all_tags, - all_commits=all_commits, - with_deps=with_deps, - force=True, - remote=remote, - jobs=jobs, - recursive=recursive, - used_run_cache=used_run_cache, - revs=revs, + indexes = _collect_indexes( + self, + targets=glob_targets(targets, glob=glob), + remote=remote, + all_branches=all_branches, + with_deps=with_deps, + all_tags=all_tags, + recursive=recursive, + all_commits=all_commits, + revs=revs, + ) + + cache_key = ( + "push", + tokenize(sorted(idx.data_tree.hash_info.value for idx in indexes.values())), + ) + + with ui.progress( + desc="Collecting", + unit="entry", + ) as pb: + data = collect( + [idx.data["repo"] for idx in indexes.values()], + "remote", + cache_index=self.data_index, + cache_key=cache_key, + callback=pb.as_callback(), push=True, ) - if odb: - all_ids = set() - for dest_odb, obj_ids in used.items(): - if not include_imports and dest_odb and dest_odb.read_only: - continue - all_ids.update(obj_ids) - result = self.cloud.push(all_ids, jobs, remote=remote, odb=odb) - if result.failed: - raise UploadError(len(result.failed)) - pushed += len(result.transferred) - else: - for dest_odb, obj_ids in used.items(): - if dest_odb and dest_odb.read_only: - continue - result = self.cloud.push( - obj_ids, jobs, remote=remote, odb=odb or dest_odb - ) - if result.failed: - raise UploadError(len(result.failed)) - pushed += len(result.transferred) - return pushed - - -def _push_worktree( - repo: "Repo", - remote: "Remote", - revs: Optional[Sequence[str]] = None, - all_branches: bool = False, - all_tags: bool = False, - all_commits: bool = False, - targets: Optional["TargetType"] = None, - jobs: Optional[int] = None, - **kwargs, -) -> int: - from dvc.repo.worktree import push_worktree - - if revs or all_branches or all_tags or all_commits: - raise InvalidArgumentError( - "Multiple rev push is unsupported for cloud versioned remotes" - ) + try: + with ui.progress( + desc="Pushing", + unit="file", + ) as pb: + push_transferred, push_failed = ipush( + data, + jobs=jobs, + callback=pb.as_callback(), + ) # pylint: disable=assignment-from-no-return + finally: + ws_idx = indexes.get("workspace") + if ws_idx is not None: + _update_meta( + self.index, + targets=glob_targets(targets, glob=glob), + remote=remote, + with_deps=with_deps, + recursive=recursive, + ) + + for fs_index in data: + fs_index.close() + + transferred_count += push_transferred + failed_count += push_failed + if failed_count: + raise UploadError(failed_count) - return push_worktree(repo, remote, targets=targets, jobs=jobs, **kwargs) + return transferred_count diff --git a/dvc/repo/worktree.py b/dvc/repo/worktree.py index f67327db05..b13ff3eaca 100644 --- a/dvc/repo/worktree.py +++ b/dvc/repo/worktree.py @@ -1,10 +1,9 @@ import logging from functools import partial -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple, Union from funcy import first -from dvc.exceptions import DvcException from dvc.fs.callbacks import Callback from dvc.stage.exceptions import StageUpdateError @@ -104,90 +103,6 @@ def _get_remote( return repo.cloud.get_remote(name, command) -def push_worktree( - repo: "Repo", - remote: "Remote", - targets: Optional["TargetType"] = None, - jobs: Optional[int] = None, - **kwargs: Any, -) -> int: - from dvc.repo.index import build_data_index - from dvc_data.index.checkout import VersioningNotSupported, apply, compare - - pushed = 0 - stages: Set["Stage"] = set() - - for remote_name, view in worktree_view_by_remotes( - repo.index, push=True, targets=targets, **kwargs - ): - remote_obj = _get_remote(repo, remote_name, remote, "push") - new_index = view.data["repo"] - if remote_obj.worktree: - logger.debug("indexing latest worktree for '%s'", remote_obj.path) - old_index = build_data_index(view, remote_obj.path, remote_obj.fs) - logger.debug("Pushing worktree changes to '%s'", remote_obj.path) - else: - old_index = None - logger.debug("Pushing version-aware files to '%s'", remote_obj.path) - - if remote_obj.worktree: - diff_kwargs: Dict[str, Any] = { - "meta_only": True, - "meta_cmp_key": partial(_meta_checksum, remote_obj.fs), - } - else: - diff_kwargs = {} - - with Callback.as_tqdm_callback( - unit="entry", - desc=f"Comparing indexes for remote {remote_obj.name!r}", - ) as cb: - diff = compare( - old_index, - new_index, - callback=cb, - delete=remote_obj.worktree, - **diff_kwargs, - ) - - total = len(new_index) - with Callback.as_tqdm_callback( - unit="file", - desc=f"Pushing to remote {remote_obj.name!r}", - disable=total == 0, - ) as cb: - cb.set_size(total) - try: - apply( - diff, - remote_obj.path, - remote_obj.fs, - callback=cb, - latest_only=remote_obj.worktree, - jobs=jobs, - ) - pushed += len(diff.files_create) - except VersioningNotSupported: - logger.exception("") - raise DvcException( - f"remote {remote_obj.name!r} does not support versioning" - ) from None - - if remote_obj.index is not None: - for key, entry in new_index.iteritems(): - remote_obj.index[key] = entry - remote_obj.index.commit() - - for out in view.outs: - workspace, _key = out.index_key - _merge_push_meta(out, repo.index.data[workspace], remote_obj.name) - stages.add(out.stage) - - for stage in stages: - stage.dump(with_files=True, update_pipeline=False) - return pushed - - def _merge_push_meta( out: "Output", index: Union["DataIndex", "DataIndexView"], diff --git a/pyproject.toml b/pyproject.toml index 2915a03487..70a2f1fe4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "configobj>=5.0.6", "distro>=1.3", "dpath<3,>=2.1.0", - "dvc-data>=2.18.1,<2.19.0", + "dvc-data>=2.19.0,<2.20.0", "dvc-http>=2.29.0", "dvc-render>=0.3.1,<1", "dvc-studio-client>=0.13.0,<1", diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index ebefe848ab..6c745d602b 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -148,7 +148,7 @@ def test_hash_recalculation(mocker, dvc, tmp_dir, local_remote): assert ret == 0 ret = main(["push"]) assert ret == 0 - assert test_file_md5.mock.call_count == 1 + assert test_file_md5.mock.call_count == 3 def test_missing_cache(tmp_dir, dvc, local_remote, caplog): @@ -486,7 +486,7 @@ def test_pull_partial(tmp_dir, dvc, local_remote): clean(["foo"], dvc) stats = dvc.pull(os.path.join("foo", "bar")) - assert stats["fetched"] == 3 + assert stats["fetched"] == 2 assert (tmp_dir / "foo").read_text() == {"bar": {"baz": "baz"}} diff --git a/tests/func/test_virtual_directory.py b/tests/func/test_virtual_directory.py index b609e5718b..a7f8aef53c 100644 --- a/tests/func/test_virtual_directory.py +++ b/tests/func/test_virtual_directory.py @@ -182,7 +182,7 @@ def test_partial_checkout_and_update(M, tmp_dir, dvc, remote): assert dvc.pull("dir/subdir") == M.dict( added=[join("dir", "")], - fetched=3, + fetched=2, ) assert (tmp_dir / "dir").read_text() == {"subdir": {"lorem": "lorem"}}