diff --git a/examples/squad.py b/examples/squad.py index 44e3ad9..c976c9a 100644 --- a/examples/squad.py +++ b/examples/squad.py @@ -14,7 +14,7 @@ ) >> sm.TextToWordsMapper( fields=["question", "context", "answers"], - splitter="whitespace", + splitter="ws", ) >> sm.SingleSequenceStriderMapper( field_to_stride=["context"], diff --git a/pyproject.toml b/pyproject.toml index 23300b2..ae2e9df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "smashed" -version = "0.16.0" +version = "0.17.0" description = """\ SMASHED is a toolkit designed to apply transformations to samples in \ datasets, such as fields extraction, tokenization, prompting, batching, \ diff --git a/src/smashed/mappers/text.py b/src/smashed/mappers/text.py index d5de999..5aab433 100644 --- a/src/smashed/mappers/text.py +++ b/src/smashed/mappers/text.py @@ -9,6 +9,7 @@ BlingFireSplitter, WhitespacePlusSplitter, WhitespaceSplitter, + WhitespaceTrailSplitter, ) @@ -69,15 +70,15 @@ class TextToWordsMapper(SingleBaseMapper): def __init__( self, fields: Union[str, Sequence[str]], - splitter: Literal[ - "blingfire", "whitespace", "whitespace_plus" - ] = "whitespace_plus", + splitter: Literal["blingfire", "ws", "plus", "trail"] = "plus", ): if splitter == "blingfire": self.splitter = BlingFireSplitter() - elif splitter == "whitespace_plus": + elif splitter == "plus": self.splitter = WhitespacePlusSplitter() - elif splitter == "whitespace": + elif splitter == "trail": + self.splitter = WhitespaceTrailSplitter() + elif splitter == "ws": self.splitter = WhitespaceSplitter() else: raise ValueError(f"Unknown splitter: {splitter}") diff --git a/src/smashed/utils/io_utils.py b/src/smashed/utils/io_utils.py index c6b108b..85e4f4c 100644 --- a/src/smashed/utils/io_utils.py +++ b/src/smashed/utils/io_utils.py @@ -1,5 +1,7 @@ +import re import shutil from contextlib import AbstractContextManager, ExitStack, contextmanager +from dataclasses import dataclass from functools import partial from logging import INFO, Logger, getLogger from os import remove as remove_local_file @@ -19,7 +21,7 @@ TypeVar, Union, ) -from urllib.parse import ParseResult, urlparse +from urllib.parse import urlparse from necessary import necessary from typing_extensions import Concatenate, ParamSpec @@ -40,27 +42,134 @@ "upload_on_success", ] -PathType = Union[str, Path, ParseResult] +PathType = Union[str, Path, "MultiPath"] ClientType = Union["BaseClient", None] -def uri_stringify(uri: PathType) -> str: - """Convert a URI to a string.""" - if isinstance(uri, str): - return uri +@dataclass +class MultiPath: + """A path object that can handle both local and remote paths.""" - if isinstance(uri, Path): - return str(uri) + prot: str + root: str + path: str - if isinstance(uri, ParseResult): - return uri.geturl() + def __post_init__(self): + SUPPORTED_PROTOCOLS = {"s3", "file"} + if self.prot and self.prot not in SUPPORTED_PROTOCOLS: + raise ValueError( + f"Unsupported protocol: {self.prot}; " + f"supported protocols are {SUPPORTED_PROTOCOLS}" + ) + @classmethod + def parse(cls, path: PathType) -> "MultiPath": + """Parse a path into a PathParser object. -def join_uri(*uris: PathType) -> str: - """Join a URI.""" - first, *rest, last = map(uri_stringify, uris) - rest = [part.strip("/") for part in rest] - return "/".join([first.rstrip("/"), *rest, last.lstrip("/")]) + Args: + path (str): The path to parse. + """ + if isinstance(path, cls): + return path + + p = urlparse(str(path)) + return cls(prot=p.scheme, root=p.netloc, path=p.path) + + @property + def is_s3(self) -> bool: + """Is true if the path is an S3 path.""" + return self.prot == "s3" + + @property + def is_local(self) -> bool: + """Is true if the path is a local path.""" + return self.prot == "file" or self.prot == "" + + def _remove_extra_slashes(self, path: str) -> str: + return re.sub(r"//+", "/", path) + + def __str__(self) -> str: + if self.prot: + loc = self._remove_extra_slashes(f"{self.root}/{self.path}") + return f"{self.prot}://{loc}" + elif self.root: + return self._remove_extra_slashes(f"/{self.root}/{self.path}") + else: + return self._remove_extra_slashes(self.path) + + @property + def bucket(self) -> str: + """If the path is an S3 path, return the bucket name. + Otherwise, raise a ValueError.""" + if not self.is_s3: + raise ValueError(f"Not an S3 path: {self}") + return self.root + + @property + def key(self) -> str: + """If the path is an S3 path, return the prefix. + Otherwise, raise a ValueError.""" + if not self.is_s3: + raise ValueError(f"Not an S3 path: {self}") + return self.path.lstrip("/") + + @property + def as_path(self) -> Path: + """Return the path as a pathlib.Path object.""" + if not self.is_local: + raise ValueError(f"Not a local path: {self}") + return Path(self.as_str) + + def __hash__(self) -> int: + return hash(self.as_str) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, (MultiPath, str, Path)): + return False + + other = MultiPath.parse(other) + return self.as_str == other.as_str + + @property + def as_str(self) -> str: + """Return the path as a string.""" + return str(self) + + def __truediv__(self, other: PathType) -> "MultiPath": + """Join two paths together using the / operator.""" + other = MultiPath.parse(other) + + if isinstance(other, MultiPath) and other.prot: + raise ValueError(f"Cannot combine fully formed path {other}") + + return MultiPath( + prot=self.prot, + root=self.root, + path=f"{self.path.rstrip('/')}/{str(other).lstrip('/')}", + ) + + def __len__(self) -> int: + return len(self.as_str) + + def __sub__(self, other: PathType) -> "MultiPath": + _o_str = MultiPath.parse(other).as_str + _s_str = self.as_str + loc = _s_str.find(_o_str) + return MultiPath.parse(_s_str[:loc] + _s_str[loc + len(_o_str) :]) + + @classmethod + def join(cls, *others: PathType) -> "MultiPath": + """Join multiple paths together; each path can be a string, + pathlib.Path, or MultiPath object.""" + if not others: + raise ValueError("No paths provided") + + first, *rest = others + first = cls.parse(first) + for part in rest: + # explicitly call __div__ to avoid mypy errors + first = first / part + return first def get_logger() -> Logger: @@ -70,13 +179,11 @@ def get_logger() -> Logger: def get_client_if_needed(path: PathType) -> ClientType: - parse = ( - urlparse(uri_stringify(path)) - if not isinstance(path, ParseResult) - else path - ) + """Return the appropriate client given the protocol of the path.""" + + path = MultiPath.parse(path) - if parse.scheme == "s3": + if path.is_s3: # necessary here will raise an error if boto3 is not installed. with necessary( "boto3", @@ -86,27 +193,26 @@ def get_client_if_needed(path: PathType) -> ClientType: ), ): return boto3.client("s3") # pyright: ignore - elif parse.scheme == "file" or parse.scheme == "": - return None # pyright: ignore - else: - raise ValueError(f"Unsupported scheme {parse.scheme}") + + return None # pyright: ignore @contextmanager def open_file_for_read( - path: Union[str, Path], + path: PathType, mode: str = "r", open_fn: Optional[Callable] = None, logger: Optional[Logger] = None, open_kwargs: Optional[Dict[str, Any]] = None, client: Optional[ClientType] = None, ) -> Generator[IO, None, None]: - """Get a context manager to read in a file that is either on - S3 or local. + """Get a context manager to read in a file that is either in a local + or remote location. If the path is a remote path, the file will be + downloaded to a temporary location and then deleted after the context + manager exits. Args: - path (Union[str, Path]): The path to the file to read. Can be an S3 - or local path. + path (Union[str, Path, MultiPath]): The path to the file to read. mode (str, optional): The mode to open the file in. Defaults to "r". Only read modes are supported (e.g. 'rb', 'rt', 'r'). open_fn (Callable, optional): The function to use to open the file. @@ -119,36 +225,32 @@ def open_file_for_read( open_kwargs = open_kwargs or {} logger = logger or get_logger() open_fn = open_fn or open - parse = urlparse(str(path)) remove = False assert "r" in mode, "Only read mode is supported" - if parse.scheme == "s3": + path = MultiPath.parse(path) + + if path.is_s3: client = client or get_client_if_needed(path) assert client is not None, "Could not get S3 client" logger.info(f"Downloading {path} to a temporary file") with NamedTemporaryFile(delete=False) as f: - path = f.name - client.download_fileobj(parse.netloc, parse.path.lstrip("/"), f) + client.download_fileobj(path.bucket, path.key.lstrip("/"), f) + path = MultiPath.parse(f.name) remove = True - elif parse.scheme == "file" or parse.scheme == "": - pass - else: - raise ValueError(f"Unsupported scheme {parse.scheme}") - try: - with open_fn(file=path, mode=mode, **open_kwargs) as f: + with open_fn(file=str(path), mode=mode, **open_kwargs) as f: yield f finally: if remove: - remove_local_file(path) + remove_local_file(str(path)) @contextmanager def open_file_for_write( - path: Union[str, Path], + path: PathType, mode: str = "w", skip_if_empty: bool = False, open_fn: Optional[Callable] = None, @@ -156,12 +258,13 @@ def open_file_for_write( open_kwargs: Optional[Dict[str, Any]] = None, client: Optional[ClientType] = None, ) -> Generator[IO, None, None]: - """Get a context manager to write to a file that is either on - S3 or local. + """Get a context manager to write to a file. If the file is from a + remote location (e.g. S3), the file will be written to a temporary + file and then uploaded to the remote location; after the context + manager exits, the temporary file will be deleted. Args: - path (Union[str, Path]): The path to the file to write. Can be local - or an S3 path. + path (Union[str, Path, MultiPath]): The path to the file to write. mode (str, optional): The mode to open the file in. Defaults to "w". Only read modes are supported (e.g. 'wb', 'w', ...). open_fn (Callable, optional): The function to use to open the file. @@ -173,91 +276,90 @@ def open_file_for_write( """ path = str(path) - parse = urlparse(path) local = None logger = logger or get_logger() open_fn = open_fn or open open_kwargs = open_kwargs or {} + path = MultiPath.parse(path) + assert "w" in mode or "a" in mode, "Only write/append mode is supported" try: - if parse.scheme == "file" or parse.scheme == "": + if path.is_local: # make enclosing directory if it doesn't exist - Path(path).parent.mkdir(parents=True, exist_ok=True) + path.as_path.parent.mkdir(parents=True, exist_ok=True) - with open_fn(file=path, mode=mode, **open_kwargs) as f: + with open_fn(file=str(path), mode=mode, **open_kwargs) as f: yield f else: with NamedTemporaryFile(delete=False, mode=mode) as f: yield f - local = f.name + local = MultiPath.parse(f.name) finally: if local is None: - if skip_if_empty and stat_local_file(path).st_size == 0: + if skip_if_empty and stat_local_file(path.as_str).st_size == 0: logger.info(f"Skipping empty file {path}") - remove_local_file(path) - elif parse.scheme == "s3": - dst = f'{parse.netloc}{parse.path.lstrip("/")}' - if skip_if_empty and stat_local_file(local).st_size == 0: - logger.info(f"Skipping upload to {dst} since {local} is empty") + remove_local_file(path.as_path) + elif path.is_s3: + # dst = f'{path.bucket}{parse.path.lstrip("/")}' + if skip_if_empty and stat_local_file(local.as_str).st_size == 0: + logger.info(f"Skipping upload to {path}: {local} is empty") else: - logger.info(f"Uploading {local} to {dst}") + logger.info(f"Uploading {local} to {path}") client = client or get_client_if_needed(path) assert client is not None, "Could not get S3 client" - client.upload_file(local, parse.netloc, parse.path.lstrip("/")) - remove_local_file(local) - else: - raise ValueError(f"Unsupported scheme {parse.scheme}") + client.upload_file(local, path.bucket, path.key.lstrip("/")) + remove_local_file(local.as_path) def recursively_list_files( - path: Union[str, Path], + path: MultiPath, ignore_hidden_files: bool = True, client: Optional[ClientType] = None, -) -> Iterable[str]: - """Recursively list all files in the given directory on network prefix +) -> Iterable[MultiPath]: + """Recursively list all files in the given directory for a given + path, local or remote. Args: - path (Union[str, Path]): The path to list content at. Can be local - or an S3 path. + path (Union[str, Path, MultiPath]): The path to list content at. ignore_hidden_files (bool, optional): Whether to ignore hidden files (i.e. files that start with a dot) when listing. Defaults to True. """ - path = str(path) - parse = urlparse(path) + path = MultiPath.parse(path) - if parse.scheme == "s3": + if path.is_s3: client = client or get_client_if_needed(path) assert client is not None, "Could not get S3 client" - prefixes = [parse.path.lstrip("/")] + prefixes = [path.key.lstrip("/")] while len(prefixes) > 0: prefix = prefixes.pop() paginator = client.get_paginator("list_objects_v2") - pages = paginator.paginate(Bucket=parse.netloc, Prefix=prefix) + pages = paginator.paginate(Bucket=path.bucket, Prefix=prefix) for page in pages: for obj in page["Contents"]: if obj["Key"][-1] == "/": prefixes.append(obj["Key"]) else: - yield f's3://{parse.netloc}/{obj["Key"]}' + yield MultiPath( + prot="s3", root=path.root, path=obj["Key"] + ) - elif parse.scheme == "file" or parse.scheme == "": - for root, _, files in local_walk(parse.path): + if path.is_local: + for _root, _, files in local_walk(path.as_str): + root = Path(_root) for f in files: if ignore_hidden_files and f.startswith("."): continue - yield join_uri(root, f) - else: - raise NotImplementedError(f"Unknown scheme: {parse.scheme}") + yield MultiPath.parse(root / f) def copy_directory( - src: Union[str, Path], - dst: Union[str, Path], + src: PathType, + dst: PathType, ignore_hidden_files: bool = False, skip_if_empty: bool = False, logger: Optional[Logger] = None, @@ -267,9 +369,8 @@ def copy_directory( locations can be local, remote, or a mix of both. Args: - src (Union[str, Path]): The location to copy from. Can be local - or a location on S3. - dst (Union[str, Path]): The location to copy to. Can be local or S3. + src (Union[str, Path, MultiPath]): The location to copy from. + dst (Union[str, Path, MultiPath]): The location to copy to. ignore_hidden_files (bool, optional): Whether to ignore hidden files on copy. Defaults to True. logger (Logger, optional): The logger to use. Defaults to the built-in @@ -280,18 +381,18 @@ def copy_directory( # we convert to string because the Path library does not handle # well network locations. - src = str(src) - dst = str(dst) + src = MultiPath.parse(src) + dst = MultiPath.parse(dst) cnt = 0 client = client or get_client_if_needed(src) or get_client_if_needed(dst) for source_path in recursively_list_files( - str(src), ignore_hidden_files=ignore_hidden_files + path=src, ignore_hidden_files=ignore_hidden_files ): - # we strip the segment of source_path that is the common prefix in src, - # then join the remaining bit - destination = join_uri(dst, source_path[len(src) :]) + # we strip the segment of source_path that is the + # common prefix in src, then join the remaining bit + destination = dst / (source_path - src) logger.info(f"Copying {source_path} to {destination}; {cnt:,} so far") @@ -312,30 +413,26 @@ def copy_directory( cnt += 1 -def remove_file(path: Union[str, Path], client: Optional[ClientType] = None): +def remove_file(path: PathType, client: Optional[ClientType] = None): """Remove a file at the provided path.""" - path = str(path) - parse = urlparse(path) + path = MultiPath.parse(path) - if parse.scheme == "s3": + if path.is_s3: client = client or get_client_if_needed(path) assert client is not None, "Could not get S3 client" - client.delete_object(Bucket=parse.netloc, Key=parse.path.lstrip("/")) - elif parse.scheme == "file" or parse.scheme == "": - remove_local_file(path) - else: - raise NotImplementedError(f"Unknown scheme: {parse.scheme}") + client.delete_object(Bucket=path.bucket, Key=path.key.lstrip("/")) + if path.is_local: + remove_local_file(path.as_path) -def remove_directory( - path: Union[str, Path], client: Optional[ClientType] = None -): + +def remove_directory(path: PathType, client: Optional[ClientType] = None): """Completely remove a directory at the provided path.""" - parse = urlparse(str(path)) + path = MultiPath.parse(path) - if parse.scheme == "s3": + if path.is_s3: client = client or get_client_if_needed(path) assert client is not None, "Could not get S3 client" @@ -343,10 +440,9 @@ def remove_directory( path=path, ignore_hidden_files=False, client=client ): remove_file(fn, client=client) - elif parse.scheme == "file" or parse.scheme == "": - shutil.rmtree(path, ignore_errors=True) - else: - raise NotImplementedError(f"Unknown scheme: {parse.scheme}") + + if path.is_local: + shutil.rmtree(path.as_str, ignore_errors=True) T = TypeVar("T") @@ -404,10 +500,8 @@ def __init__( self._ctx = ExitStack() self.remote_path = remote_path - self.local_path = ( - uri_stringify(local_path) - if local_path is not None - else self._ctx.enter_context(TemporaryDirectory()) + self.local_path = MultiPath.parse( + local_path or self._ctx.enter_context(TemporaryDirectory()) ) if local_path is None and keep_local: raise ValueError( @@ -428,7 +522,7 @@ def _decorated( remote_path=self.remote_path, keep_local=self.keep_local, ) as path: - output = func(path, *args, **kwargs) + output = func(path.as_str, *args, **kwargs) return output def __call__( @@ -442,9 +536,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): if exc_type is None: # all went well, so we copy the local directory to the remote - copy_directory( - src=self.local_path, dst=self.remote_path # pyright: ignore - ) + copy_directory(src=self.local_path, dst=self.remote_path) if not self.keep_local: remove_directory(self.local_path) diff --git a/src/smashed/utils/wordsplitter.py b/src/smashed/utils/wordsplitter.py index d178065..21f8297 100644 --- a/src/smashed/utils/wordsplitter.py +++ b/src/smashed/utils/wordsplitter.py @@ -8,7 +8,12 @@ from blingfire import text_to_words -__all__ = ["WhitespaceSplitter", "BlingFireSplitter"] +__all__ = [ + "WhitespaceSplitter", + "BlingFireSplitter", + "WhitespacePlusSplitter", + "WhitespaceTrailSplitter", +] class BaseWordSplitter: @@ -53,3 +58,14 @@ class WhitespacePlusSplitter(WhitespaceSplitter): def __init__(self, language: str = "en"): super().__init__(language) self.tokenizer = Whitespace() + + +class WhitespaceTrailSplitter(WhitespacePlusSplitter): + def tokenize(self, text: str) -> List[str]: + # the start of each token + locs = [s for _, (s, _) in self.tokenizer.pre_tokenize_str(text)] + + # we include any trailing whitespace in the token + return [text[locs[i] : locs[i + 1]] for i in range(len(locs) - 1)] + [ + text[locs[-1] :] + ] diff --git a/tests/test_io_utils.py b/tests/test_io_utils.py new file mode 100644 index 0000000..64151ac --- /dev/null +++ b/tests/test_io_utils.py @@ -0,0 +1,129 @@ +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + +from smashed.utils.io_utils import ( + MultiPath, + copy_directory, + recursively_list_files, + remove_directory, + remove_file, +) + + +class TestMultiPath(unittest.TestCase): + def test_parse(self): + s3_path = "s3://bucket/path/to/file" + parse = MultiPath.parse(s3_path) + self.assertEqual(parse.prot, "s3") + self.assertEqual(parse.bucket, "bucket") + self.assertEqual(parse.key, "path/to/file") + self.assertEqual(str(parse), s3_path) + + local_path = "/path/to/file" + parse = MultiPath.parse(local_path) + self.assertEqual(parse.prot, "") + self.assertEqual(str(parse), local_path) + + local_path = "path/to/file" + parse = MultiPath.parse(local_path) + self.assertEqual(parse.prot, "") + self.assertEqual(str(parse), local_path) + + local_path = "file://path/to/file" + parse = MultiPath.parse(local_path) + self.assertEqual(parse.prot, "file") + self.assertEqual(str(parse), local_path) + + gs_path = "gs://bucket/path/to/file" + with self.assertRaises(ValueError): + MultiPath.parse(gs_path) + + def test_join(self): + self.assertEqual( + MultiPath.parse("s3://bucket/path/to") / "new_file", + MultiPath.parse("s3://bucket/path/to/new_file"), + ) + + self.assertEqual( + MultiPath.parse("s3://bucket/path/to/") / "/new_file", + MultiPath.parse("s3://bucket/path/to/new_file"), + ) + + self.assertEqual( + MultiPath.join("foo", Path("bar"), MultiPath.parse("/baz")), + MultiPath.parse("foo/bar/baz"), + ) + + with self.assertRaises(ValueError): + _ = MultiPath.parse("s3://bucket/path/to") / "s3://bucket/path/to" + + def test_types(self): + s3_path = MultiPath.parse("s3://bucket/path/to/file") + self.assertTrue(s3_path.is_s3) + self.assertFalse(s3_path.is_local) + + with self.assertRaises(ValueError): + s3_path.as_path + + local_path = MultiPath.parse("/path/to/file") + self.assertFalse(local_path.is_s3) + self.assertTrue(local_path.is_local) + + with self.assertRaises(ValueError): + local_path.bucket + local_path.key + + def test_subtraction(self): + path_a = MultiPath.parse("s3://bucket/path/to/file") + path_b = MultiPath.parse("s3://bucket/") + self.assertEqual((path_a - path_b).as_str, "path/to/file") + self.assertEqual((path_b - path_a).as_str, "s3://bucket/") + + def test_local_operations(self): + with TemporaryDirectory() as tmpdir: + root_path = MultiPath.parse(tmpdir) + + # make a directory + (root_path / "d1").as_path.mkdir() + + # make some files + for file_name in ["f1", "f2"]: + (root_path / "d1" / file_name).as_path.touch() + + # make some nested directories and files + (root_path / "d1" / "d11").as_path.mkdir() + (root_path / "d1" / "d11" / "f11").as_path.touch() + + # test listing functionality + all_files = {f"{tmpdir}/d1/{f}" for f in ("f1", "f2", "d11/f11")} + + for fn in recursively_list_files(root_path / "d1"): + self.assertIn(fn, all_files) + + # test copy + copy_directory(root_path / "d1", root_path / "d2") + + all_files = {f"{tmpdir}/d2/{f}" for f in ("f1", "f2", "d11/f11")} + for fn in recursively_list_files(root_path / "d2"): + self.assertIn(fn, all_files) + + # test copy in a non-empty directory + (root_path / "d3" / "d11").as_path.mkdir(parents=True) + (root_path / "d3" / "d11" / "f11").as_path.touch() + + all_files = {f"{tmpdir}/d3/{f}" for f in ("f1", "f2", "d11/f11")} + copy_directory(root_path / "d1", root_path / "d3") + for fn in recursively_list_files(root_path / "d3"): + self.assertIn(fn, all_files) + + # test remove + remove_directory(root_path / "d1") + self.assertFalse((root_path / "d1").as_path.exists()) + + # test remove file + remove_file(root_path / "d3" / "f1") + self.assertFalse((root_path / "d3" / "f1").as_path.exists()) + + with self.assertRaises(FileNotFoundError): + remove_file(root_path / "d3" / "f1") diff --git a/tests/test_text2words.py b/tests/test_text2words.py new file mode 100644 index 0000000..96f0463 --- /dev/null +++ b/tests/test_text2words.py @@ -0,0 +1,14 @@ +from unittest import TestCase + +from smashed.mappers.text import TextToWordsMapper, WordsToTextMapper + + +class TestText2Words(TestCase): + def test_trail(self): + mapper = TextToWordsMapper( + fields="text", splitter="trail" + ) >> WordsToTextMapper(fields="text", joiner="") + text = "Hello world! What a beautiful day...\nOR NOT?" + dataset = [{"text": text}] + mapped_dataset = mapper.map(dataset) + self.assertEqual(mapped_dataset[0]["text"], text)