Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for tracking remote dataset #10287

Merged
merged 4 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dvc/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
dag,
data,
data_sync,
dataset,
destroy,
diff,
du,
Expand Down Expand Up @@ -67,6 +68,7 @@
dag,
data,
data_sync,
dataset,
destroy,
diff,
du,
Expand Down
206 changes: 206 additions & 0 deletions dvc/commands/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from typing import TYPE_CHECKING, Optional

from dvc.cli import formatter
from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.exceptions import DvcException
from dvc.log import logger

if TYPE_CHECKING:
from rich.text import Text

from dvc.repo.datasets import Dataset, FileInfo

logger = logger.getChild(__name__)


def diff_files(old: list["FileInfo"], new: list["FileInfo"]) -> dict[str, list[str]]:
old_files = {d.relpath: d for d in old}
new_files = {d.relpath: d for d in new}
rest = old_files.keys() & new_files.keys()
return {
"added": list(new_files.keys() - old_files.keys()),
"deleted": list(old_files.keys() - new_files.keys()),
"modified": [p for p in rest if new_files[p] != old_files[p]],
}


class CmdDatasetAdd(CmdBase):
@classmethod
def display(cls, dataset: "Dataset", action: str = "Adding"):
from dvc.repo.datasets import DVCDatasetLock, DVCXDatasetLock
from dvc.ui import ui

assert dataset.lock

url = dataset.url
ver: str = ""
if isinstance(dataset.lock, DVCXDatasetLock):
ver = f"v{dataset.lock.version}"
if isinstance(dataset.lock, DVCDatasetLock):
if dataset.lock.path:
url = f"{dataset.url}:/{dataset.lock.path.lstrip('/')}"
if rev := dataset.lock.rev:
ver = rev

ver_part: Optional["Text"] = None
if ver:
ver_part = ui.rich_text.assemble(" @ ", (ver, "repr.number"))
text = ui.rich_text.assemble("(", (url, "repr.url"), ver_part or "", ")")
ui.write(action, ui.rich_text(dataset.name, "cyan"), text, styled=True)

def run(self):
from urllib.parse import urlsplit

d = vars(self.args)
url_obj = urlsplit(self.args.url)
if url_obj.scheme == "dvcx":
d["type"] = "dvcx"
elif url_obj.scheme.startswith("dvc"):
d["type"] = "dvc"
protos = tuple(url_obj.scheme.split("+"))
if not protos or protos == ("dvc",) or protos == ("dvc", "ssh"):
d["url"] = url_obj.netloc + url_obj.path
else:
d["url"] = url_obj._replace(scheme=protos[1]).geturl()
else:
d["type"] = "url"

existing = self.repo.datasets.get(self.args.name)
with self.repo.scm_context:
if not self.args.force and existing:
path = self.repo.fs.relpath(existing.manifest_path)
raise DvcException(
f"{self.args.name} already exists in {path}, "
"use the --force to overwrite"
)
dataset = self.repo.datasets.add(**d)
self.display(dataset)


class CmdDatasetUpdate(CmdBase):
def display(self, dataset: "Dataset", new: "Dataset"):
from dvc.commands.checkout import log_changes
from dvc.repo.datasets import DVCDatasetLock, DVCXDatasetLock, URLDatasetLock
from dvc.ui import ui

if not dataset.lock:
return CmdDatasetAdd.display(new, "Updating")
if dataset == new:
ui.write("[yellow]Nothing to update[/]", styled=True)
return

v: Optional[tuple[str, str]] = None
if isinstance(dataset.lock, DVCXDatasetLock):
assert isinstance(new.lock, DVCXDatasetLock)
v = (f"v{dataset.lock.version}", f"v{new.lock.version}")
if isinstance(dataset.lock, DVCDatasetLock):
assert isinstance(new.lock, DVCDatasetLock)
v = (f"{dataset.lock.rev_lock[:9]}", f"{new.lock.rev_lock[:9]}")
if v:
part = ui.rich_text.assemble(
(v[0], "repr.number"),
" -> ",
(v[1], "repr.number"),
)
else:
part = ui.rich_text(dataset.url, "repr.url")
changes = ui.rich_text.assemble("(", part, ")")
ui.write("Updating", ui.rich_text(dataset.name, "cyan"), changes, styled=True)
if isinstance(dataset.lock, URLDatasetLock):
assert isinstance(new.lock, URLDatasetLock)
stats = diff_files(dataset.lock.files, new.lock.files)
log_changes(stats)

def run(self):
from difflib import get_close_matches

from dvc.repo.datasets import DatasetNotFoundError
from dvc.ui import ui

with self.repo.scm_context:
try:
dataset, new = self.repo.datasets.update(**vars(self.args))
except DatasetNotFoundError:
logger.exception("")
if matches := get_close_matches(self.args.name, self.repo.datasets):
ui.write(
"did you mean?",
ui.rich_text(matches[0], "cyan"),
stderr=True,
styled=True,
)
return 1
self.display(dataset, new)


def add_parser(subparsers, parent_parser):
ds_parser = subparsers.add_parser(
"dataset",
aliases=["ds"],
parents=[parent_parser],
formatter_class=formatter.RawDescriptionHelpFormatter,
)
ds_subparsers = ds_parser.add_subparsers(
dest="cmd",
help="Use `dvc dataset CMD --help` to display command-specific help.",
required=True,
)

dataset_add_help = "Add a dataset."
ds_add_parser = ds_subparsers.add_parser(
"add",
parents=[parent_parser],
description=append_doc_link(dataset_add_help, "dataset/add"),
formatter_class=formatter.RawTextHelpFormatter,
help=dataset_add_help,
)
ds_add_parser.add_argument(
"--url",
required=True,
help="""\
Location of the data to download. Supported URLs:

https://example.com/path/to/file
s3://bucket/key/path
gs://bucket/path/to/file/or/dir
hdfs://example.com/path/to/file
ssh://example.com/absolute/path/to/file/or/dir
remote://remote_name/path/to/file/or/dir (see `dvc remote`)
dvcx://dataset_name

To import data from dvc/git repositories, \
add dvc:// schema to the repo url, e.g:
dvc://[email protected]/iterative/example-get-started.git
dvc+https://github.com/iterative/example-get-started.git""",
)
ds_add_parser.add_argument(
"--name", help="Name of the dataset to add", required=True
)
ds_add_parser.add_argument(
"--rev",
help="Git revision, e.g. SHA, branch, tag "
"(only applicable for dvc/git repository)",
)
ds_add_parser.add_argument(
"--path", help="Path to a file or directory within the git repository"
)
ds_add_parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Overwrite existing dataset",
)
ds_add_parser.set_defaults(func=CmdDatasetAdd)

dataset_update_help = "Update a dataset."
ds_update_parser = ds_subparsers.add_parser(
"update",
parents=[parent_parser],
description=append_doc_link(dataset_update_help, "dataset/add"),
formatter_class=formatter.RawDescriptionHelpFormatter,
help=dataset_update_help,
)
ds_update_parser.add_argument("name", help="Name of the dataset to update")
ds_update_parser.set_defaults(func=CmdDatasetUpdate)
8 changes: 4 additions & 4 deletions dvc/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from sqlalchemy import create_engine
from sqlalchemy.engine import make_url as _make_url
from sqlalchemy.exc import NoSuchModuleError
from sqlalchemy import create_engine # type: ignore[import]
from sqlalchemy.engine import make_url as _make_url # type: ignore[import]
from sqlalchemy.exc import NoSuchModuleError # type: ignore[import]

from dvc import env
from dvc.exceptions import DvcException
Expand All @@ -17,7 +17,7 @@

if TYPE_CHECKING:
from sqlalchemy.engine import URL, Connectable, Engine
from sqlalchemy.sql.expression import Selectable
from sqlalchemy.sql.expression import Selectable # type: ignore[import]


logger = logger.getChild(__name__)
Expand Down
14 changes: 11 additions & 3 deletions dvc/dependency/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from funcy import compact, merge

from dvc.exceptions import DvcException
from dvc_data.hashfile.hash_info import HashInfo

from .db import AbstractDependency
Expand Down Expand Up @@ -47,17 +48,24 @@ def fill_values(self, values=None):
)

def workspace_status(self):
registered = self.repo.index.datasets.get(self.name, {})
ds = self.repo.datasets[self.name]
info: dict[str, Any] = self.hash_info.value if self.hash_info else {} # type: ignore[assignment]
if info != registered:

# TODO: what to do if dvc.lock and dvc.yaml are different
if not ds.lock or info != ds.lock.to_dict():
return {str(self): "modified"}
return {}

def status(self):
return self.workspace_status()

def get_hash(self):
return HashInfo(self.PARAM_DATASET, self.repo.index.datasets.get(self.name, {}))
ds = self.repo.datasets[self.name]
if not ds.lock:
raise DvcException(
f"Information missing for {self.name!r} dataset in dvc.lock"
)
return HashInfo(self.PARAM_DATASET, ds.lock.to_dict()) # type: ignore[arg-type]

def save(self):
self.hash_info = self.get_hash()
Expand Down
3 changes: 2 additions & 1 deletion dvc/dependency/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING:
from dvc.output import Output
from dvc.repo import Repo
from dvc.stage import Stage

logger = logger.getChild(__name__)
Expand All @@ -33,7 +34,7 @@ class AbstractDependency(Dependency):
"""Dependency without workspace/fs/fs_path"""

def __init__(self, stage: "Stage", info: dict[str, Any], *args, **kwargs):
self.repo = stage.repo
self.repo: "Repo" = stage.repo
self.stage = stage
self.fs = None
self.fs_path = None
Expand Down
37 changes: 37 additions & 0 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class SingleStageFile(FileMixin):
from dvc.stage.loader import SingleStageLoader as LOADER # noqa: N814

datasets: ClassVar[list[dict[str, Any]]] = []
datasets_lock: ClassVar[list[dict[str, Any]]] = []
metrics: ClassVar[list[str]] = []
plots: ClassVar[Any] = {}
params: ClassVar[list[str]] = []
Expand Down Expand Up @@ -240,6 +241,20 @@ def dump(self, stage, update_pipeline=True, update_lock=True, **kwargs):
if update_lock:
self._dump_lockfile(stage, **kwargs)

def dump_dataset(self, dataset):
with modify_yaml(self.path, fs=self.repo.fs) as data:
datasets: list[dict] = data.setdefault("datasets", [])
loc = next(
(i for i, ds in enumerate(datasets) if ds["name"] == dataset["name"]),
None,
)
if loc is not None:
apply_diff(dataset, datasets[loc])
datasets[loc] = dataset
else:
datasets.append(dataset)
self.repo.scm_context.track_file(self.relpath)

def _dump_lockfile(self, stage, **kwargs):
self._lockfile.dump(stage, **kwargs)

Expand Down Expand Up @@ -308,6 +323,10 @@ def params(self) -> list[str]:
def datasets(self) -> list[dict[str, Any]]:
return self.contents.get("datasets", [])

@property
def datasets_lock(self) -> list[dict[str, Any]]:
return self.lockfile_contents.get("datasets", [])

@property
def artifacts(self) -> dict[str, Optional[dict[str, Any]]]:
return self.contents.get("artifacts", {})
Expand Down Expand Up @@ -357,6 +376,24 @@ def _load(self, **kwargs: Any):
self._check_gitignored()
return {}, ""

def dump_dataset(self, dataset: dict):
with modify_yaml(self.path, fs=self.repo.fs) as data:
data.update({"schema": "2.0"})
if not data:
logger.info("Generating lock file '%s'", self.relpath)

datasets: list[dict] = data.setdefault("datasets", [])
loc = next(
(i for i, ds in enumerate(datasets) if ds["name"] == dataset["name"]),
None,
)
if loc is not None:
datasets[loc] = dataset
else:
datasets.append(dataset)
data.setdefault("stages", {})
self.repo.scm_context.track_file(self.relpath)

def dump(self, stage, **kwargs):
stage_data = serialize.to_lockfile(stage, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion dvc/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def dumpd(self, **kwargs): # noqa: C901, PLR0912

ret: dict[str, Any] = {}
with_files = (
(not self.IS_DEPENDENCY or self.stage.is_import)
(not self.IS_DEPENDENCY or kwargs.get("datasets") or self.stage.is_import)
and self.hash_info.isdir
and (kwargs.get("with_files") or self.files is not None)
)
Expand Down
2 changes: 2 additions & 0 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__( # noqa: PLR0915, PLR0913
from dvc.fs import GitFileSystem, LocalFileSystem, localfs
from dvc.lock import LockNoop, make_lock
from dvc.repo.artifacts import Artifacts
from dvc.repo.datasets import Datasets
from dvc.repo.metrics import Metrics
from dvc.repo.params import Params
from dvc.repo.plots import Plots
Expand Down Expand Up @@ -220,6 +221,7 @@ def __init__( # noqa: PLR0915, PLR0913
self.plots: Plots = Plots(self)
self.params: Params = Params(self)
self.artifacts: Artifacts = Artifacts(self)
self.datasets: Datasets = Datasets(self)

self.stage_collection_error_handler: Optional[
Callable[[str, Exception], None]
Expand Down
Loading
Loading