Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a FSspec based ArtifactLoader #100

Merged
merged 8 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dev = [
"psutil",
"pyarrow",
"pytest>=7.4.0",
"fsspec"
]
docs = [
"black",
Expand Down Expand Up @@ -85,7 +86,7 @@ strict_optional = false
warn_unreachable = true

[[tool.mypy.overrides]]
module = ["tabulate", "yaml"]
module = ["tabulate", "yaml", "fsspec"]
ignore_missing_imports = true

[tool.ruff]
Expand Down
90 changes: 82 additions & 8 deletions src/nnbench/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,31 @@
import copy
import inspect
import os
import shutil
import weakref
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, Iterable, Iterator, Literal, TypeVar
from pathlib import Path
from tempfile import mkdtemp
from typing import (
Any,
Callable,
Generic,
Iterable,
Iterator,
Literal,
TypeVar,
)

from nnbench.context import Context

try:
import fsspec

HAS_FSSPEC = True
except ImportError:
HAS_FSSPEC = False

T = TypeVar("T")
Variable = tuple[str, type, Any]

Expand Down Expand Up @@ -110,22 +129,77 @@ class LocalArtifactLoader(ArtifactLoader):
Parameters
----------
path : str | os.PathLike[str]
The file system pathto the artifact.
The file system path to the artifact.
"""

def __init__(self, path: str | os.PathLike[str]):
def __init__(self, path: str | os.PathLike[str]) -> None:
self._path = path

def load(self):
def load(self) -> Path:
"""
Returns the path to the artifact on the local file system.
"""
return self._path
return Path(self._path).resolve()


class S3ArtifactLoader(ArtifactLoader):
# TODO: Implement this and other common ArtifactLoders here or in a util
pass
class FilePathArtifactLoader(ArtifactLoader):
"""
ArtifactLoader for loading artifacts using fsspec-supported file systems.

This allows for loading from various file systems like local, S3, GCS, etc.,
by using a unified API provided by fsspec.

Parameters
----------
path : str | os.PathLike[str]
The path to the artifact, which can include a protocol specifier (like 's3://') for remote access.
destination : str | os.PathLike[str] | None
The local directory to which remote artifacts will be downloaded. If provided, the model data will be persisted. Otherwise, local artifacts are cleaned.
storage_options : dict[str, Any] | None
Storage options for remote storage.
"""

def __init__(
self,
path: str | os.PathLike[str],
destination: str | os.PathLike[str] | None = None,
storage_options: dict[str, Any] | None = None,
) -> None:
self.source_path = str(path)
if destination:
self.target_path = str(Path(destination).resolve())
delete = False
else:
self.target_path = str(Path(mkdtemp()).resolve())
delete = True
self._finalizer = weakref.finalize(self, self._cleanup, delete=delete)
self.storage_options = storage_options or {}

def load(self) -> Path:
"""
Loads the artifact and returns the local path.

Returns
-------
Path
The path to the artifact on the local filesystem.

Raises
------
ImportError
When fsspec is not installed.
"""
if not HAS_FSSPEC:
raise ImportError(
"class {self.__class__.__name__} requires `fsspec` to be installed. You can install it by running `python -m pip install --upgrade fsspec`"
)
fs = fsspec.filesystem(fsspec.utils.get_protocol(self.source_path))
fs.get(self.source_path, self.target_path, recursive=True)
return Path(self.target_path).resolve()

def _cleanup(self, delete: bool) -> None:
if delete:
shutil.rmtree(self.target_path)


class Artifact(Generic[T], metaclass=ABCMeta):
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@
def testfolder() -> str:
"""A test directory for benchmark collection."""
return str(HERE / "benchmarks")


@pytest.fixture
def local_file(tmp_path: Path) -> Path:
file_path = tmp_path / "test_file.txt"
file_path.write_text("Test content")
return file_path
11 changes: 11 additions & 0 deletions tests/test_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pathlib import Path

from nnbench.types import FilePathArtifactLoader


def test_load_local_file(local_file: Path, tmp_path: Path) -> None:
test_dir = tmp_path / "test_load_dir"
loader = FilePathArtifactLoader(local_file, test_dir)
loaded_path: Path = loader.load()
assert loaded_path.exists()
assert loaded_path.read_text() == "Test content"