Skip to content

Commit

Permalink
Support different datasets per split in TfdsDataSource
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 545226895
  • Loading branch information
fineguy authored and SeqIO committed Aug 10, 2023
1 parent 4ede81c commit 93f0dd2
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 41 deletions.
15 changes: 10 additions & 5 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,25 +471,30 @@ class TfdsDataSource(DataSource):

def __init__(
self,
tfds_name: str,
tfds_name: Optional[str] = None,
tfds_data_dir: Optional[str] = None,
splits: Optional[Union[Iterable[str], Mapping[str, str]]] = None,
splits: Optional[
Union[Iterable[str], Mapping[str, str], Mapping[str, utils.TfdsSplit]]
] = None,
caching_permitted: bool = True,
decoders: Optional[tfds.typing.TreeDict[tfds.decode.Decoder]] = None,
):
"""TfdsTask constructor.
Args:
tfds_name: The name and version number of a TFDS dataset, optionally with
a config.
a config. If `tfds_name` is not specified then `splits` values must be
instances of `TfdsSplit`.
tfds_data_dir: An optional path to a specific TFDS data directory to use.
If provided `tfds_name` must be a valid dataset in the directory. If
`tfds_name` is empty `tfds_dara_dir` must point to the directory with
one dataset.
splits: an iterable of allowable string split names, a dict mapping
allowable canonical splits (e.g., 'validation') to TFDS splits or slices
(e.g., 'train[':1%']), or None. The default, None, uses all available
splits from the TFDS dataset info.
(e.g., 'train[':1%']), or `TfdsSplit` (e.g. `TfdsSplit(dataset='mnist',
split='train')`), or None. The default, None, uses all available splits
from the TFDS dataset info. If `TfdsSplit` are used then `tfds_name`
must be empty.
caching_permitted: indicates whether this data source may be cached.
Default True.
decoders: dict (optional), mapping from features to tfds.decode.Decoders,
Expand Down
10 changes: 10 additions & 0 deletions seqio/dataset_providers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,16 @@ def test_tfds_splits(self):
tfds_name="fake:0.0.0", splits={"validation": "train"}
).splits,
)
self.assertSameElements(
["train"],
dataset_providers.TfdsDataSource(
splits={
"train": utils.TfdsSplit(
dataset="fake:0.0.0", split="validation"
)
}
).splits,
)

def test_tfds_source_splits(self):
default_splits_src = dataset_providers.TfdsDataSource("fake:0.0.0")
Expand Down
160 changes: 127 additions & 33 deletions seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,33 @@ def add_global_cache_dirs(global_cache_dirs):
_GLOBAL_CACHE_DIRECTORIES += global_cache_dirs


def _validate_tfds_name(name: str) -> None:
"""Validates TFDS dataset name."""
if (
name
and ":" not in name
):
raise ValueError(f"TFDS name must contain a version number, got: {name}")


@dataclasses.dataclass(frozen=True)
class TfdsSplit:
"""Points to a specific TFDS split.
Attributes:
dataset: dataset name.
split: TFDS split (e.g. 'train'), or slice (e.g. 'train[":1%"]').
data_dir: directory to read/write TFDS data.
"""

dataset: str
split: Optional[str]
data_dir: Optional[tfds.typing.PathLike] = None

def __post_init__(self):
_validate_tfds_name(self.dataset)


class LazyTfdsLoader(object):
"""Wrapper for TFDS datasets with memoization and additional functionality.
Expand All @@ -92,31 +119,46 @@ class LazyTfdsLoader(object):

def __init__(
self,
name: str,
name: Optional[str] = None,
data_dir=None,
split_map=None,
split_map: Optional[
Union[
Mapping[str, str],
Mapping[str, "TfdsSplit"],
]
] = None,
decoders=None,
):
"""LazyTfdsLoader constructor.
Args:
name: str, the name of the TFDS dataset.
name: str (optional), the name of the TFDS dataset. If `name` is not
specified then `split_map` values must be instances of `TfdsSplit`.
data_dir: str (optional), directory to read/write TFDS data.
split_map: dict (optional), mapping from canonical splits (e.g.,
'validation') to TFDS splits or slices (e.g., 'train[':1%']).
'validation') to TFDS splits (e.g. 'train'), or slices (e.g.,
'train[':1%']), or `TfdsSplit` (e.g. `TfdsSplit(dataset='mnist',
split='train')`). If `TfdsSplit` are used then `name` must be empty.
decoders: dict (optional), mapping from features to tfds.decode.Decoders,
such as tfds.decode.SkipDecoding() for skipping image byte decoding
"""
if (
name
and ":" not in name
):
raise ValueError(f"TFDS name must contain a version number, got: {name}")
_validate_tfds_name(name)
self._name = name
self._data_dir = data_dir
self._split_map = split_map
self._decoders = decoders

self._is_custom_split_map = False
if split_map:
random_split_value = next(iter(split_map.values()))
if isinstance(random_split_value, TfdsSplit):
self._is_custom_split_map = True
if self._name or self._data_dir:
raise ValueError(
"If split values are instances of `TfdsSplit`, `name` and"
" `data_dir` must be `None`."
)

@property
def name(self):
return self._name
Expand All @@ -138,7 +180,13 @@ def __repr__(self):

@property
def data_dir(self):
"""Returns the data directory fot this TFDS dataset."""
"""Returns the data directory for this TFDS dataset."""

if self._is_custom_split_map:
logging.warning(
"`LazyTfdsLoader` refers to multiple datasets, `data_dir` is unknown."
)
return None


if (
Expand All @@ -160,45 +208,85 @@ def read_config(self):
return _TFDS_DATA_READ_CONFIG_OVERRIDE
return tfds.ReadConfig()

@functools.cached_property
def _builder_key(self) -> Tuple[str, Optional[str]]:
return (self.name, self.data_dir)
def _get_builder_key(
self, dataset: Optional[str], data_dir: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
return (dataset, data_dir)

@property
def is_memoized(self) -> bool:
return self._builder_key in LazyTfdsLoader._MEMOIZED_BUILDERS
if self._is_custom_split_map:
return all(
self._get_builder_key(tfds_split.dataset, tfds_split.data_dir)
in LazyTfdsLoader._MEMOIZED_BUILDERS
for tfds_split in self._split_map.values()
)
else:
return (
self._get_builder_key(self.name, self.data_dir)
in LazyTfdsLoader._MEMOIZED_BUILDERS
)

@property
def builder(self):
return self._get_builder()

def _get_builder(self, split: Optional[str] = None):
"""Returns the DatasetBuilder for this TFDS dataset."""
if not self.is_memoized:
if self.name:
LazyTfdsLoader._MEMOIZED_BUILDERS[self._builder_key] = tfds.builder(
self.name, data_dir=self.data_dir
if self._is_custom_split_map:
if mapped_split := self._split_map.get(split):
dataset = mapped_split.dataset
data_dir = mapped_split.data_dir
else:
raise ValueError(
"`LazyTfdsLoader` refers to multiple datasets, pass `split` to"
" `_get_builder()`."
)
else:
dataset = self.name
data_dir = self.data_dir
builder_key = self._get_builder_key(dataset, data_dir)
if builder_key not in LazyTfdsLoader._MEMOIZED_BUILDERS:
if dataset:
LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] = tfds.builder(
dataset, data_dir=data_dir
)
else:
LazyTfdsLoader._MEMOIZED_BUILDERS[self._builder_key] = (
tfds.builder_from_directory(self.data_dir)
LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] = (
tfds.builder_from_directory(data_dir)
)
return LazyTfdsLoader._MEMOIZED_BUILDERS[self._builder_key]
return LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key]

@property
def info(self):
return self.builder.info

def _map_split(self, split: str):
return self._split_map[split] if self._split_map else split
def _map_split(self, split: str) -> Optional[str]:
"""Maps the given split to a dataset split."""
if self._is_custom_split_map:
self._split_map: Mapping[str, TfdsSplit]
return self._split_map[split].split
elif self._split_map:
self._split_map: Mapping[str, str]
return self._split_map[split]
else:
return split

def files(self, split: str):
"""Returns set of instructions for reading TFDS files for the dataset."""
split = self._map_split(split)
dataset_split = self._map_split(split)
builder = self._get_builder(split)

if "/" not in self.name and self.builder.BUILDER_CONFIGS:
if (
self.name is not None
and "/" not in self.name
and builder.BUILDER_CONFIGS
):
# If builder has multiple configs, and no particular config was
# requested, raise an error.
raise ValueError("Dataset '%s' has multiple configs." % self.name)

split_info = self.builder.info.splits[split]
split_info = builder.info.splits[dataset_split]
files = split_info.file_instructions

if not files:
Expand All @@ -213,7 +301,13 @@ def load(
shard_info=None,
):
"""Returns a tf.data.Dataset for the given split."""
split = self._map_split(split)
dataset_split = self._map_split(split)
if self._is_custom_split_map:
name = self._split_map[split].dataset
data_dir = self._split_map[split].data_dir
else:
name = self.name
data_dir = self.data_dir
read_config = self.read_config
read_config.input_context = (
tf.distribute.InputContext( # pylint: disable=g-long-ternary
Expand All @@ -226,9 +320,9 @@ def load(
read_config.shuffle_seed = seed
read_config.skip_prefetch = True
return tfds.load(
self._name,
split=split,
data_dir=self.data_dir,
name,
split=dataset_split,
data_dir=data_dir,
shuffle_files=shuffle_files,
download=True,
try_gcs=True,
Expand All @@ -254,9 +348,9 @@ def load_shard(

def size(self, split: str) -> Optional[int]:
"""Returns the number of examples in the split."""
split = self._map_split(split)
ds_splits = self.info.splits
dataset_size = ds_splits[split].num_examples
dataset_split = self._map_split(split)
ds_splits = self._get_builder(split).info.splits
dataset_size = ds_splits[dataset_split].num_examples
# Very large datasets have num_examples = 0; default instead to np.inf
dataset_size = dataset_size if dataset_size > 0 else np.inf
return dataset_size
Expand Down
70 changes: 67 additions & 3 deletions seqio/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,13 @@ def test_split_map(self, mock_tfds_load):
info=mock.Mock(
splits={
"validation": mock.Mock(
num_examples=420, file_instructions=["f1", "f2"]
name="validation",
num_examples=420,
file_instructions=["f1", "f2"],
),
"test": mock.Mock(
name="test", num_examples=42, file_instructions=["f3"]
),
"test": mock.Mock(num_examples=42, file_instructions=["f3"]),
}
)
)
Expand Down Expand Up @@ -153,6 +157,66 @@ def test_split_map(self, mock_tfds_load):
with self.assertRaises(KeyError):
ds.files(split="test")

@mock.patch("tensorflow_datasets.load")
def test_tfds_split(self, mock_tfds_load):
utils.LazyTfdsLoader._MEMOIZED_BUILDERS[("ds/c1:1.0.0", None)] = mock.Mock(
info=mock.Mock(
splits={
"validation": mock.Mock(
name="validation",
num_examples=420,
file_instructions=["f1", "f2"],
),
}
)
)
utils.LazyTfdsLoader._MEMOIZED_BUILDERS[("ds/c2:1.0.0", None)] = mock.Mock(
info=mock.Mock(
splits={
"test": mock.Mock(
name="test", num_examples=42, file_instructions=["f3"]
),
}
)
)
split_map = {
"train": utils.TfdsSplit(dataset="ds/c1:1.0.0", split="validation"),
"test": utils.TfdsSplit(dataset="ds/c2:1.0.0", split="test"),
}

with self.assertRaisesWithLiteralMatch(
ValueError,
"If split values are instances of `TfdsSplit`, `name` and"
" `data_dir` must be `None`.",
):
utils.LazyTfdsLoader("ds/c1:1.0.0", split_map=split_map)
ds = utils.LazyTfdsLoader(split_map=split_map)

# test .load()
ds.load("train", shuffle_files=False, seed=42)
mock_tfds_load.assert_called_once_with(
"ds/c1:1.0.0",
split="validation",
data_dir=None,
shuffle_files=False,
download=True,
try_gcs=True,
read_config=AnyArg(),
decoders=None,
)

# test .size()
self.assertEqual(420, ds.size(split="train"))
self.assertEqual(42, ds.size(split="test"))
with self.assertRaises(KeyError):
ds.size(split="validation")

# test .files()
self.assertListEqual(["f1", "f2"], ds.files(split="train"))
self.assertListEqual(["f3"], ds.files(split="test"))
with self.assertRaises(KeyError):
ds.files(split="validation")

@mock.patch("tensorflow_datasets.load")
def test_read_config_override_default(self, mock_tfds_load):
ds = utils.LazyTfdsLoader(
Expand Down Expand Up @@ -735,7 +799,7 @@ def test_trim_and_pad_dataset(self):
"idx": [1, 2],
},
]
ds = ds = tf.data.Dataset.from_generator(
ds = tf.data.Dataset.from_generator(
lambda: x,
output_signature={
"inputs": tf.TensorSpec([None], tf.int32),
Expand Down

0 comments on commit 93f0dd2

Please sign in to comment.