diff --git a/pyproject.toml b/pyproject.toml index 60a8628..5ab1381 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "smashed" -version = "0.19.6" +version = "0.20.0" description = """\ SMASHED is a toolkit designed to apply transformations to samples in \ datasets, such as fields extraction, tokenization, prompting, batching, \ @@ -97,6 +97,7 @@ dev = [ "ipdb>=0.13.0", "flake8-pyi>=22.8.1", "Flake8-pyproject>=1.1.0", + "moto[ec2,s3,all] >= 4.0.0", ] remote = [ "smart-open>=5.2.1", diff --git a/src/smashed/utils/__init__.py b/src/smashed/utils/__init__.py index ef49ea1..c0bec18 100644 --- a/src/smashed/utils/__init__.py +++ b/src/smashed/utils/__init__.py @@ -1,5 +1,11 @@ from .caching import get_cache_dir from .convert import bytes_from_int, int_from_bytes +from .io_utils import ( + MultiPath, + open_file_for_read, + open_file_for_write, + recursively_list_files, +) from .version import get_name, get_name_and_version, get_version from .warnings import SmashedWarnings from .wordsplitter import BlingFireSplitter, WhitespaceSplitter @@ -12,6 +18,10 @@ "get_name", "get_version", "int_from_bytes", + "MultiPath", + "open_file_for_read", + "open_file_for_write", + "recursively_list_files", "SmashedWarnings", "WhitespaceSplitter", ] diff --git a/src/smashed/utils/install_blingfire_macos.py b/src/smashed/utils/install_blingfire_macos.py index a986644..1dba93f 100644 --- a/src/smashed/utils/install_blingfire_macos.py +++ b/src/smashed/utils/install_blingfire_macos.py @@ -1,6 +1,9 @@ #! /usr/bin/env python3 + +import platform from subprocess import call +from warnings import warn BASH_SCRIPT = """ #! /usr/bin/env bash @@ -39,7 +42,17 @@ def main(): - call(BASH_SCRIPT.strip(), shell=True) + # check if we are on MacOS + if platform.system() != "Darwin": + warn("This script is only meant to be run on MacOS; skipping...") + return + + # check that architecture is arm64 + if platform.machine() != "arm64": + warn("This script is only meant to be run on arm64; skipping...") + return + + return call(BASH_SCRIPT.strip(), shell=True) if __name__ == "__main__": diff --git a/src/smashed/utils/io_utils/__init__.py b/src/smashed/utils/io_utils/__init__.py new file mode 100644 index 0000000..da32b07 --- /dev/null +++ b/src/smashed/utils/io_utils/__init__.py @@ -0,0 +1,32 @@ +from .closures import upload_on_success +from .compression import compress_stream, decompress_stream +from .multipath import MultiPath +from .operations import ( + copy_directory, + exists, + is_dir, + is_file, + open_file_for_read, + open_file_for_write, + recursively_list_files, + remove_directory, + remove_file, + stream_file_for_read, +) + +__all__ = [ + "compress_stream", + "copy_directory", + "decompress_stream", + "exists", + "is_dir", + "is_file", + "MultiPath", + "open_file_for_read", + "open_file_for_write", + "recursively_list_files", + "remove_directory", + "remove_file", + "stream_file_for_read", + "upload_on_success", +] diff --git a/src/smashed/utils/io_utils/closures.py b/src/smashed/utils/io_utils/closures.py new file mode 100644 index 0000000..0d4d40d --- /dev/null +++ b/src/smashed/utils/io_utils/closures.py @@ -0,0 +1,107 @@ +from contextlib import AbstractContextManager, ExitStack +from functools import partial +from tempfile import TemporaryDirectory +from typing import Callable, Optional, TypeVar + +from typing_extensions import Concatenate, ParamSpec + +from .multipath import MultiPath +from .operations import PathType, copy_directory, remove_directory + +T = TypeVar("T") +P = ParamSpec("P") + + +class upload_on_success(AbstractContextManager): + """Context manager to upload a directory of results to a remote + location if the execution in the context manager is successful. + + You can use this class in two ways: + + 1. As a context manager + + ```python + + with upload_on_success('s3://my-bucket/my-results') as path: + # run training, save temporary results in `path` + ... + ``` + + 2. As a function decorator + + ```python + @upload_on_success('s3://my-bucket/my-results') + def my_function(path: str, ...) + # run training, save temporary results in `path` + ``` + + You can specify a local destination by passing `local_path` to + `upload_on_success`. Otherwise, a temporary directory is created for you. + """ + + def __init__( + self, + remote_path: PathType, + local_path: Optional[PathType] = None, + keep_local: bool = False, + ) -> None: + """Constructor for upload_on_success context manager + + Args: + remote_path (str or urllib.parse.ParseResult): The remote location + to upload to (e.g., an S3 prefix for a bucket you have + access to). + local_path (str or Path): The local path where to temporarily + store files before upload. If not provided, a temporary + directory is created for you and returned by the context + manager. It will be deleted at the end of the context + (unless keep_local is set to True). Defaults to None + keep_local (bool, optional): Whether to keep the local results + as well as uploading to the remote path. Only available + if `local_path` is provided. + """ + + self._ctx = ExitStack() + self.remote_path = remote_path + self.local_path = MultiPath.parse( + local_path or self._ctx.enter_context(TemporaryDirectory()) + ) + if local_path is None and keep_local: + raise ValueError( + "Cannot keep local destination if `local_path` is None" + ) + self.keep_local = keep_local + + super().__init__() + + def _decorated( + self, + func: Callable[Concatenate[str, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: + with type(self)( + local_path=self.local_path, + remote_path=self.remote_path, + keep_local=self.keep_local, + ) as path: + output = func(path.as_str, *args, **kwargs) + return output + + def __call__( + self, func: Callable[Concatenate[str, P], T] + ) -> Callable[P, T]: + return partial(self._decorated, func=func) # type: ignore + + def __enter__(self): + return self.local_path + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + # all went well, so we copy the local directory to the remote + copy_directory(src=self.local_path, dst=self.remote_path) + + if not self.keep_local: + remove_directory(self.local_path) + + self._ctx.close() diff --git a/src/smashed/utils/io_utils/compression.py b/src/smashed/utils/io_utils/compression.py new file mode 100644 index 0000000..345cc34 --- /dev/null +++ b/src/smashed/utils/io_utils/compression.py @@ -0,0 +1,66 @@ +import gzip as gz +import io +from contextlib import contextmanager +from typing import IO, Iterator, Literal, Optional, cast + +from .io_wrappers import BytesZLibDecompressorIO, TextZLibDecompressorIO + + +@contextmanager +def decompress_stream( + stream: IO, + mode: Literal["r", "rt", "rb"] = "rt", + encoding: Optional[str] = "utf-8", + errors: str = "strict", + chunk_size: int = io.DEFAULT_BUFFER_SIZE, + gzip: bool = True, +) -> Iterator[IO]: + out: io.IOBase + + if mode == "rb" or mode == "r": + out = BytesZLibDecompressorIO( + stream=stream, chunk_size=chunk_size, gzip=gzip + ) + elif mode == "rt": + assert encoding is not None, "encoding must be provided for text mode" + out = TextZLibDecompressorIO( + stream=stream, + chunk_size=chunk_size, + gzip=gzip, + encoding=encoding, + errors=errors, + ) + else: + raise ValueError(f"Unsupported mode: {mode}") + + # cast to IO to satisfy mypy, then yield + yield cast(IO, out) + + # Flush and close the stream + out.close() + + +@contextmanager +def compress_stream( + stream: IO, + mode: Literal["w", "wt", "wb"] = "wt", + encoding: Optional[str] = "utf-8", + errors: str = "strict", + gzip: bool = True, +) -> Iterator[IO]: + + assert gzip, "Only gzip compression is supported at this time" + + if mode == "wb" or mode == "w": + out = gz.open(stream, mode=mode) + elif mode == "wt": + assert encoding is not None, "encoding must be provided for text mode" + out = gz.open(stream, mode=mode, encoding=encoding, errors=errors) + else: + raise ValueError(f"Unsupported mode: {mode}") + + # cast to IO to satisfy mypy, then yield + yield cast(IO, out) + + # Flush and close the stream + out.close() diff --git a/src/smashed/utils/io_utils/io_wrappers.py b/src/smashed/utils/io_utils/io_wrappers.py new file mode 100644 index 0000000..1ae23e5 --- /dev/null +++ b/src/smashed/utils/io_utils/io_wrappers.py @@ -0,0 +1,173 @@ +import io +import zlib +from typing import IO, Any, Generic, Iterator, Optional, TypeVar + +T = TypeVar("T", bound=Any) + + +class ReadIO(io.IOBase, Generic[T]): + def __init__( + self, + stream: IO, + chunk_size: int = io.DEFAULT_BUFFER_SIZE, + ): + self.stream = stream + self.ready_buffer = bytearray() + self.chunk_size = chunk_size + self._closed = False + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def read(self, __size: Optional[int] = None) -> T: + raise NotImplementedError + + def readline(self, __size: Optional[int] = None) -> T: + raise NotImplementedError + + def _process_data(self, data: bytes) -> bytes: + return data + + def __next__(self) -> T: + try: + return super().__next__() # type: ignore + except StopIteration as stop: + self._closed = True + raise stop + + def __iter__(self) -> Iterator[T]: + return self + + def _readline(self, size: int = -1) -> bytes: + while size < 0 or len(self.ready_buffer) < size: + if b"\n" in self.ready_buffer: + break + + read_data = self.stream.read(self.chunk_size) + if not read_data: + break + + processed_data = self._process_data(read_data) + if read_data and not processed_data: + raise RuntimeError(f"{self.__class__.__name__} failed") + + self.ready_buffer.extend(processed_data) + + loc = self.ready_buffer.find(b"\n") + if loc >= 0: + return_value = self.ready_buffer[: loc + 1] + self.ready_buffer = self.ready_buffer[loc + 1 :] + else: + return_value = self.ready_buffer + self.ready_buffer = bytearray() + + if not (return_value or self.ready_buffer) and size != 0: + # user has requested more than 0 bytes but there is nothing + # left in the buffer to read + raise StopIteration() + + return bytes(return_value) + + def _read(self, size: int = -1) -> bytes: + while size < 0 or len(self.ready_buffer) < size: + read_data = self.stream.read(self.chunk_size) + if not read_data: + break + + processed_data = self._process_data(read_data) + if read_data and not processed_data: + raise RuntimeError(f"{self.__class__.__name__} failed") + + self.ready_buffer.extend(processed_data) + + # If size equals -1, return all available data + if size < 0: + return_value = self.ready_buffer + self.ready_buffer = bytearray() + else: + return_value = self.ready_buffer[:size] + self.ready_buffer = self.ready_buffer[size:] + + if not (return_value or self.ready_buffer) and size != 0: + # user has requested more than 0 bytes but there is nothing + # left in the buffer to read + raise StopIteration() + + return bytes(return_value) + + +class ReadBytesIO(ReadIO[bytes], io.RawIOBase): + def read(self, __size: Optional[int] = None) -> bytes: + return self._read(__size or -1) + + def readline(self, __size: Optional[int] = None) -> bytes: + return self._readline(__size or -1) + + +class ReadTextIO(ReadIO[str], io.TextIOBase): + def __init__( + self, + stream: IO, + chunk_size: int = io.DEFAULT_BUFFER_SIZE, + encoding: str = "utf-8", + errors: str = "strict", + ): + super().__init__(stream=stream, chunk_size=chunk_size) + self._encoding = encoding + self._errors = errors + + def read(self, __size: Optional[int] = None) -> str: + out = self._read(__size or -1) + return out.decode(encoding=self._encoding, errors=self._errors) + + def readline(self, __size: Optional[int] = None) -> str: # type: ignore + out = self._readline(__size or -1) + return out.decode(encoding=self._encoding, errors=self._errors) + + +class BaseZlibDecompressorIO(ReadIO[T], Generic[T]): + def __init__( + self, + stream: IO, + chunk_size: int = io.DEFAULT_BUFFER_SIZE, + gzip: bool = True, + ): + gzip_offset = 16 if gzip else 0 + self.decoder = zlib.decompressobj(gzip_offset + zlib.MAX_WBITS) + super().__init__(stream=stream, chunk_size=chunk_size) + + def _process_data(self, data: bytes) -> bytes: + return self.decoder.decompress(data) + + +class BytesZLibDecompressorIO(BaseZlibDecompressorIO[bytes], ReadBytesIO): + """Wraps a zlib decompressor so that it can be used as a file-like + object. Returns bytes.""" + + ... + + +class TextZLibDecompressorIO(BaseZlibDecompressorIO[str], ReadTextIO): + def __init__( + self, + stream: IO, + chunk_size: int = io.DEFAULT_BUFFER_SIZE, + gzip: bool = True, + encoding: str = "utf-8", + errors: str = "strict", + ): + BaseZlibDecompressorIO.__init__( + self, stream=stream, chunk_size=chunk_size, gzip=gzip + ) + ReadTextIO.__init__( + self, + stream=stream, + chunk_size=chunk_size, + encoding=encoding, + errors=errors, + ) + + ... diff --git a/src/smashed/utils/io_utils/multipath.py b/src/smashed/utils/io_utils/multipath.py new file mode 100644 index 0000000..6487a20 --- /dev/null +++ b/src/smashed/utils/io_utils/multipath.py @@ -0,0 +1,148 @@ +import re +from dataclasses import dataclass +from logging import getLogger +from pathlib import Path +from typing import TYPE_CHECKING, Any, Union +from urllib.parse import urlparse + +from necessary import necessary + +with necessary("boto3", soft=True) as BOTO_AVAILABLE: + if TYPE_CHECKING or BOTO_AVAILABLE: + from botocore.client import BaseClient + + +PathType = Union[str, Path, "MultiPath"] +ClientType = Union["BaseClient", None] + +LOGGER = getLogger(__file__) + + +@dataclass +class MultiPath: + """A path object that can handle both local and remote paths.""" + + prot: str + root: str + path: str + + def __post_init__(self): + SUPPORTED_PROTOCOLS = {"s3", "file"} + if self.prot and self.prot not in SUPPORTED_PROTOCOLS: + raise ValueError( + f"Unsupported protocol: {self.prot}; " + f"supported protocols are {SUPPORTED_PROTOCOLS}" + ) + + @classmethod + def parse(cls, path: PathType) -> "MultiPath": + """Parse a path into a PathParser object. + + Args: + path (str): The path to parse. + """ + if isinstance(path, cls): + return path + elif isinstance(path, Path): + path = str(path) + elif not isinstance(path, str): + raise ValueError(f"Cannot parse path of type {type(path)}") + + p = urlparse(str(path)) + return cls(prot=p.scheme, root=p.netloc, path=p.path) + + @property + def is_s3(self) -> bool: + """Is true if the path is an S3 path.""" + return self.prot == "s3" + + @property + def is_local(self) -> bool: + """Is true if the path is a local path.""" + return self.prot == "file" or self.prot == "" + + def _remove_extra_slashes(self, path: str) -> str: + return re.sub(r"//+", "/", path) + + def __str__(self) -> str: + if self.prot: + loc = self._remove_extra_slashes(f"{self.root}/{self.path}") + return f"{self.prot}://{loc}" + elif self.root: + return self._remove_extra_slashes(f"/{self.root}/{self.path}") + else: + return self._remove_extra_slashes(self.path) + + @property + def bucket(self) -> str: + """If the path is an S3 path, return the bucket name. + Otherwise, raise a ValueError.""" + if not self.is_s3: + raise ValueError(f"Not an S3 path: {self}") + return self.root + + @property + def key(self) -> str: + """If the path is an S3 path, return the prefix. + Otherwise, raise a ValueError.""" + if not self.is_s3: + raise ValueError(f"Not an S3 path: {self}") + return self.path.lstrip("/") + + @property + def as_path(self) -> Path: + """Return the path as a pathlib.Path object.""" + if not self.is_local: + raise ValueError(f"Not a local path: {self}") + return Path(self.as_str) + + def __hash__(self) -> int: + return hash(self.as_str) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, (MultiPath, str, Path)): + return False + + other = MultiPath.parse(other) + return self.as_str == other.as_str + + @property + def as_str(self) -> str: + """Return the path as a string.""" + return str(self) + + def __truediv__(self, other: PathType) -> "MultiPath": + """Join two paths together using the / operator.""" + other = MultiPath.parse(other) + + if isinstance(other, MultiPath) and other.prot: + raise ValueError(f"Cannot combine fully formed path {other}") + + return MultiPath( + prot=self.prot, + root=self.root, + path=f"{self.path.rstrip('/')}/{str(other).lstrip('/')}", + ) + + def __len__(self) -> int: + return len(self.as_str) + + def __sub__(self, other: PathType) -> "MultiPath": + _o_str = MultiPath.parse(other).as_str + _s_str = self.as_str + loc = _s_str.find(_o_str) + return MultiPath.parse(_s_str[:loc] + _s_str[loc + len(_o_str) :]) + + @classmethod + def join(cls, *others: PathType) -> "MultiPath": + """Join multiple paths together; each path can be a string, + pathlib.Path, or MultiPath object.""" + if not others: + raise ValueError("No paths provided") + + first, *rest = others + first = cls.parse(first) + for part in rest: + # explicitly call __div__ to avoid mypy errors + first = first / part + return first diff --git a/src/smashed/utils/io_utils.py b/src/smashed/utils/io_utils/operations.py similarity index 64% rename from src/smashed/utils/io_utils.py rename to src/smashed/utils/io_utils/operations.py index 7d17474..35d09b0 100644 --- a/src/smashed/utils/io_utils.py +++ b/src/smashed/utils/io_utils/operations.py @@ -1,14 +1,12 @@ -import re +import io import shutil -from contextlib import AbstractContextManager, ExitStack, contextmanager -from dataclasses import dataclass -from functools import partial +from contextlib import ExitStack, contextmanager from logging import Logger, getLogger from os import remove as remove_local_file from os import stat as stat_local_file from os import walk as local_walk from pathlib import Path -from tempfile import NamedTemporaryFile, TemporaryDirectory, gettempdir +from tempfile import NamedTemporaryFile, gettempdir from typing import ( IO, TYPE_CHECKING, @@ -18,13 +16,14 @@ Generator, Iterable, Optional, - TypeVar, Union, + cast, ) -from urllib.parse import urlparse from necessary import necessary -from typing_extensions import Concatenate, ParamSpec + +from .io_wrappers import ReadBytesIO, ReadTextIO +from .multipath import MultiPath with necessary("boto3", soft=True) as BOTO_AVAILABLE: if TYPE_CHECKING or BOTO_AVAILABLE: @@ -32,155 +31,12 @@ from botocore.client import BaseClient -__all__ = [ - "copy_directory", - "exists", - "is_dir", - "is_file", - "open_file_for_read", - "open_file_for_write", - "recursively_list_files", - "remove_directory", - "remove_file", - "upload_on_success", -] - PathType = Union[str, Path, "MultiPath"] ClientType = Union["BaseClient", None] LOGGER = getLogger(__file__) -@dataclass -class MultiPath: - """A path object that can handle both local and remote paths.""" - - prot: str - root: str - path: str - - def __post_init__(self): - SUPPORTED_PROTOCOLS = {"s3", "file"} - if self.prot and self.prot not in SUPPORTED_PROTOCOLS: - raise ValueError( - f"Unsupported protocol: {self.prot}; " - f"supported protocols are {SUPPORTED_PROTOCOLS}" - ) - - @classmethod - def parse(cls, path: PathType) -> "MultiPath": - """Parse a path into a PathParser object. - - Args: - path (str): The path to parse. - """ - if isinstance(path, cls): - return path - elif isinstance(path, Path): - path = str(path) - elif not isinstance(path, str): - raise ValueError(f"Cannot parse path of type {type(path)}") - - p = urlparse(str(path)) - return cls(prot=p.scheme, root=p.netloc, path=p.path) - - @property - def is_s3(self) -> bool: - """Is true if the path is an S3 path.""" - return self.prot == "s3" - - @property - def is_local(self) -> bool: - """Is true if the path is a local path.""" - return self.prot == "file" or self.prot == "" - - def _remove_extra_slashes(self, path: str) -> str: - return re.sub(r"//+", "/", path) - - def __str__(self) -> str: - if self.prot: - loc = self._remove_extra_slashes(f"{self.root}/{self.path}") - return f"{self.prot}://{loc}" - elif self.root: - return self._remove_extra_slashes(f"/{self.root}/{self.path}") - else: - return self._remove_extra_slashes(self.path) - - @property - def bucket(self) -> str: - """If the path is an S3 path, return the bucket name. - Otherwise, raise a ValueError.""" - if not self.is_s3: - raise ValueError(f"Not an S3 path: {self}") - return self.root - - @property - def key(self) -> str: - """If the path is an S3 path, return the prefix. - Otherwise, raise a ValueError.""" - if not self.is_s3: - raise ValueError(f"Not an S3 path: {self}") - return self.path.lstrip("/") - - @property - def as_path(self) -> Path: - """Return the path as a pathlib.Path object.""" - if not self.is_local: - raise ValueError(f"Not a local path: {self}") - return Path(self.as_str) - - def __hash__(self) -> int: - return hash(self.as_str) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, (MultiPath, str, Path)): - return False - - other = MultiPath.parse(other) - return self.as_str == other.as_str - - @property - def as_str(self) -> str: - """Return the path as a string.""" - return str(self) - - def __truediv__(self, other: PathType) -> "MultiPath": - """Join two paths together using the / operator.""" - other = MultiPath.parse(other) - - if isinstance(other, MultiPath) and other.prot: - raise ValueError(f"Cannot combine fully formed path {other}") - - return MultiPath( - prot=self.prot, - root=self.root, - path=f"{self.path.rstrip('/')}/{str(other).lstrip('/')}", - ) - - def __len__(self) -> int: - return len(self.as_str) - - def __sub__(self, other: PathType) -> "MultiPath": - _o_str = MultiPath.parse(other).as_str - _s_str = self.as_str - loc = _s_str.find(_o_str) - return MultiPath.parse(_s_str[:loc] + _s_str[loc + len(_o_str) :]) - - @classmethod - def join(cls, *others: PathType) -> "MultiPath": - """Join multiple paths together; each path can be a string, - pathlib.Path, or MultiPath object.""" - if not others: - raise ValueError("No paths provided") - - first, *rest = others - first = cls.parse(first) - for part in rest: - # explicitly call __div__ to avoid mypy errors - first = first / part - return first - - def get_client_if_needed(path: PathType, **boto3_kwargs: Any) -> ClientType: """Return the appropriate client given the protocol of the path.""" @@ -219,6 +75,60 @@ def get_temp_dir(path: Optional[PathType]) -> Path: return path +@contextmanager +def stream_file_for_read( + path: PathType, + mode: str = "r", + open_fn: Optional[Callable] = None, + logger: Optional[Logger] = None, + open_kwargs: Optional[Dict[str, Any]] = None, + client: Optional[ClientType] = None, +) -> Generator[IO, None, None]: + """Just like open_file_for_read, but returns a file-like object that + streams content from remote files instead of saving it locally first. + + Args: + path (Union[str, Path, MultiPath]): The path to the file to read. + mode (str, optional): The mode to open the file in. Defaults to "r". + open_fn (Callable, optional): The function to use to open the file. + Defaults to the built-in open function. + logger (Logger, optional): The logger to use. Defaults to the built-in + logger at INFO level. + open_kwargs (Dict[str, Any], optional): Any additional keyword to pass + to the open function. Defaults to None. + client (ClientType, optional): The client to use to download the file. + If not provided, one will be created using the default boto3 + if necessary. Defaults to None. + """ + + open_kwargs = open_kwargs or {} + logger = logger or LOGGER + open_fn = open_fn or open + + assert "r" in mode, "Only read mode is supported" + + path = MultiPath.parse(path) + + if path.is_s3: + client = client or get_client_if_needed(path) + assert client is not None, "Could not get S3 client" + + obj = client.get_object(Bucket=path.bucket, Key=path.key.lstrip("/")) + + stream: io.IOBase + if "b" in mode: + stream = ReadBytesIO(obj["Body"]) + else: + stream = ReadTextIO(obj["Body"]) + + yield cast(IO, stream) + elif path.is_local: + with open_fn(file=path.as_str, mode=mode, **open_kwargs) as f: + yield f + else: + raise ValueError(f"Unsupported protocol: {path.prot}") + + @contextmanager def open_file_for_read( path: PathType, @@ -357,6 +267,8 @@ def open_file_for_write( path (Union[str, Path, MultiPath]): The path to the file to write. mode (str, optional): The mode to open the file in. Defaults to "w". Only read modes are supported (e.g. 'wb', 'w', ...). + skip_if_empty (bool, optional): If True, the file will not be + written if the content is empty. Defaults to False. open_fn (Callable, optional): The function to use to open the file. Defaults to the built-in open function. logger (Logger, optional): The logger to use. Defaults to the built-in @@ -562,102 +474,3 @@ def remove_directory(path: PathType, client: Optional[ClientType] = None): if path.is_local: shutil.rmtree(path.as_str, ignore_errors=True) - - -T = TypeVar("T") -P = ParamSpec("P") - - -class upload_on_success(AbstractContextManager): - """Context manager to upload a directory of results to a remote - location if the execution in the context manager is successful. - - You can use this class in two ways: - - 1. As a context manager - - ```python - - with upload_on_success('s3://my-bucket/my-results') as path: - # run training, save temporary results in `path` - ... - ``` - - 2. As a function decorator - - ```python - @upload_on_success('s3://my-bucket/my-results') - def my_function(path: str, ...) - # run training, save temporary results in `path` - ``` - - You can specify a local destination by passing `local_path` to - `upload_on_success`. Otherwise, a temporary directory is created for you. - """ - - def __init__( - self, - remote_path: PathType, - local_path: Optional[PathType] = None, - keep_local: bool = False, - ) -> None: - """Constructor for upload_on_success context manager - - Args: - remote_path (str or urllib.parse.ParseResult): The remote location - to upload to (e.g., an S3 prefix for a bucket you have - access to). - local_path (str or Path): The local path where to temporarily - store files before upload. If not provided, a temporary - directory is created for you and returned by the context - manager. It will be deleted at the end of the context - (unless keep_local is set to True). Defaults to None - keep_local (bool, optional): Whether to keep the local results - as well as uploading to the remote path. Only available - if `local_path` is provided. - """ - - self._ctx = ExitStack() - self.remote_path = remote_path - self.local_path = MultiPath.parse( - local_path or self._ctx.enter_context(TemporaryDirectory()) - ) - if local_path is None and keep_local: - raise ValueError( - "Cannot keep local destination if `local_path` is None" - ) - self.keep_local = keep_local - - super().__init__() - - def _decorated( - self, - func: Callable[Concatenate[str, P], T], - *args: P.args, - **kwargs: P.kwargs, - ) -> T: - with type(self)( - local_path=self.local_path, - remote_path=self.remote_path, - keep_local=self.keep_local, - ) as path: - output = func(path.as_str, *args, **kwargs) - return output - - def __call__( - self, func: Callable[Concatenate[str, P], T] - ) -> Callable[P, T]: - return partial(self._decorated, func=func) # type: ignore - - def __enter__(self): - return self.local_path - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is None: - # all went well, so we copy the local directory to the remote - copy_directory(src=self.local_path, dst=self.remote_path) - - if not self.keep_local: - remove_directory(self.local_path) - - self._ctx.close() diff --git a/tests/__init__.py b/tests/__init__.py index 6e94092..819122f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,5 @@ +import os + from necessary import necessary with necessary("datasets", soft=True): @@ -5,3 +7,10 @@ # disable huggingface progress bar when running tests disable_progress_bar() + + +os.environ["AWS_ACCESS_KEY_ID"] = "testing" +os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" +os.environ["AWS_SECURITY_TOKEN"] = "testing" +os.environ["AWS_SESSION_TOKEN"] = "testing" +os.environ["AWS_DEFAULT_REGION"] = "us-east-1" diff --git a/tests/fixtures/compressed_jsonl/arxiv.gz b/tests/fixtures/compressed_jsonl/arxiv.gz new file mode 100644 index 0000000..29cdb55 Binary files /dev/null and b/tests/fixtures/compressed_jsonl/arxiv.gz differ diff --git a/tests/fixtures/compressed_jsonl/c4-train.gz b/tests/fixtures/compressed_jsonl/c4-train.gz new file mode 100644 index 0000000..53aac2e Binary files /dev/null and b/tests/fixtures/compressed_jsonl/c4-train.gz differ diff --git a/tests/test_decompression.py b/tests/test_decompression.py new file mode 100644 index 0000000..0b31737 --- /dev/null +++ b/tests/test_decompression.py @@ -0,0 +1,59 @@ +import io +import json +import unittest +from pathlib import Path + +from smashed.utils.io_utils import compress_stream, decompress_stream + +FIXTURES_PATH = Path(__file__).parent / "fixtures" + + +class TestDeCompression(unittest.TestCase): + def setUp(self) -> None: + self.arxiv_path = FIXTURES_PATH / "compressed_jsonl" / "arxiv.gz" + self.c4_train_path = FIXTURES_PATH / "compressed_jsonl" / "c4-train.gz" + + def test_bytes_compression(self): + cnt = 0 + with open(self.arxiv_path, "rb") as f: + with decompress_stream(f, "rb", gzip=True) as g: + for ln in g: + json.loads(ln) + cnt += 1 + self.assertEqual(cnt, 9) + + cnt = 0 + with open(self.c4_train_path, "rb") as f: + with decompress_stream(f, "rb", gzip=True) as g: + for ln in g: + json.loads(ln) + cnt += 1 + self.assertEqual(cnt, 185) + + def test_text_compression(self): + cnt = 0 + with open(self.arxiv_path, "rb") as f: + with decompress_stream(f, gzip=True) as g: + for ln in g: + json.loads(ln) + cnt += 1 + self.assertEqual(cnt, 9) + + cnt = 0 + with open(self.c4_train_path, "rb") as f: + with decompress_stream(f, "rt", gzip=True) as g: + for ln in g: + json.loads(ln) + cnt += 1 + self.assertEqual(cnt, 185) + + def test_compression(self): + text = "This is a test\nWith multiple lines\nBye!" + stream = io.BytesIO() + + with compress_stream(stream, "wt", gzip=True) as f: + f.write(text) + + stream.seek(0) + with decompress_stream(stream, "rt", gzip=True) as g: + self.assertEqual(g.read(), text) diff --git a/tests/test_s3.py b/tests/test_s3.py new file mode 100644 index 0000000..2f38b5e --- /dev/null +++ b/tests/test_s3.py @@ -0,0 +1,65 @@ +import unittest +from logging import getLogger + +import boto3 +import moto + +from smashed.utils.io_utils import ( + open_file_for_read, + open_file_for_write, + stream_file_for_read, +) + + +class TestIo(unittest.TestCase): + mock_s3 = moto.mock_s3() + BUCKET_NAME = "mytestbucket" + FILE_KEY = "test.jsonl" + CONTENT = "This is a test\nWith multiple lines\nBye!" + REGION = "us-east-1" + + def setUp(self): + self.mock_s3.start() + self.conn = boto3.resource("s3", region_name=self.REGION) + self.conn.create_bucket(Bucket=self.BUCKET_NAME) + self.client = boto3.client("s3", region_name=self.REGION) + getLogger("botocore").setLevel("INFO") + + def tearDown(self): + self.mock_s3.stop() + + @property + def PREFIX(self): + return f"s3://{self.BUCKET_NAME}/{self.FILE_KEY}" + + def _write_file(self): + self.client.put_object( + Bucket=self.BUCKET_NAME, Key=self.FILE_KEY, Body=self.CONTENT + ) + + def _read_file(self): + r = self.client.get_object(Bucket=self.BUCKET_NAME, Key=self.FILE_KEY) + return r["Body"].read().decode("utf-8") + + def test_read_from_s3(self): + self._write_file() + with open_file_for_read(self.PREFIX) as f: + self.assertEqual(f.read(), self.CONTENT) + + def test_write_to_s3(self): + with open_file_for_write(self.PREFIX) as f: + f.write(self.CONTENT) + + content = self._read_file() + self.assertEqual(content, self.CONTENT) + + def test_stream_from_s3(self): + self._write_file() + with stream_file_for_read(self.PREFIX) as f: + self.assertEqual(f.read(), self.CONTENT) + + def test_stream_lines_from_s3(self): + self._write_file() + with stream_file_for_read(self.PREFIX) as f: + for la, lb in zip(f, self.CONTENT.split("\n")): + self.assertEqual(la.strip(), lb)