Skip to content

Commit

Permalink
Added tools to stream from S3 (#57)
Browse files Browse the repository at this point in the history
* added compression lib

* renamed

* style

* added extra function to decompress

* added option to compress a stream

* docs
  • Loading branch information
soldni authored May 21, 2023
1 parent d2b214a commit 5fd542c
Show file tree
Hide file tree
Showing 14 changed files with 748 additions and 252 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smashed"
version = "0.19.6"
version = "0.20.0"
description = """\
SMASHED is a toolkit designed to apply transformations to samples in \
datasets, such as fields extraction, tokenization, prompting, batching, \
Expand Down Expand Up @@ -97,6 +97,7 @@ dev = [
"ipdb>=0.13.0",
"flake8-pyi>=22.8.1",
"Flake8-pyproject>=1.1.0",
"moto[ec2,s3,all] >= 4.0.0",
]
remote = [
"smart-open>=5.2.1",
Expand Down
10 changes: 10 additions & 0 deletions src/smashed/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .caching import get_cache_dir
from .convert import bytes_from_int, int_from_bytes
from .io_utils import (
MultiPath,
open_file_for_read,
open_file_for_write,
recursively_list_files,
)
from .version import get_name, get_name_and_version, get_version
from .warnings import SmashedWarnings
from .wordsplitter import BlingFireSplitter, WhitespaceSplitter
Expand All @@ -12,6 +18,10 @@
"get_name",
"get_version",
"int_from_bytes",
"MultiPath",
"open_file_for_read",
"open_file_for_write",
"recursively_list_files",
"SmashedWarnings",
"WhitespaceSplitter",
]
15 changes: 14 additions & 1 deletion src/smashed/utils/install_blingfire_macos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#! /usr/bin/env python3


import platform
from subprocess import call
from warnings import warn

BASH_SCRIPT = """
#! /usr/bin/env bash
Expand Down Expand Up @@ -39,7 +42,17 @@


def main():
call(BASH_SCRIPT.strip(), shell=True)
# check if we are on MacOS
if platform.system() != "Darwin":
warn("This script is only meant to be run on MacOS; skipping...")
return

# check that architecture is arm64
if platform.machine() != "arm64":
warn("This script is only meant to be run on arm64; skipping...")
return

return call(BASH_SCRIPT.strip(), shell=True)


if __name__ == "__main__":
Expand Down
32 changes: 32 additions & 0 deletions src/smashed/utils/io_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .closures import upload_on_success
from .compression import compress_stream, decompress_stream
from .multipath import MultiPath
from .operations import (
copy_directory,
exists,
is_dir,
is_file,
open_file_for_read,
open_file_for_write,
recursively_list_files,
remove_directory,
remove_file,
stream_file_for_read,
)

__all__ = [
"compress_stream",
"copy_directory",
"decompress_stream",
"exists",
"is_dir",
"is_file",
"MultiPath",
"open_file_for_read",
"open_file_for_write",
"recursively_list_files",
"remove_directory",
"remove_file",
"stream_file_for_read",
"upload_on_success",
]
107 changes: 107 additions & 0 deletions src/smashed/utils/io_utils/closures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from contextlib import AbstractContextManager, ExitStack
from functools import partial
from tempfile import TemporaryDirectory
from typing import Callable, Optional, TypeVar

from typing_extensions import Concatenate, ParamSpec

from .multipath import MultiPath
from .operations import PathType, copy_directory, remove_directory

T = TypeVar("T")
P = ParamSpec("P")


class upload_on_success(AbstractContextManager):
"""Context manager to upload a directory of results to a remote
location if the execution in the context manager is successful.
You can use this class in two ways:
1. As a context manager
```python
with upload_on_success('s3://my-bucket/my-results') as path:
# run training, save temporary results in `path`
...
```
2. As a function decorator
```python
@upload_on_success('s3://my-bucket/my-results')
def my_function(path: str, ...)
# run training, save temporary results in `path`
```
You can specify a local destination by passing `local_path` to
`upload_on_success`. Otherwise, a temporary directory is created for you.
"""

def __init__(
self,
remote_path: PathType,
local_path: Optional[PathType] = None,
keep_local: bool = False,
) -> None:
"""Constructor for upload_on_success context manager
Args:
remote_path (str or urllib.parse.ParseResult): The remote location
to upload to (e.g., an S3 prefix for a bucket you have
access to).
local_path (str or Path): The local path where to temporarily
store files before upload. If not provided, a temporary
directory is created for you and returned by the context
manager. It will be deleted at the end of the context
(unless keep_local is set to True). Defaults to None
keep_local (bool, optional): Whether to keep the local results
as well as uploading to the remote path. Only available
if `local_path` is provided.
"""

self._ctx = ExitStack()
self.remote_path = remote_path
self.local_path = MultiPath.parse(
local_path or self._ctx.enter_context(TemporaryDirectory())
)
if local_path is None and keep_local:
raise ValueError(
"Cannot keep local destination if `local_path` is None"
)
self.keep_local = keep_local

super().__init__()

def _decorated(
self,
func: Callable[Concatenate[str, P], T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
with type(self)(
local_path=self.local_path,
remote_path=self.remote_path,
keep_local=self.keep_local,
) as path:
output = func(path.as_str, *args, **kwargs)
return output

def __call__(
self, func: Callable[Concatenate[str, P], T]
) -> Callable[P, T]:
return partial(self._decorated, func=func) # type: ignore

def __enter__(self):
return self.local_path

def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
# all went well, so we copy the local directory to the remote
copy_directory(src=self.local_path, dst=self.remote_path)

if not self.keep_local:
remove_directory(self.local_path)

self._ctx.close()
66 changes: 66 additions & 0 deletions src/smashed/utils/io_utils/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import gzip as gz
import io
from contextlib import contextmanager
from typing import IO, Iterator, Literal, Optional, cast

from .io_wrappers import BytesZLibDecompressorIO, TextZLibDecompressorIO


@contextmanager
def decompress_stream(
stream: IO,
mode: Literal["r", "rt", "rb"] = "rt",
encoding: Optional[str] = "utf-8",
errors: str = "strict",
chunk_size: int = io.DEFAULT_BUFFER_SIZE,
gzip: bool = True,
) -> Iterator[IO]:
out: io.IOBase

if mode == "rb" or mode == "r":
out = BytesZLibDecompressorIO(
stream=stream, chunk_size=chunk_size, gzip=gzip
)
elif mode == "rt":
assert encoding is not None, "encoding must be provided for text mode"
out = TextZLibDecompressorIO(
stream=stream,
chunk_size=chunk_size,
gzip=gzip,
encoding=encoding,
errors=errors,
)
else:
raise ValueError(f"Unsupported mode: {mode}")

# cast to IO to satisfy mypy, then yield
yield cast(IO, out)

# Flush and close the stream
out.close()


@contextmanager
def compress_stream(
stream: IO,
mode: Literal["w", "wt", "wb"] = "wt",
encoding: Optional[str] = "utf-8",
errors: str = "strict",
gzip: bool = True,
) -> Iterator[IO]:

assert gzip, "Only gzip compression is supported at this time"

if mode == "wb" or mode == "w":
out = gz.open(stream, mode=mode)
elif mode == "wt":
assert encoding is not None, "encoding must be provided for text mode"
out = gz.open(stream, mode=mode, encoding=encoding, errors=errors)
else:
raise ValueError(f"Unsupported mode: {mode}")

# cast to IO to satisfy mypy, then yield
yield cast(IO, out)

# Flush and close the stream
out.close()
Loading

0 comments on commit 5fd542c

Please sign in to comment.