Skip to content

Commit

Permalink
dvc.api for dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Feb 23, 2024
1 parent 81c7c3b commit 3309e15
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions dvc/api/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Literal, TypedDict, Union


class DVCXDataset(TypedDict):
type: Literal["dvcx"]
name: str
version: int


class DVCDataset(TypedDict):
type: Literal["dvc"]
url: str
path: str
sha: str


class URLDataset(TypedDict):
type: Literal["url"]
files: list[str]
path: str


def get(name: str) -> Union[DVCXDataset, DVCDataset, URLDataset]:
from difflib import get_close_matches

from dvc.fs import get_cloud_fs
from dvc.repo import Repo, datasets

repo = Repo()
try:
dataset = repo.datasets[name]
except datasets.DatasetNotFoundError as e:
add_note = getattr(e, "add_note", lambda _: None)
if matches := get_close_matches(name, repo.datasets):
add_note(f"Did you mean: {matches[0]!r}?")
raise

lock = dataset.lock
if not lock:
raise ValueError("missing lock information")

if isinstance(lock, datasets.DVCXDataset):
return DVCXDataset(type="dvcx", name=lock.name_version[0], version=lock.version)
if isinstance(lock, datasets.DVCDataset):
return DVCDataset(
type="dvc", url=dataset.url, path=lock.path, sha=lock.rev_lock
)
if isinstance(lock, datasets.URLDataset):
fs_cls, _, path = get_cloud_fs(repo.config, url=dataset.url)
assert fs_cls
join_version = getattr(fs_cls, "join_version", lambda path, _: path)
protocol = fs_cls.protocol
versioned_path = join_version(path, lock.meta.version_id)
versioned_path = f"{protocol}://{versioned_path}"
files = [
join_version(
fs_cls.join(versioned_path, file.relpath), file.meta.version_id
)
for file in lock.files
]
return URLDataset(type="url", files=files, path=versioned_path)

0 comments on commit 3309e15

Please sign in to comment.