Skip to content

Commit

Permalink
Better handling of imports errors when blingfire is not available (#61)
Browse files Browse the repository at this point in the history
* adding extra errors

* style
  • Loading branch information
soldni authored Jun 27, 2023
1 parent 1bf52a7 commit ee3dc44
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smashed"
version = "0.21.3"
version = "0.21.4"
description = """\
SMASHED is a toolkit designed to apply transformations to samples in \
datasets, such as fields extraction, tokenization, prompting, batching, \
Expand All @@ -11,7 +11,7 @@ license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"necessary>=0.4.1",
"necessary>=0.4.3",
"trouting>=0.3.3",
"ftfy>=6.1.1",
"platformdirs>=2.5.0",
Expand Down
5 changes: 4 additions & 1 deletion src/smashed/utils/wordsplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from necessary import Necessary, necessary

with necessary("blingfire", soft=True) as BLINGFIRE_AVAILABLE:
with necessary(
"blingfire", soft=True, errors=(ModuleNotFoundError, OSError)
) as BLINGFIRE_AVAILABLE:
if BLINGFIRE_AVAILABLE or TYPE_CHECKING:
from blingfire import text_to_words

Expand Down Expand Up @@ -39,6 +41,7 @@ def __call__(

@Necessary(
"blingfire",
errors=(ModuleNotFoundError, OSError),
message=(
"{module_name} missing. Fix with 'pip install smashed[prompting]'."
"If you are on a Mac with Apple Silicon chip, also run "
Expand Down
2 changes: 1 addition & 1 deletion tests/test_batch_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestBatchInterface(unittest.TestCase):
def test_batch(self, remove_columns: bool = False):
mapper = MockMapper(1, output_fields=["a"])

data = Dataset.from_list([{"a": i, "b": i**2} for i in range(100)])
data = Dataset.from_list([{"a": i, "b": i ** 2} for i in range(100)])

def _batch_fn(data: LazyBatch, mapper: MockMapper) -> LazyBatch:
return mapper.map(deepcopy(data), remove_columns=remove_columns)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_list_cache(self):

def test_datasets_cache(self):
dt = Dataset.from_dict(
{"a": [i for i in range(5)], "b": [i**2 for i in range(5)]}
{"a": [i for i in range(5)], "b": [i ** 2 for i in range(5)]}
)

with tempfile.TemporaryDirectory() as tmpdir:
Expand Down

0 comments on commit ee3dc44

Please sign in to comment.