From e376c72e5344173c4ffb9df48eb9ed737eabda96 Mon Sep 17 00:00:00 2001 From: Tom van der Weide Date: Wed, 4 Sep 2024 22:39:11 -0700 Subject: [PATCH] Add a read_only option to seqio.TfdsDataSource. This option prevents the dataset from being generated if it does not already exist. This is useful when there is a central process that generates datasets and client shouldn't. PiperOrigin-RevId: 671234210 --- seqio/dataset_providers.py | 3 +++ seqio/utils.py | 33 ++++++++++++++++++++++----------- seqio/utils_test.py | 17 +++++++++++++++++ 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index dab7ef84..a3f8d246 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -495,6 +495,7 @@ def __init__( caching_permitted: bool = True, decoders: Optional[tfds.typing.TreeDict[tfds.decode.Decoder]] = None, tfds_builder_kwargs: Optional[dict[str, Any]] = None, + read_only: bool = False, ): """TfdsTask constructor. @@ -519,6 +520,7 @@ def __init__( tfds_builder_kwargs: `dict` (optional), keyword arguments to be passed to the `tfds.core.DatasetBuilder` constructor through `tfds.load()` and `tfds.builder()`. + read_only: whether `get_dataset` can trigger the generation of a dataset. """ if splits and not isinstance(splits, dict): splits = {k: k for k in splits} @@ -529,6 +531,7 @@ def __init__( split_map=splits if isinstance(splits, dict) else None, decoders=decoders, builder_kwargs=tfds_builder_kwargs, + read_only=read_only, ) # If splits are not provided, we pass an empty tuple and use the lazy diff --git a/seqio/utils.py b/seqio/utils.py index 5768c029..700d1b77 100644 --- a/seqio/utils.py +++ b/seqio/utils.py @@ -137,6 +137,7 @@ def __init__( split_map: Union[Mapping[str, str], Mapping[str, TfdsSplit], None] = None, decoders=None, builder_kwargs: Optional[dict[str, Any]] = None, + read_only: bool = False, ): """LazyTfdsLoader constructor. @@ -153,6 +154,7 @@ def __init__( builder_kwargs: `dict` (optional), keyword arguments to be passed to the `tfds.core.DatasetBuilder` constructor through `tfds.load()` and `tfds.builder()`. + read_only: whether `get_dataset` can trigger the generation of a dataset. """ _validate_tfds_name(name) self._name = name @@ -161,6 +163,7 @@ def __init__( self._split_map = split_map self._decoders = decoders self._builder_kwargs = builder_kwargs + self._read_only = read_only self._is_custom_split_map = False if split_map: @@ -389,17 +392,25 @@ def load( ) read_config.shuffle_seed = seed read_config.skip_prefetch = True - return tfds.load( - dataset, - split=dataset_split, - data_dir=data_dir, - shuffle_files=shuffle_files, - download=True, - try_gcs=True, - read_config=read_config, - decoders=self._decoders, - builder_kwargs=self._builder_kwargs, - ) + if self._read_only: + return self.builder.as_dataset( + split=dataset_split, + shuffle_files=shuffle_files, + read_config=read_config, + decoders=self._decoders, + ) + else: + return tfds.load( + dataset, + split=dataset_split, + data_dir=data_dir, + shuffle_files=shuffle_files, + download=True, + try_gcs=True, + read_config=read_config, + decoders=self._decoders, + builder_kwargs=self._builder_kwargs, + ) def load_shard( self, diff --git a/seqio/utils_test.py b/seqio/utils_test.py index 5261177b..b10593ae 100644 --- a/seqio/utils_test.py +++ b/seqio/utils_test.py @@ -145,6 +145,23 @@ def test_builder_cls_non_existing(self): actual = ds.builder_cls() self.assertIsNone(actual) + @mock.patch("tensorflow_datasets.load") + @mock.patch("tensorflow_datasets.builder") + def test_read_only(self, mock_tfds_builder, mock_tfds_load): + mock_builder = mock.create_autospec(tfds.core.DatasetBuilder) + mock_tfds_builder.return_value = mock_builder + loader = utils.LazyTfdsLoader( + "ds/cfg:1.2.3", data_dir="/data", read_only=True + ) + _ = loader.load(split="train", shuffle_files=False) + mock_tfds_load.assert_not_called() + mock_builder.as_dataset.assert_called_once_with( + split="train", + shuffle_files=False, + read_config=AnyArg(), + decoders=None, + ) + @mock.patch("tensorflow_datasets.load") def test_split_map(self, mock_tfds_load): seed = 0