From dfd992e8d3e9fef145149a68186a707e872f5363 Mon Sep 17 00:00:00 2001 From: Philip May Date: Mon, 13 May 2024 12:20:18 +0200 Subject: [PATCH] Add FileBasedRestartableBatchDataProcessor. (#154) * fix torch dep. to work with intel Mac * add v1 impl. of FileBasedRestartableBatchDataProcessor * add tests and fixes * update copyright year * update copyright year * update lint config * add doc * fix lint * fix typo * improve doc * add load_data * fix lint * improve typing --- Makefile | 4 +- mltb2/files.py | 129 +++++++++++++++++++++++++++- pyproject.toml | 20 +++-- tests/test_files.py | 135 +++++++++++++++++++++++++++++- tests/test_md.py | 6 +- tests/test_openai.py | 20 ++--- tests/test_somajo.py | 28 +++---- tests/test_somajo_transformers.py | 6 +- tests/test_transformers.py | 8 +- 9 files changed, 308 insertions(+), 48 deletions(-) diff --git a/Makefile b/Makefile index b764be7..f5eba8f 100644 --- a/Makefile +++ b/Makefile @@ -4,13 +4,13 @@ other-src := tests docs check: poetry run black $(src) $(other-src) --check --diff poetry run mypy --install-types --non-interactive $(src) $(other-src) - poetry run ruff $(src) $(other-src) + poetry run ruff check $(src) $(other-src) poetry run mdformat --check --number . poetry run make -C docs clean doctest format: poetry run black $(src) $(other-src) - poetry run ruff $(src) $(other-src) --fix + poetry run ruff check $(src) $(other-src) --fix poetry run mdformat --number . test: diff --git a/mltb2/files.py b/mltb2/files.py index 1eb0e57..f8c2a03 100644 --- a/mltb2/files.py +++ b/mltb2/files.py @@ -1,4 +1,5 @@ -# Copyright (c) 2023 Philip May +# Copyright (c) 2023-2024 Philip May +# Copyright (c) 2023-2024 Philip May, Deutsche Telekom AG # This software is distributed under the terms of the MIT license # which is available at https://opensource.org/licenses/MIT @@ -13,8 +14,14 @@ import contextlib +import gzip +import json import os -from typing import Optional +import random +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Set +from uuid import uuid4 from platformdirs import user_data_dir from sklearn.datasets._base import RemoteFileMetadata, _fetch_remote @@ -64,3 +71,121 @@ def fetch_remote_file(dirname, filename, url: str, sha256_checksum: str) -> str: os.remove(os.path.join(dirname, filename)) raise return fetch_remote_file_path + + +@dataclass +class FileBasedRestartableBatchDataProcessor: + """Batch data processor which supports restartability and is backed by files. + + Args: + data: The data to process. + batch_size: The batch size. + uuid_name: The name of the uuid field in the data. + result_dir: The directory where the results are stored. + """ + + data: List[Dict[str, Any]] + batch_size: int + uuid_name: str + result_dir: str + _result_dir_path: Path = field(init=False, repr=False) + _own_lock_uuids: Set[str] = field(init=False, repr=False, default_factory=set) + + def __post_init__(self) -> None: + """Do post init.""" + # check that batch size is > 0 + if self.batch_size <= 0: + raise ValueError("batch_size must be > 0!") + + if not len(self.data) > 0: + raise ValueError("data must not be empty!") + + uuids: Set[str] = set() + + # check uuid_name + for idx, d in enumerate(self.data): + if self.uuid_name not in d: + raise ValueError(f"uuid_name '{self.uuid_name}' not available in data at index {idx}!") + uuid = d[self.uuid_name] + if not isinstance(uuid, str): + raise TypeError(f"uuid '{uuid}' at index {idx} is not a string!") + if len(uuid) == 0: + raise ValueError(f"uuid '{uuid}' at index {idx} is empty!") + uuids.add(uuid) + + if len(uuids) != len(self.data): + raise ValueError("uuids are not unique!") + + # create and check _result_dir_path + self._result_dir_path = Path(self.result_dir) + self._result_dir_path.mkdir(parents=True, exist_ok=True) # create directory if not available + if not self._result_dir_path.is_dir(): + raise ValueError(f"Faild to create or find result_dir '{self.result_dir}'!") + + def _get_locked_or_done_uuids(self) -> Set[str]: + locked_or_done_uuids: Set[str] = set() + for child_path in self._result_dir_path.iterdir(): + if child_path.is_file(): + filename = child_path.name + if filename.endswith(".lock"): + uuid = filename[: filename.rindex(".lock")] + elif filename.endswith(".json.gz") and "_" in filename: + uuid = filename[: filename.rindex("_")] + locked_or_done_uuids.add(uuid) + return locked_or_done_uuids + + def _write_lock_files(self, batch: Sequence[Dict[str, Any]]) -> None: + for d in batch: + uuid = d[self.uuid_name] + (self._result_dir_path / f"{uuid}.lock").touch() + self._own_lock_uuids.add(uuid) + + def read_batch(self) -> Sequence[Dict[str, Any]]: + """Read the next batch of data.""" + locked_or_done_uuids: Set[str] = self._get_locked_or_done_uuids() + remaining_data = [d for d in self.data if d[self.uuid_name] not in locked_or_done_uuids] + random.shuffle(remaining_data) + next_batch_size = min(self.batch_size, len(remaining_data)) + next_batch = remaining_data[:next_batch_size] + self._write_lock_files(next_batch) + return next_batch + + def _save_batch_data(self, batch: Sequence[Dict[str, Any]]) -> None: + for d in batch: + uuid = d[self.uuid_name] + if uuid not in self._own_lock_uuids: + raise ValueError(f"uuid '{uuid}' not locked by me!") + filename = self._result_dir_path / f"{uuid}_{str(uuid4())}.json.gz" # noqa: RUF010 + with gzip.GzipFile(filename, "w") as outfile: + outfile.write(json.dumps(d).encode("utf-8")) + + def _remove_lock_files(self, batch: Sequence[Dict[str, Any]]) -> None: + for d in batch: + uuid = d[self.uuid_name] + (self._result_dir_path / f"{uuid}.lock").unlink(missing_ok=True) + self._own_lock_uuids.discard(uuid) + + def save_batch(self, batch: Sequence[Dict[str, Any]]) -> None: + """Save the batch of data.""" + self._save_batch_data(batch) + self._remove_lock_files(batch) + + @staticmethod + def load_data(result_dir: str) -> List[Dict[str, Any]]: + """Load all data. + + After all data is processed, this method can be used to load all data. + + Args: + result_dir: The directory where the results are stored. + """ + _result_dir_path = Path(result_dir) + if not _result_dir_path.is_dir(): + raise ValueError(f"Did not find result_dir '{result_dir}'!") + + data = [] + for child_path in _result_dir_path.iterdir(): + if child_path.is_file() and child_path.name.endswith(".json.gz"): + with gzip.GzipFile(child_path, "r") as infile: + data.append(json.loads(infile.read().decode("utf-8"))) + return data diff --git a/pyproject.toml b/pyproject.toml index 533f2c9..4892ac4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,11 @@ fasttext-wheel = {version = "*", optional = true} optuna = {version = "*", optional = true} matplotlib = {version = "*", optional = true} SoMaJo = {version = ">=2.4.1", optional = true} -torch = {version = "!=2.0.1,!=2.1.0", optional = true} # some versions have poetry issues + +# some versions have poetry issues +# 2.3.0 does not work with Intel Mac +torch = {version = "!=2.0.1,!=2.1.0,!=2.3.0", optional = true} + transformers = {version = "*", optional = true} tiktoken = {version = "*", optional = true} safetensors = {version = "!=0.3.2", optional = true} # version 0.3.2 has poetry issues @@ -121,10 +125,13 @@ line-length = 119 target-version = ["py38", "py39", "py310", "py311"] [tool.ruff] -select = ["ALL"] line-length = 119 -fixable = ["I"] target-version = "py38" + + +[tool.ruff.lint] +select = ["ALL"] +fixable = ["I"] ignore = [ "DJ", # flake8-django - https://docs.astral.sh/ruff/rules/#flake8-django-dj "ERA", # eradicate - https://docs.astral.sh/ruff/rules/#eradicate-era @@ -157,18 +164,19 @@ ignore = [ "RUF015", # Prefer `next(iter(sentences))` over single element slice ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "tests/**/test_*.py" = [ "D100", # Missing docstring in public module "D103", # Missing docstring in public function "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable "S101", # Use of assert detected + "N802", # Function name should be lowercase ] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" -[tool.ruff.flake8-copyright] +[tool.ruff.lint.flake8-copyright] notice-rgx = "(# Copyright \\(c\\) \\d{4}.*\\n)+# This software is distributed under the terms of the MIT license\\n# which is available at https://opensource.org/licenses/MIT\\n\\n" [tool.mypy] diff --git a/tests/test_files.py b/tests/test_files.py index 1f799a6..721965d 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -1,12 +1,14 @@ -# Copyright (c) 2023 Philip May +# Copyright (c) 2023-2024 Philip May +# Copyright (c) 2023-2024 Philip May, Deutsche Telekom AG # This software is distributed under the terms of the MIT license # which is available at https://opensource.org/licenses/MIT import os +from uuid import uuid4 import pytest -from mltb2.files import fetch_remote_file, get_and_create_mltb2_data_dir +from mltb2.files import FileBasedRestartableBatchDataProcessor, fetch_remote_file, get_and_create_mltb2_data_dir def test_fetch_remote_file(tmpdir): @@ -37,3 +39,132 @@ def test_get_and_create_mltb2_data_dir(tmpdir): mltb2_data_dir = get_and_create_mltb2_data_dir(tmpdir) assert mltb2_data_dir == os.path.join(tmpdir, "mltb2") + + +def test_FileBasedRestartableBatchDataProcessor_batch_size(tmp_path): + result_dir = tmp_path.absolute() + with pytest.raises(ValueError): + _ = FileBasedRestartableBatchDataProcessor(data=[], batch_size=0, uuid_name="uuid", result_dir=result_dir) + + +def test_FileBasedRestartableBatchDataProcessor_empty_data(tmp_path): + result_dir = tmp_path.absolute() + with pytest.raises(ValueError): + _ = FileBasedRestartableBatchDataProcessor(data=[], batch_size=10, uuid_name="uuid", result_dir=result_dir) + + +def test_FileBasedRestartableBatchDataProcessor_uuid_in_data(tmp_path): + result_dir = tmp_path.absolute() + with pytest.raises(ValueError): + _ = FileBasedRestartableBatchDataProcessor( + data=[{"x": 10}], batch_size=10, uuid_name="uuid", result_dir=result_dir + ) + + +def test_FileBasedRestartableBatchDataProcessor_uuid_type(tmp_path): + result_dir = tmp_path.absolute() + with pytest.raises(TypeError): + _ = FileBasedRestartableBatchDataProcessor( + data=[{"uuid": 6, "x": 10}], batch_size=10, uuid_name="uuid", result_dir=result_dir + ) + + +def test_FileBasedRestartableBatchDataProcessor_uuid_empty(tmp_path): + result_dir = tmp_path.absolute() + with pytest.raises(ValueError): + _ = FileBasedRestartableBatchDataProcessor( + data=[{"uuid": "", "x": 10}], batch_size=10, uuid_name="uuid", result_dir=result_dir + ) + + +def test_FileBasedRestartableBatchDataProcessor_uuid_unique(tmp_path): + result_dir = tmp_path.absolute() + data = [{"uuid": "a", "x": 10}, {"uuid": "a", "x": 10}, {"uuid": "c", "x": 10}] + with pytest.raises(ValueError): + _ = FileBasedRestartableBatchDataProcessor(data=data, batch_size=10, uuid_name="uuid", result_dir=result_dir) + + +def test_FileBasedRestartableBatchDataProcessor_write_lock_files(tmp_path): + result_dir = tmp_path.absolute() + batch_size = 10 + data = [{"uuid": str(uuid4()), "x": i} for i in range(100)] + data_processor = FileBasedRestartableBatchDataProcessor( + data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir + ) + data = data_processor.read_batch() + + assert len(data) == batch_size + + # check lock files + lock_files = list(tmp_path.glob("*.lock")) + assert len(lock_files) == batch_size + + +def test_FileBasedRestartableBatchDataProcessor_save_batch_data(tmp_path): + result_dir = tmp_path.absolute() + batch_size = 10 + data = [{"uuid": str(uuid4()), "x": i} for i in range(100)] + data_processor = FileBasedRestartableBatchDataProcessor( + data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir + ) + data = data_processor.read_batch() + data_processor.save_batch(data) + + # check lock files + lock_files = list(tmp_path.glob("*.json.gz")) + assert len(lock_files) == batch_size + + +def test_FileBasedRestartableBatchDataProcessor_remove_lock_files(tmp_path): + result_dir = tmp_path.absolute() + batch_size = 10 + data = [{"uuid": str(uuid4()), "x": i} for i in range(100)] + data_processor = FileBasedRestartableBatchDataProcessor( + data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir + ) + data = data_processor.read_batch() + data_processor.save_batch(data) + + # check lock files + lock_files = list(tmp_path.glob("*.lock")) + assert len(lock_files) == 0 + + +def test_FileBasedRestartableBatchDataProcessor_save_unlocked(tmp_path): + result_dir = tmp_path.absolute() + batch_size = 10 + data = [{"uuid": str(uuid4()), "x": i} for i in range(100)] + data_processor = FileBasedRestartableBatchDataProcessor( + data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir + ) + data = data_processor.read_batch() + data[0]["uuid"] = "something_else" + with pytest.raises(ValueError): + data_processor.save_batch(data) + + +def test_FileBasedRestartableBatchDataProcessor_load_data(tmp_path): + result_dir = tmp_path.absolute() + batch_size = 10 + data = [{"uuid": str(uuid4()), "x": i} for i in range(100)] + data_processor = FileBasedRestartableBatchDataProcessor( + data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir + ) + + # process all data + while True: + _data = data_processor.read_batch() + if len(_data) == 0: + break + data_processor.save_batch(_data) + + del data_processor + processed_data = FileBasedRestartableBatchDataProcessor.load_data(result_dir) + + assert len(processed_data) == len(data) + for d in processed_data: + assert "uuid" in d + assert "x" in d + assert isinstance(d["uuid"], str) + assert isinstance(d["x"], int) + assert d["x"] < 100 diff --git a/tests/test_md.py b/tests/test_md.py index 6f04f6d..4432d7c 100644 --- a/tests/test_md.py +++ b/tests/test_md.py @@ -49,7 +49,7 @@ def test_chunk_md(): assert result[2] == "### Headline 3 / 1\n\n#### Headline 4 / 1\n\nContent." -def test_MdTextSplitter_call(): # noqa: N802 +def test_MdTextSplitter_call(): transformers_token_counter = TransformersTokenCounter("deepset/gbert-base") text_merger = MdTextSplitter( max_token=15, @@ -63,7 +63,7 @@ def test_MdTextSplitter_call(): # noqa: N802 assert merged_md[1] == "### Headline 3 / 1\n\n#### Headline 4 / 1\n\nContent." -def test_MdTextSplitter_call_no_merge(): # noqa: N802 +def test_MdTextSplitter_call_no_merge(): transformers_token_counter = TransformersTokenCounter("deepset/gbert-base") text_merger = MdTextSplitter( max_token=1, @@ -78,7 +78,7 @@ def test_MdTextSplitter_call_no_merge(): # noqa: N802 assert merged_md[2] == "### Headline 3 / 1\n\n#### Headline 4 / 1\n\nContent." -def test_MdTextSplitter_call_all_merge(): # noqa: N802 +def test_MdTextSplitter_call_all_merge(): transformers_token_counter = TransformersTokenCounter("deepset/gbert-base") text_merger = MdTextSplitter( max_token=1000, diff --git a/tests/test_openai.py b/tests/test_openai.py index 2e59d8c..61c74e0 100644 --- a/tests/test_openai.py +++ b/tests/test_openai.py @@ -18,12 +18,12 @@ def gpt_4_open_ai_token_counter() -> OpenAiTokenCounter: @settings(max_examples=1000) @given(text=text()) -def test_OpenAiTokenCounter_str_hypothesis(text: str, gpt_4_open_ai_token_counter: OpenAiTokenCounter): # noqa: N802 +def test_OpenAiTokenCounter_str_hypothesis(text: str, gpt_4_open_ai_token_counter: OpenAiTokenCounter): token_count = gpt_4_open_ai_token_counter(text) assert token_count >= 0 # type: ignore[operator] -def test_OpenAiTokenCounter_call_string(): # noqa: N802 +def test_OpenAiTokenCounter_call_string(): token_counter = OpenAiTokenCounter("gpt-4") token_count = token_counter("Das ist ein Text.") @@ -32,15 +32,13 @@ def test_OpenAiTokenCounter_call_string(): # noqa: N802 @settings(max_examples=1000) @given(texts=lists(text())) -def test_OpenAiTokenCounter_list_hypothesis( # noqa: N802 - texts: List[str], gpt_4_open_ai_token_counter: OpenAiTokenCounter -): +def test_OpenAiTokenCounter_list_hypothesis(texts: List[str], gpt_4_open_ai_token_counter: OpenAiTokenCounter): token_count = gpt_4_open_ai_token_counter(texts) assert len(token_count) == len(texts) # type: ignore[arg-type] assert all(count >= 0 for count in token_count) # type: ignore[union-attr] -def test_OpenAiTokenCounter_call_list(): # noqa: N802 +def test_OpenAiTokenCounter_call_list(): token_counter = OpenAiTokenCounter("gpt-4") token_count = token_counter(["Das ist ein Text.", "Das ist ein anderer Text."]) @@ -50,34 +48,34 @@ def test_OpenAiTokenCounter_call_list(): # noqa: N802 assert token_count[1] == 7 -def test_OpenAiChat__missing_role_message_key(): # noqa: N802 +def test_OpenAiChat__missing_role_message_key(): open_ai_chat = OpenAiChat(api_key="secret", model="apt-4") invalid_prompt_as_list = [{"x": "user", "content": "prompt"}] with pytest.raises(ValueError): open_ai_chat(invalid_prompt_as_list) -def test_OpenAiChat__missing_content_message_key(): # noqa: N802 +def test_OpenAiChat__missing_content_message_key(): open_ai_chat = OpenAiChat(api_key="secret", model="apt-4") invalid_prompt_as_list = [{"role": "user", "x": "prompt"}] with pytest.raises(ValueError): open_ai_chat(invalid_prompt_as_list) -def test_OpenAiChat__invalid_role_in_message_key(): # noqa: N802 +def test_OpenAiChat__invalid_role_in_message_key(): open_ai_chat = OpenAiChat(api_key="secret", model="apt-4") invalid_prompt_as_list = [{"role": "x", "content": "prompt"}] with pytest.raises(ValueError): open_ai_chat(invalid_prompt_as_list) -def test_OpenAiChat__model_in_completion_kwargs(): # noqa: N802 +def test_OpenAiChat__model_in_completion_kwargs(): open_ai_chat = OpenAiChat(api_key="secret", model="apt-4") with pytest.raises(ValueError): open_ai_chat("Hello!", completion_kwargs={"model": "gpt-4"}) -def test_OpenAiChat__messages_in_completion_kwargs(): # noqa: N802 +def test_OpenAiChat__messages_in_completion_kwargs(): open_ai_chat = OpenAiChat(api_key="secret", model="apt-4") with pytest.raises(ValueError): open_ai_chat("Hello!", completion_kwargs={"messages": "World!"}) diff --git a/tests/test_somajo.py b/tests/test_somajo.py index fe7ecc8..17e285d 100644 --- a/tests/test_somajo.py +++ b/tests/test_somajo.py @@ -17,7 +17,7 @@ ) -def test_SoMaJoSentenceSplitter_call() -> None: # noqa: N802 +def test_SoMaJoSentenceSplitter_call() -> None: """Test ``SoMaJoSentenceSplitter.call`` happy case.""" splitter = SoMaJoSentenceSplitter("de_CMC") text = "Das ist der erste Satz. Das ist der 2. Satz." @@ -28,7 +28,7 @@ def test_SoMaJoSentenceSplitter_call() -> None: # noqa: N802 assert sentences[1] == "Das ist der 2. Satz." -def test_SoMaJoSentenceSplitter_call_space_and_linebreak() -> None: # noqa: N802 +def test_SoMaJoSentenceSplitter_call_space_and_linebreak() -> None: """Test ``SoMaJoSentenceSplitter.call`` with space an line break.""" splitter = SoMaJoSentenceSplitter("de_CMC") text = " Das ist der erste Satz. \n Das ist der 2. \n Satz. " @@ -39,7 +39,7 @@ def test_SoMaJoSentenceSplitter_call_space_and_linebreak() -> None: # noqa: N80 assert sentences[1] == "Das ist der 2. Satz." -def test_JaccardSimilarity_call(): # noqa: N802 +def test_JaccardSimilarity_call(): text1 = "Das ist ein deutscher Text." text2 = "Das ist ein anderer Text." jaccard_similarity = JaccardSimilarity("de_CMC") @@ -52,7 +52,7 @@ def test_JaccardSimilarity_call(): # noqa: N802 assert result2 > 0.0 -def test_JaccardSimilarity_call_same(): # noqa: N802 +def test_JaccardSimilarity_call_same(): text = "Das ist ein deutscher Text." jaccard_similarity = JaccardSimilarity("de_CMC") result = jaccard_similarity(text, text) @@ -60,7 +60,7 @@ def test_JaccardSimilarity_call_same(): # noqa: N802 assert isclose(result, 1.0) -def test_JaccardSimilarity_call_no_overlap(): # noqa: N802 +def test_JaccardSimilarity_call_no_overlap(): text1 = "Das ist ein deutscher Text." text2 = "Vollkommen anders!" jaccard_similarity = JaccardSimilarity("de_CMC") @@ -69,7 +69,7 @@ def test_JaccardSimilarity_call_no_overlap(): # noqa: N802 assert isclose(result, 0.0) -def test_TokenExtractor_extract_token_set(): # noqa: N802 +def test_TokenExtractor_extract_token_set(): text = "Das ist ein Text. Er enthält keine URL." token_extractor = TokenExtractor("de_CMC") result = token_extractor.extract_token_set(text) @@ -78,7 +78,7 @@ def test_TokenExtractor_extract_token_set(): # noqa: N802 assert "." in result -def test_TokenExtractor_extract_url_set_with_str(): # noqa: N802 +def test_TokenExtractor_extract_url_set_with_str(): url1 = "http://may.la" url2 = "github.com" text_with_url = f"{url1} Das ist ein Text. {url2} Er enthält eine URL." @@ -89,7 +89,7 @@ def test_TokenExtractor_extract_url_set_with_str(): # noqa: N802 assert url2 in result -def test_TokenExtractor_extract_url_set_with_list(): # noqa: N802 +def test_TokenExtractor_extract_url_set_with_list(): url1 = "http://may.la" url2 = "github.com" text_with_url = [f"{url1} Das ist ein Text.", f"{url2} Er enthält eine URL."] @@ -100,7 +100,7 @@ def test_TokenExtractor_extract_url_set_with_list(): # noqa: N802 assert url2 in result -def test_TokenExtractor_extract_url_set_no_url(): # noqa: N802 +def test_TokenExtractor_extract_url_set_no_url(): text_with_url = "Das ist ein Text. Er enthält keine URLs." token_extractor = TokenExtractor("de_CMC") result = token_extractor.extract_url_set(text_with_url) @@ -146,7 +146,7 @@ def test_detokenize(): assert result == "Das ist ein Satz." -def test_UrlSwapper_swap_urls(): # noqa: N802 +def test_UrlSwapper_swap_urls(): token_extractor = TokenExtractor("de_CMC") url_swapper = UrlSwapper(token_extractor) text_with_url = "This is a text with URL: http://may.la." @@ -162,7 +162,7 @@ def test_UrlSwapper_swap_urls(): # noqa: N802 "2 MD URL s: [Philip May](http://may.la). [other link](https://github.com/telekom/mltb2#installation)", ], ) -def test_UrlSwapper__is_reversible(text_with_url: str): # noqa: N802 +def test_UrlSwapper__is_reversible(text_with_url: str): token_extractor = TokenExtractor("de_CMC") url_swapper = UrlSwapper(token_extractor) text_with_reverse_swapped_url, no_reverse_swap_urls = url_swapper.reverse_swap_urls( @@ -172,7 +172,7 @@ def test_UrlSwapper__is_reversible(text_with_url: str): # noqa: N802 assert len(no_reverse_swap_urls) == 0 -def test_UrlSwapper__no_reverse_swap_urls(): # noqa: N802 +def test_UrlSwapper__no_reverse_swap_urls(): token_extractor = TokenExtractor("de_CMC") url_swapper = UrlSwapper(token_extractor) text_with_url = "This is a text with URL: http://may.la." @@ -185,7 +185,7 @@ def test_UrlSwapper__no_reverse_swap_urls(): # noqa: N802 # see https://github.com/telekom/mltb2/issues/94 -def test_UrlSwapper__markdown_bug(): # noqa: N802 +def test_UrlSwapper__markdown_bug(): token_extractor = TokenExtractor("de_CMC") url_swapper = UrlSwapper(token_extractor) text_with_url = "This is a MD link: [https://something-1.com](https://something-2.com)." @@ -196,7 +196,7 @@ def test_UrlSwapper__markdown_bug(): # noqa: N802 # regression for https://github.com/tsproisl/SoMaJo/issues/27 -def test_TokenExtractor_extract_url_set__markdown_bug_2(): # noqa: N802 +def test_TokenExtractor_extract_url_set__markdown_bug_2(): text_with_url = "This is a MD link: ." token_extractor = TokenExtractor("de_CMC") url_set = token_extractor.extract_url_set(text_with_url) diff --git a/tests/test_somajo_transformers.py b/tests/test_somajo_transformers.py index 76807c4..a053851 100644 --- a/tests/test_somajo_transformers.py +++ b/tests/test_somajo_transformers.py @@ -11,7 +11,7 @@ from mltb2.transformers import TransformersTokenCounter -def test_TextSplitter_call(): # noqa: N802 +def test_TextSplitter_call(): somajo_sentence_splitter = SoMaJoSentenceSplitter("de_CMC") transformers_token_counter = TransformersTokenCounter("deepset/gbert-base") text_splitter = TextSplitter( @@ -29,7 +29,7 @@ def test_TextSplitter_call(): # noqa: N802 assert split_text[2] == "Satz 4 ist das." -def test_TextSplitter_call_sentence_too_long_exception(): # noqa: N802 +def test_TextSplitter_call_sentence_too_long_exception(): somajo_sentence_splitter = SoMaJoSentenceSplitter("de_CMC") transformers_token_counter = TransformersTokenCounter("deepset/gbert-base") text_splitter = TextSplitter( @@ -43,7 +43,7 @@ def test_TextSplitter_call_sentence_too_long_exception(): # noqa: N802 text_splitter(text) -def test_TextSplitter_call_sentence_too_long_no_exception(): # noqa: N802 +def test_TextSplitter_call_sentence_too_long_no_exception(): somajo_sentence_splitter = SoMaJoSentenceSplitter("de_CMC") transformers_token_counter = TransformersTokenCounter("deepset/gbert-base") text_splitter = TextSplitter( diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 6a7fc20..d91bbb9 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -17,23 +17,21 @@ def deepset_gbert_base_token_counter() -> TransformersTokenCounter: @settings(max_examples=1000, deadline=None) @given(text=text()) -def test_TransformersTokenCounter_hypothesis( # noqa: N802 - text: str, deepset_gbert_base_token_counter: TransformersTokenCounter -): +def test_TransformersTokenCounter_hypothesis(text: str, deepset_gbert_base_token_counter: TransformersTokenCounter): token_count = deepset_gbert_base_token_counter(text) assert isinstance(token_count, int) assert token_count >= 0 -def test_TransformersTokenCounter_call_string(): # noqa: N802 +def test_TransformersTokenCounter_call_string(): transformers_token_counter = TransformersTokenCounter("deepset/gbert-base") token_count = transformers_token_counter("Das ist ein Text.") assert token_count == 5 -def test_TransformersTokenCounter_call_list(): # noqa: N802 +def test_TransformersTokenCounter_call_list(): transformers_token_counter = TransformersTokenCounter("deepset/gbert-base") token_count = transformers_token_counter(["Das ist ein Text.", "Das ist ein anderer Text."])