Skip to content

Commit

Permalink
Add a read_only option to seqio.TfdsDataSource.
Browse files Browse the repository at this point in the history
This option prevents the dataset from being generated if it does not already exist. This is useful when there is a central process that generates datasets and client shouldn't.

PiperOrigin-RevId: 671234210
  • Loading branch information
tomvdw authored and SeqIO committed Sep 10, 2024
1 parent a4183b6 commit e376c72
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
3 changes: 3 additions & 0 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ def __init__(
caching_permitted: bool = True,
decoders: Optional[tfds.typing.TreeDict[tfds.decode.Decoder]] = None,
tfds_builder_kwargs: Optional[dict[str, Any]] = None,
read_only: bool = False,
):
"""TfdsTask constructor.
Expand All @@ -519,6 +520,7 @@ def __init__(
tfds_builder_kwargs: `dict` (optional), keyword arguments to be passed to
the `tfds.core.DatasetBuilder` constructor through `tfds.load()` and
`tfds.builder()`.
read_only: whether `get_dataset` can trigger the generation of a dataset.
"""
if splits and not isinstance(splits, dict):
splits = {k: k for k in splits}
Expand All @@ -529,6 +531,7 @@ def __init__(
split_map=splits if isinstance(splits, dict) else None,
decoders=decoders,
builder_kwargs=tfds_builder_kwargs,
read_only=read_only,
)

# If splits are not provided, we pass an empty tuple and use the lazy
Expand Down
33 changes: 22 additions & 11 deletions seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
split_map: Union[Mapping[str, str], Mapping[str, TfdsSplit], None] = None,
decoders=None,
builder_kwargs: Optional[dict[str, Any]] = None,
read_only: bool = False,
):
"""LazyTfdsLoader constructor.
Expand All @@ -153,6 +154,7 @@ def __init__(
builder_kwargs: `dict` (optional), keyword arguments to be passed to the
`tfds.core.DatasetBuilder` constructor through `tfds.load()` and
`tfds.builder()`.
read_only: whether `get_dataset` can trigger the generation of a dataset.
"""
_validate_tfds_name(name)
self._name = name
Expand All @@ -161,6 +163,7 @@ def __init__(
self._split_map = split_map
self._decoders = decoders
self._builder_kwargs = builder_kwargs
self._read_only = read_only

self._is_custom_split_map = False
if split_map:
Expand Down Expand Up @@ -389,17 +392,25 @@ def load(
)
read_config.shuffle_seed = seed
read_config.skip_prefetch = True
return tfds.load(
dataset,
split=dataset_split,
data_dir=data_dir,
shuffle_files=shuffle_files,
download=True,
try_gcs=True,
read_config=read_config,
decoders=self._decoders,
builder_kwargs=self._builder_kwargs,
)
if self._read_only:
return self.builder.as_dataset(
split=dataset_split,
shuffle_files=shuffle_files,
read_config=read_config,
decoders=self._decoders,
)
else:
return tfds.load(
dataset,
split=dataset_split,
data_dir=data_dir,
shuffle_files=shuffle_files,
download=True,
try_gcs=True,
read_config=read_config,
decoders=self._decoders,
builder_kwargs=self._builder_kwargs,
)

def load_shard(
self,
Expand Down
17 changes: 17 additions & 0 deletions seqio/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,23 @@ def test_builder_cls_non_existing(self):
actual = ds.builder_cls()
self.assertIsNone(actual)

@mock.patch("tensorflow_datasets.load")
@mock.patch("tensorflow_datasets.builder")
def test_read_only(self, mock_tfds_builder, mock_tfds_load):
mock_builder = mock.create_autospec(tfds.core.DatasetBuilder)
mock_tfds_builder.return_value = mock_builder
loader = utils.LazyTfdsLoader(
"ds/cfg:1.2.3", data_dir="/data", read_only=True
)
_ = loader.load(split="train", shuffle_files=False)
mock_tfds_load.assert_not_called()
mock_builder.as_dataset.assert_called_once_with(
split="train",
shuffle_files=False,
read_config=AnyArg(),
decoders=None,
)

@mock.patch("tensorflow_datasets.load")
def test_split_map(self, mock_tfds_load):
seed = 0
Expand Down

0 comments on commit e376c72

Please sign in to comment.