Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a read_only option to seqio.TfdsDataSource. #760

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ def __init__(
caching_permitted: bool = True,
decoders: Optional[tfds.typing.TreeDict[tfds.decode.Decoder]] = None,
tfds_builder_kwargs: Optional[dict[str, Any]] = None,
read_only: bool = False,
):
"""TfdsTask constructor.

Expand All @@ -519,6 +520,7 @@ def __init__(
tfds_builder_kwargs: `dict` (optional), keyword arguments to be passed to
the `tfds.core.DatasetBuilder` constructor through `tfds.load()` and
`tfds.builder()`.
read_only: whether `get_dataset` can trigger the generation of a dataset.
"""
if splits and not isinstance(splits, dict):
splits = {k: k for k in splits}
Expand All @@ -529,6 +531,7 @@ def __init__(
split_map=splits if isinstance(splits, dict) else None,
decoders=decoders,
builder_kwargs=tfds_builder_kwargs,
read_only=read_only,
)

# If splits are not provided, we pass an empty tuple and use the lazy
Expand Down
33 changes: 22 additions & 11 deletions seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
split_map: Union[Mapping[str, str], Mapping[str, TfdsSplit], None] = None,
decoders=None,
builder_kwargs: Optional[dict[str, Any]] = None,
read_only: bool = False,
):
"""LazyTfdsLoader constructor.

Expand All @@ -153,6 +154,7 @@ def __init__(
builder_kwargs: `dict` (optional), keyword arguments to be passed to the
`tfds.core.DatasetBuilder` constructor through `tfds.load()` and
`tfds.builder()`.
read_only: whether `get_dataset` can trigger the generation of a dataset.
"""
_validate_tfds_name(name)
self._name = name
Expand All @@ -161,6 +163,7 @@ def __init__(
self._split_map = split_map
self._decoders = decoders
self._builder_kwargs = builder_kwargs
self._read_only = read_only

self._is_custom_split_map = False
if split_map:
Expand Down Expand Up @@ -389,17 +392,25 @@ def load(
)
read_config.shuffle_seed = seed
read_config.skip_prefetch = True
return tfds.load(
dataset,
split=dataset_split,
data_dir=data_dir,
shuffle_files=shuffle_files,
download=True,
try_gcs=True,
read_config=read_config,
decoders=self._decoders,
builder_kwargs=self._builder_kwargs,
)
if self._read_only:
return self.builder.as_dataset(
split=dataset_split,
shuffle_files=shuffle_files,
read_config=read_config,
decoders=self._decoders,
)
else:
return tfds.load(
dataset,
split=dataset_split,
data_dir=data_dir,
shuffle_files=shuffle_files,
download=True,
try_gcs=True,
read_config=read_config,
decoders=self._decoders,
builder_kwargs=self._builder_kwargs,
)

def load_shard(
self,
Expand Down
17 changes: 17 additions & 0 deletions seqio/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,23 @@ def test_builder_cls_non_existing(self):
actual = ds.builder_cls()
self.assertIsNone(actual)

@mock.patch("tensorflow_datasets.load")
@mock.patch("tensorflow_datasets.builder")
def test_read_only(self, mock_tfds_builder, mock_tfds_load):
mock_builder = mock.create_autospec(tfds.core.DatasetBuilder)
mock_tfds_builder.return_value = mock_builder
loader = utils.LazyTfdsLoader(
"ds/cfg:1.2.3", data_dir="/data", read_only=True
)
_ = loader.load(split="train", shuffle_files=False)
mock_tfds_load.assert_not_called()
mock_builder.as_dataset.assert_called_once_with(
split="train",
shuffle_files=False,
read_config=AnyArg(),
decoders=None,
)

@mock.patch("tensorflow_datasets.load")
def test_split_map(self, mock_tfds_load):
seed = 0
Expand Down
Loading