Skip to content

Commit

Permalink
fetch: use index fetch
Browse files Browse the repository at this point in the history
  • Loading branch information
efiop committed Jun 10, 2023
1 parent 1cb371d commit 01095e1
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 185 deletions.
8 changes: 8 additions & 0 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,14 @@ def data_index(self) -> "DataIndex":

return self._data_index

def drop_data_index(self) -> None:
try:
self.data_index.delete_node(("tree",))
except KeyError:
pass
self.data_index.commit()
self._reset()

def __repr__(self):
return f"{self.__class__.__name__}: '{self.root_dir}'"

Expand Down
170 changes: 28 additions & 142 deletions dvc/repo/fetch.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,9 @@
import logging
from contextlib import suppress
from typing import TYPE_CHECKING, Optional, Sequence

from dvc.config import NoRemoteError
from dvc.exceptions import DownloadError
from dvc.fs import Schemes

from . import locked

if TYPE_CHECKING:
from dvc.data_cloud import Remote
from dvc.repo import Repo
from dvc.types import TargetType
from dvc_data.hashfile.db import HashFileDB
from dvc_data.hashfile.transfer import TransferResult

logger = logging.getLogger(__name__)


Expand All @@ -31,7 +20,6 @@ def fetch( # noqa: C901, PLR0913
all_commits=False,
run_cache=False,
revs=None,
odb: Optional["HashFileDB"] = None,
) -> int:
"""Download data items from a cloud and imported repositories
Expand All @@ -45,18 +33,11 @@ def fetch( # noqa: C901, PLR0913
config.NoRemoteError: thrown when downloading only local files and no
remote is configured
"""
from dvc.repo.imports import save_imports
from dvc_data.hashfile.transfer import TransferResult
from dvc_data.index.fetch import fetch as ifetch

if isinstance(targets, str):
targets = [targets]

worktree_remote: Optional["Remote"] = None
with suppress(NoRemoteError):
_remote = self.cloud.get_remote(name=remote)
if _remote.worktree or _remote.fs.version_aware:
worktree_remote = _remote

failed_count = 0
transferred_count = 0

Expand All @@ -66,133 +47,38 @@ def fetch( # noqa: C901, PLR0913
except DownloadError as exc:
failed_count += exc.amount

no_remote_msg: Optional[str] = None
result = TransferResult(set(), set())
try:
if worktree_remote is not None:
transferred_count += _fetch_worktree(
self,
worktree_remote,
revs=revs,
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
targets=targets,
jobs=jobs,
with_deps=with_deps,
recursive=recursive,
)
else:
d, f = _fetch(
self,
def _indexes():
for _ in self.brancher(
revs=revs,
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
):
yield self.index.targets_view(
targets,
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
with_deps=with_deps,
force=True,
remote=remote,
jobs=jobs,
recursive=recursive,
revs=revs,
odb=odb,
)
result.transferred.update(d)
result.failed.update(f)
except NoRemoteError as exc:
no_remote_msg = str(exc)

for rev in self.brancher(
revs=revs,
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
):
imported = save_imports(
self,
targets,
unpartial=not rev or rev == "workspace",
recursive=recursive,
)
result.transferred.update(imported)
result.failed.difference_update(imported)

failed_count += len(result.failed)
).data["repo"]

try:
saved_remote = self.config["core"].get("remote")
if remote:
self.config["core"]["remote"] = remote

fetch_transferred, fetch_failed = ifetch(
_indexes(), jobs=jobs
) # pylint: disable=assignment-from-no-return
finally:
if remote:
self.config["core"]["remote"] = saved_remote

if fetch_transferred:
# NOTE: dropping cached index to force reloading from newly saved cache
self.drop_data_index()

transferred_count += fetch_transferred
failed_count += fetch_failed
if failed_count:
if no_remote_msg:
logger.error(no_remote_msg)
raise DownloadError(failed_count)

transferred_count += len(result.transferred)
return transferred_count


def _fetch(
repo: "Repo",
targets: "TargetType",
remote: Optional[str] = None,
jobs: Optional[int] = None,
odb: Optional["HashFileDB"] = None,
**kwargs,
) -> "TransferResult":
from dvc_data.hashfile.transfer import TransferResult

result = TransferResult(set(), set())
used = repo.used_objs(
targets,
remote=remote,
jobs=jobs,
**kwargs,
)
if odb:
all_ids = set()
for _odb, obj_ids in used.items():
all_ids.update(obj_ids)
d, f = repo.cloud.pull(
all_ids,
jobs=jobs,
remote=remote,
odb=odb,
)
result.transferred.update(d)
result.failed.update(f)
else:
for src_odb, obj_ids in sorted(
used.items(),
key=lambda item: item[0] is not None
and item[0].fs.protocol == Schemes.MEMORY,
):
d, f = repo.cloud.pull(
obj_ids,
jobs=jobs,
remote=remote,
odb=src_odb,
)
result.transferred.update(d)
result.failed.update(f)
return result


def _fetch_worktree(
repo: "Repo",
remote: "Remote",
revs: Optional[Sequence[str]] = None,
all_branches: bool = False,
all_tags: bool = False,
all_commits: bool = False,
targets: Optional["TargetType"] = None,
jobs: Optional[int] = None,
**kwargs,
) -> int:
from dvc.repo.worktree import fetch_worktree

downloaded = 0
for _ in repo.brancher(
revs=revs,
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
):
downloaded += fetch_worktree(repo, remote, targets=targets, jobs=jobs, **kwargs)
return downloaded
13 changes: 13 additions & 0 deletions dvc/repo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,19 @@ def _load_storage_from_out(storage_map, key, out):

if out.stage.is_import:
dep = out.stage.deps[0]
if not out.hash_info:
from fsspec.utils import tokenize

# partial import
storage_map.add_cache(
FileStorage(
key,
out.cache.fs,
out.cache.fs.path.join(
out.cache.path, "fs", dep.fs.protocol, tokenize(dep.fs_path)
),
)
)
storage_map.add_remote(FileStorage(key, dep.fs, dep.fs_path))


Expand Down
6 changes: 0 additions & 6 deletions dvc/repo/pull.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import logging
from typing import TYPE_CHECKING, Optional

from dvc.repo import locked
from dvc.utils import glob_targets

if TYPE_CHECKING:
from dvc_objects.db import ObjectDB

logger = logging.getLogger(__name__)


Expand All @@ -24,7 +20,6 @@ def pull( # noqa: PLR0913
all_commits=False,
run_cache=False,
glob=False,
odb: Optional["ObjectDB"] = None,
allow_missing=False,
):
if isinstance(targets, str):
Expand All @@ -42,7 +37,6 @@ def pull( # noqa: PLR0913
with_deps=with_deps,
recursive=recursive,
run_cache=run_cache,
odb=odb,
)
stats = self.checkout(
targets=expanded_targets,
Expand Down
26 changes: 0 additions & 26 deletions dvc/repo/worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,32 +104,6 @@ def _get_remote(
return repo.cloud.get_remote(name, command)


def fetch_worktree(
repo: "Repo",
remote: "Remote",
targets: Optional["TargetType"] = None,
jobs: Optional[int] = None,
**kwargs: Any,
) -> int:
from dvc_data.index import save

transferred = 0
for remote_name, view in worktree_view_by_remotes(
repo.index, push=True, targets=targets, **kwargs
):
remote_obj = _get_remote(repo, remote_name, remote, "fetch")
index = view.data["repo"]
total = len(index)
with Callback.as_tqdm_callback(
unit="file",
desc=f"Fetching from remote {remote_obj.name!r}",
disable=total == 0,
) as cb:
cb.set_size(total)
transferred += save(index, callback=cb, jobs=jobs, storage="remote")
return transferred


def push_worktree(
repo: "Repo",
remote: "Remote",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"configobj>=5.0.6",
"distro>=1.3",
"dpath<3,>=2.1.0",
"dvc-data>=1.0.3,<1.1.0",
"dvc-data>=1.1.0,<1.2.0",
"dvc-http>=2.29.0",
"dvc-render>=0.3.1,<1",
"dvc-studio-client>=0.9.2,<1",
Expand Down
12 changes: 6 additions & 6 deletions tests/func/test_data_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def test_missing_cache(tmp_dir, dvc, local_remote, caplog):
"Some of the cache files do not exist "
"neither locally nor on remote. Missing cache files:\n"
)
foo = "name: bar, md5: 37b51d194a7513e45b56f6524f2d51f2\n"
bar = "name: foo, md5: acbd18db4cc2f85cedef654fccc4a4d8\n"
foo = "md5: 37b51d194a7513e45b56f6524f2d51f2\n"
bar = "md5: acbd18db4cc2f85cedef654fccc4a4d8\n"

caplog.clear()
dvc.push()
Expand Down Expand Up @@ -198,15 +198,15 @@ def test_verify_hashes(tmp_dir, scm, dvc, mocker, tmp_path_factory, local_remote

dvc.pull()
# NOTE: 1 is for index.data_tree building
assert hash_spy.call_count == 1
assert hash_spy.call_count == 2

# Removing cache will invalidate existing state entries
dvc.cache.local.clear()

dvc.config["remote"]["upstream"]["verify"] = True

dvc.pull()
assert hash_spy.call_count == 6
assert hash_spy.call_count == 8


@flaky(max_runs=3, min_passes=1)
Expand Down Expand Up @@ -268,7 +268,7 @@ def test_pull_partial_import(tmp_dir, dvc, local_workspace):
stage = dvc.imp_url("remote://workspace/file", os.fspath(dst), no_download=True)

result = dvc.pull("file")
assert result["fetched"] == 1
assert result["fetched"] == 0
assert dst.exists()

assert stage.outs[0].get_hash().value == "d10b4c3ff123b26dc068d43a8bef2d23"
Expand Down Expand Up @@ -483,7 +483,7 @@ def test_pull_partial(tmp_dir, dvc, local_remote):
clean(["foo"], dvc)

stats = dvc.pull(os.path.join("foo", "bar"))
assert stats["fetched"] == 1
assert stats["fetched"] == 3
assert (tmp_dir / "foo").read_text() == {"bar": {"baz": "baz"}}


Expand Down
5 changes: 3 additions & 2 deletions tests/func/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,13 @@ def test_pull_import_no_download(tmp_dir, scm, dvc, erepo_dir):
dvc.imp(os.fspath(erepo_dir), "foo", "foo_imported", no_download=True)

dvc.pull(["foo_imported.dvc"])
assert (tmp_dir / "foo_imported").exists
assert (tmp_dir / "foo_imported").exists()
assert (tmp_dir / "foo_imported" / "bar").read_bytes() == b"bar"
assert (tmp_dir / "foo_imported" / "baz").read_bytes() == b"baz contents"

stage = load_file(dvc, "foo_imported.dvc").stage
dvc.commit(force=True)

stage = load_file(dvc, "foo_imported.dvc").stage
if os.name == "nt":
expected_hash = "2e798234df5f782340ac3ce046f8dfae.dir"
else:
Expand Down
3 changes: 2 additions & 1 deletion tests/func/test_import_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,9 @@ def test_partial_import_pull(tmp_dir, scm, dvc, local_workspace):

assert dst.exists()

stage = load_file(dvc, "file.dvc").stage
dvc.commit(force=True)

stage = load_file(dvc, "file.dvc").stage
assert stage.outs[0].hash_info.value == "d10b4c3ff123b26dc068d43a8bef2d23"
assert stage.outs[0].meta.size == 12

Expand Down
2 changes: 1 addition & 1 deletion tests/func/test_virtual_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_partial_checkout_and_update(M, tmp_dir, dvc, remote):

assert dvc.pull("dir/subdir") == M.dict(
added=[join("dir", "")],
fetched=1,
fetched=3,
)
assert (tmp_dir / "dir").read_text() == {"subdir": {"lorem": "lorem"}}

Expand Down

0 comments on commit 01095e1

Please sign in to comment.