Skip to content

Commit

Permalink
Add FileBasedRestartableBatchDataProcessor. (#154)
Browse files Browse the repository at this point in the history
* fix torch dep. to work with intel Mac

* add v1 impl. of FileBasedRestartableBatchDataProcessor

* add tests and fixes

* update copyright year

* update copyright year

* update lint config

* add doc

* fix lint

* fix typo

* improve doc

* add load_data

* fix lint

* improve typing
  • Loading branch information
PhilipMay authored May 13, 2024
1 parent 7a99dff commit dfd992e
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 48 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ other-src := tests docs
check:
poetry run black $(src) $(other-src) --check --diff
poetry run mypy --install-types --non-interactive $(src) $(other-src)
poetry run ruff $(src) $(other-src)
poetry run ruff check $(src) $(other-src)
poetry run mdformat --check --number .
poetry run make -C docs clean doctest

format:
poetry run black $(src) $(other-src)
poetry run ruff $(src) $(other-src) --fix
poetry run ruff check $(src) $(other-src) --fix
poetry run mdformat --number .

test:
Expand Down
129 changes: 127 additions & 2 deletions mltb2/files.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023 Philip May
# Copyright (c) 2023-2024 Philip May
# Copyright (c) 2023-2024 Philip May, Deutsche Telekom AG
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

Expand All @@ -13,8 +14,14 @@


import contextlib
import gzip
import json
import os
from typing import Optional
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Set
from uuid import uuid4

from platformdirs import user_data_dir
from sklearn.datasets._base import RemoteFileMetadata, _fetch_remote
Expand Down Expand Up @@ -64,3 +71,121 @@ def fetch_remote_file(dirname, filename, url: str, sha256_checksum: str) -> str:
os.remove(os.path.join(dirname, filename))
raise
return fetch_remote_file_path


@dataclass
class FileBasedRestartableBatchDataProcessor:
"""Batch data processor which supports restartability and is backed by files.
Args:
data: The data to process.
batch_size: The batch size.
uuid_name: The name of the uuid field in the data.
result_dir: The directory where the results are stored.
"""

data: List[Dict[str, Any]]
batch_size: int
uuid_name: str
result_dir: str
_result_dir_path: Path = field(init=False, repr=False)
_own_lock_uuids: Set[str] = field(init=False, repr=False, default_factory=set)

def __post_init__(self) -> None:
"""Do post init."""
# check that batch size is > 0
if self.batch_size <= 0:
raise ValueError("batch_size must be > 0!")

if not len(self.data) > 0:
raise ValueError("data must not be empty!")

uuids: Set[str] = set()

# check uuid_name
for idx, d in enumerate(self.data):
if self.uuid_name not in d:
raise ValueError(f"uuid_name '{self.uuid_name}' not available in data at index {idx}!")
uuid = d[self.uuid_name]
if not isinstance(uuid, str):
raise TypeError(f"uuid '{uuid}' at index {idx} is not a string!")
if len(uuid) == 0:
raise ValueError(f"uuid '{uuid}' at index {idx} is empty!")
uuids.add(uuid)

if len(uuids) != len(self.data):
raise ValueError("uuids are not unique!")

# create and check _result_dir_path
self._result_dir_path = Path(self.result_dir)
self._result_dir_path.mkdir(parents=True, exist_ok=True) # create directory if not available
if not self._result_dir_path.is_dir():
raise ValueError(f"Faild to create or find result_dir '{self.result_dir}'!")

def _get_locked_or_done_uuids(self) -> Set[str]:
locked_or_done_uuids: Set[str] = set()
for child_path in self._result_dir_path.iterdir():
if child_path.is_file():
filename = child_path.name
if filename.endswith(".lock"):
uuid = filename[: filename.rindex(".lock")]
elif filename.endswith(".json.gz") and "_" in filename:
uuid = filename[: filename.rindex("_")]
locked_or_done_uuids.add(uuid)
return locked_or_done_uuids

def _write_lock_files(self, batch: Sequence[Dict[str, Any]]) -> None:
for d in batch:
uuid = d[self.uuid_name]
(self._result_dir_path / f"{uuid}.lock").touch()
self._own_lock_uuids.add(uuid)

def read_batch(self) -> Sequence[Dict[str, Any]]:
"""Read the next batch of data."""
locked_or_done_uuids: Set[str] = self._get_locked_or_done_uuids()
remaining_data = [d for d in self.data if d[self.uuid_name] not in locked_or_done_uuids]
random.shuffle(remaining_data)
next_batch_size = min(self.batch_size, len(remaining_data))
next_batch = remaining_data[:next_batch_size]
self._write_lock_files(next_batch)
return next_batch

def _save_batch_data(self, batch: Sequence[Dict[str, Any]]) -> None:
for d in batch:
uuid = d[self.uuid_name]
if uuid not in self._own_lock_uuids:
raise ValueError(f"uuid '{uuid}' not locked by me!")
filename = self._result_dir_path / f"{uuid}_{str(uuid4())}.json.gz" # noqa: RUF010
with gzip.GzipFile(filename, "w") as outfile:
outfile.write(json.dumps(d).encode("utf-8"))

def _remove_lock_files(self, batch: Sequence[Dict[str, Any]]) -> None:
for d in batch:
uuid = d[self.uuid_name]
(self._result_dir_path / f"{uuid}.lock").unlink(missing_ok=True)
self._own_lock_uuids.discard(uuid)

def save_batch(self, batch: Sequence[Dict[str, Any]]) -> None:
"""Save the batch of data."""
self._save_batch_data(batch)
self._remove_lock_files(batch)

@staticmethod
def load_data(result_dir: str) -> List[Dict[str, Any]]:
"""Load all data.
After all data is processed, this method can be used to load all data.
Args:
result_dir: The directory where the results are stored.
"""
_result_dir_path = Path(result_dir)
if not _result_dir_path.is_dir():
raise ValueError(f"Did not find result_dir '{result_dir}'!")

data = []
for child_path in _result_dir_path.iterdir():
if child_path.is_file() and child_path.name.endswith(".json.gz"):
with gzip.GzipFile(child_path, "r") as infile:
data.append(json.loads(infile.read().decode("utf-8")))
return data
20 changes: 14 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ fasttext-wheel = {version = "*", optional = true}
optuna = {version = "*", optional = true}
matplotlib = {version = "*", optional = true}
SoMaJo = {version = ">=2.4.1", optional = true}
torch = {version = "!=2.0.1,!=2.1.0", optional = true} # some versions have poetry issues

# some versions have poetry issues
# 2.3.0 does not work with Intel Mac
torch = {version = "!=2.0.1,!=2.1.0,!=2.3.0", optional = true}

transformers = {version = "*", optional = true}
tiktoken = {version = "*", optional = true}
safetensors = {version = "!=0.3.2", optional = true} # version 0.3.2 has poetry issues
Expand Down Expand Up @@ -121,10 +125,13 @@ line-length = 119
target-version = ["py38", "py39", "py310", "py311"]

[tool.ruff]
select = ["ALL"]
line-length = 119
fixable = ["I"]
target-version = "py38"


[tool.ruff.lint]
select = ["ALL"]
fixable = ["I"]
ignore = [
"DJ", # flake8-django - https://docs.astral.sh/ruff/rules/#flake8-django-dj
"ERA", # eradicate - https://docs.astral.sh/ruff/rules/#eradicate-era
Expand Down Expand Up @@ -157,18 +164,19 @@ ignore = [
"RUF015", # Prefer `next(iter(sentences))` over single element slice
]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"tests/**/test_*.py" = [
"D100", # Missing docstring in public module
"D103", # Missing docstring in public function
"PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable
"S101", # Use of assert detected
"N802", # Function name should be lowercase
]

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.flake8-copyright]
[tool.ruff.lint.flake8-copyright]
notice-rgx = "(# Copyright \\(c\\) \\d{4}.*\\n)+# This software is distributed under the terms of the MIT license\\n# which is available at https://opensource.org/licenses/MIT\\n\\n"

[tool.mypy]
Expand Down
135 changes: 133 additions & 2 deletions tests/test_files.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) 2023 Philip May
# Copyright (c) 2023-2024 Philip May
# Copyright (c) 2023-2024 Philip May, Deutsche Telekom AG
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

import os
from uuid import uuid4

import pytest

from mltb2.files import fetch_remote_file, get_and_create_mltb2_data_dir
from mltb2.files import FileBasedRestartableBatchDataProcessor, fetch_remote_file, get_and_create_mltb2_data_dir


def test_fetch_remote_file(tmpdir):
Expand Down Expand Up @@ -37,3 +39,132 @@ def test_get_and_create_mltb2_data_dir(tmpdir):
mltb2_data_dir = get_and_create_mltb2_data_dir(tmpdir)

assert mltb2_data_dir == os.path.join(tmpdir, "mltb2")


def test_FileBasedRestartableBatchDataProcessor_batch_size(tmp_path):
result_dir = tmp_path.absolute()
with pytest.raises(ValueError):
_ = FileBasedRestartableBatchDataProcessor(data=[], batch_size=0, uuid_name="uuid", result_dir=result_dir)


def test_FileBasedRestartableBatchDataProcessor_empty_data(tmp_path):
result_dir = tmp_path.absolute()
with pytest.raises(ValueError):
_ = FileBasedRestartableBatchDataProcessor(data=[], batch_size=10, uuid_name="uuid", result_dir=result_dir)


def test_FileBasedRestartableBatchDataProcessor_uuid_in_data(tmp_path):
result_dir = tmp_path.absolute()
with pytest.raises(ValueError):
_ = FileBasedRestartableBatchDataProcessor(
data=[{"x": 10}], batch_size=10, uuid_name="uuid", result_dir=result_dir
)


def test_FileBasedRestartableBatchDataProcessor_uuid_type(tmp_path):
result_dir = tmp_path.absolute()
with pytest.raises(TypeError):
_ = FileBasedRestartableBatchDataProcessor(
data=[{"uuid": 6, "x": 10}], batch_size=10, uuid_name="uuid", result_dir=result_dir
)


def test_FileBasedRestartableBatchDataProcessor_uuid_empty(tmp_path):
result_dir = tmp_path.absolute()
with pytest.raises(ValueError):
_ = FileBasedRestartableBatchDataProcessor(
data=[{"uuid": "", "x": 10}], batch_size=10, uuid_name="uuid", result_dir=result_dir
)


def test_FileBasedRestartableBatchDataProcessor_uuid_unique(tmp_path):
result_dir = tmp_path.absolute()
data = [{"uuid": "a", "x": 10}, {"uuid": "a", "x": 10}, {"uuid": "c", "x": 10}]
with pytest.raises(ValueError):
_ = FileBasedRestartableBatchDataProcessor(data=data, batch_size=10, uuid_name="uuid", result_dir=result_dir)


def test_FileBasedRestartableBatchDataProcessor_write_lock_files(tmp_path):
result_dir = tmp_path.absolute()
batch_size = 10
data = [{"uuid": str(uuid4()), "x": i} for i in range(100)]
data_processor = FileBasedRestartableBatchDataProcessor(
data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir
)
data = data_processor.read_batch()

assert len(data) == batch_size

# check lock files
lock_files = list(tmp_path.glob("*.lock"))
assert len(lock_files) == batch_size


def test_FileBasedRestartableBatchDataProcessor_save_batch_data(tmp_path):
result_dir = tmp_path.absolute()
batch_size = 10
data = [{"uuid": str(uuid4()), "x": i} for i in range(100)]
data_processor = FileBasedRestartableBatchDataProcessor(
data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir
)
data = data_processor.read_batch()
data_processor.save_batch(data)

# check lock files
lock_files = list(tmp_path.glob("*.json.gz"))
assert len(lock_files) == batch_size


def test_FileBasedRestartableBatchDataProcessor_remove_lock_files(tmp_path):
result_dir = tmp_path.absolute()
batch_size = 10
data = [{"uuid": str(uuid4()), "x": i} for i in range(100)]
data_processor = FileBasedRestartableBatchDataProcessor(
data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir
)
data = data_processor.read_batch()
data_processor.save_batch(data)

# check lock files
lock_files = list(tmp_path.glob("*.lock"))
assert len(lock_files) == 0


def test_FileBasedRestartableBatchDataProcessor_save_unlocked(tmp_path):
result_dir = tmp_path.absolute()
batch_size = 10
data = [{"uuid": str(uuid4()), "x": i} for i in range(100)]
data_processor = FileBasedRestartableBatchDataProcessor(
data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir
)
data = data_processor.read_batch()
data[0]["uuid"] = "something_else"
with pytest.raises(ValueError):
data_processor.save_batch(data)


def test_FileBasedRestartableBatchDataProcessor_load_data(tmp_path):
result_dir = tmp_path.absolute()
batch_size = 10
data = [{"uuid": str(uuid4()), "x": i} for i in range(100)]
data_processor = FileBasedRestartableBatchDataProcessor(
data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir
)

# process all data
while True:
_data = data_processor.read_batch()
if len(_data) == 0:
break
data_processor.save_batch(_data)

del data_processor
processed_data = FileBasedRestartableBatchDataProcessor.load_data(result_dir)

assert len(processed_data) == len(data)
for d in processed_data:
assert "uuid" in d
assert "x" in d
assert isinstance(d["uuid"], str)
assert isinstance(d["x"], int)
assert d["x"] < 100
6 changes: 3 additions & 3 deletions tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_chunk_md():
assert result[2] == "### Headline 3 / 1\n\n#### Headline 4 / 1\n\nContent."


def test_MdTextSplitter_call(): # noqa: N802
def test_MdTextSplitter_call():
transformers_token_counter = TransformersTokenCounter("deepset/gbert-base")
text_merger = MdTextSplitter(
max_token=15,
Expand All @@ -63,7 +63,7 @@ def test_MdTextSplitter_call(): # noqa: N802
assert merged_md[1] == "### Headline 3 / 1\n\n#### Headline 4 / 1\n\nContent."


def test_MdTextSplitter_call_no_merge(): # noqa: N802
def test_MdTextSplitter_call_no_merge():
transformers_token_counter = TransformersTokenCounter("deepset/gbert-base")
text_merger = MdTextSplitter(
max_token=1,
Expand All @@ -78,7 +78,7 @@ def test_MdTextSplitter_call_no_merge(): # noqa: N802
assert merged_md[2] == "### Headline 3 / 1\n\n#### Headline 4 / 1\n\nContent."


def test_MdTextSplitter_call_all_merge(): # noqa: N802
def test_MdTextSplitter_call_all_merge():
transformers_token_counter = TransformersTokenCounter("deepset/gbert-base")
text_merger = MdTextSplitter(
max_token=1000,
Expand Down
Loading

0 comments on commit dfd992e

Please sign in to comment.