Skip to content

Commit

Permalink
Add a function to easily override a TFDS data dir in an instantiated …
Browse files Browse the repository at this point in the history
…source

PiperOrigin-RevId: 670950104
  • Loading branch information
tomvdw authored and SeqIO committed Sep 4, 2024
1 parent cf8c3e2 commit a4183b6
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 15 deletions.
2 changes: 1 addition & 1 deletion seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def __init__(
@property
def splits(self):
"""Overrides since we can't call `info.splits` until after init."""
return self._splits or self._tfds_dataset.info.splits
return self._splits or self.tfds_dataset.info.splits

@property
def tfds_dataset(self) -> utils.LazyTfdsLoader:
Expand Down
37 changes: 25 additions & 12 deletions seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ def _validate_tfds_name(name: str) -> None:
raise ValueError(f"TFDS name must contain a version number, got: {name}")


def _get_data_dir_override(tfds_name: Optional[str]) -> Optional[str]:
"""Returns the data dir in case it is overridden."""
if (
_TFDS_DATA_DIR_OVERRIDE
):
return _TFDS_DATA_DIR_OVERRIDE
return None


@dataclasses.dataclass(frozen=True)
class TfdsSplit:
"""Points to a specific TFDS split.
Expand Down Expand Up @@ -148,6 +157,7 @@ def __init__(
_validate_tfds_name(name)
self._name = name
self._data_dir = data_dir
self._data_dir_override = None
self._split_map = split_map
self._decoders = decoders
self._builder_kwargs = builder_kwargs
Expand Down Expand Up @@ -238,19 +248,22 @@ def data_dir(self) -> Optional[str]:
)
return None

if self._data_dir_override is None:
if data_dir_override := _get_data_dir_override(tfds_name=self.name):
self._data_dir_override = data_dir_override

if (
_TFDS_DATA_DIR_OVERRIDE
):
if self._data_dir:
logging.warning(
"Overriding TFDS data directory '%s' with '%s' for dataset '%s'.",
self._data_dir,
_TFDS_DATA_DIR_OVERRIDE,
self.name,
)
return _TFDS_DATA_DIR_OVERRIDE
return self._data_dir
if self._data_dir_override and self._data_dir:
logging.warning(
"Overriding TFDS data directory '%s' with '%s' for dataset '%s'.",
self._data_dir,
self._data_dir_override,
self.name,
)

return self._data_dir_override or self._data_dir

def override_data_dir(self, data_dir: str) -> None:
self._data_dir_override = data_dir

@property
def read_config(self):
Expand Down
37 changes: 35 additions & 2 deletions seqio/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,41 @@ def test_read_config_override(self, mock_tfds_load):
# reset to default global override
utils.set_tfds_read_config_override(None)

@mock.patch("tensorflow_datasets.builder")
def test_override_data_dir(self, mock_tfds_builder):
mock_builder1 = mock.create_autospec(tfds.core.DatasetBuilder)
mock_builder2 = mock.create_autospec(tfds.core.DatasetBuilder)
mock_builder3 = mock.create_autospec(tfds.core.DatasetBuilder)
mock_tfds_builder.side_effect = [
mock_builder1,
mock_builder2,
mock_builder3,
]

orig_data_dir = "/data"
override1 = "/override1"
override2 = "/override2"

utils.set_tfds_data_dir_override(override1)

# Should use `override1` that was set globally.
loader = utils.LazyTfdsLoader(name="a/b:1.0.0", data_dir=orig_data_dir)
self.assertEqual(override1, loader.data_dir)
self.assertEqual(loader.builder, mock_builder1)

loader.override_data_dir(override2)
self.assertEqual(override2, loader.data_dir)
self.assertEqual(loader.builder, mock_builder2)

# Set back to original data dir and check whether the cache works.
loader.override_data_dir(orig_data_dir)
self.assertEqual(orig_data_dir, loader.data_dir)
self.assertEqual(loader.builder, mock_builder3)
self.assertEqual(mock_tfds_builder.call_count, 3)

# Unset it to not influence other tests.
utils.set_tfds_data_dir_override(None)



class TransformUtilsTest(parameterized.TestCase):
Expand Down Expand Up @@ -375,7 +410,6 @@ def fn(ex, seed):
mapped_ds = fn(ds) # pylint: disable=no-value-for-parameter
results = [7, 5, 6, 6, 7, 11, 12, 16, 15, 15]
expected_ds = [{"field": results[i]} for i in range(10)]
print("gaurav", list(mapped_ds.as_numpy_iterator()))
self.assertListEqual(list(mapped_ds.as_numpy_iterator()), expected_ds)

def test_random_map_fn_with_kwargs(self):
Expand Down Expand Up @@ -432,7 +466,6 @@ def fn(ex, seeds, val, sequence_length):
mapped_ds = map_fn(ds, sequence_length={"field": -1}) # pylint: disable=no-value-for-parameter
results = [13, 15, 16, 13, 9, 16, 15, 26, 18, 16]
expected_ds = [{"field": results[i]} for i in range(10)]
print("gaurav", list(mapped_ds.as_numpy_iterator()))
self.assertListEqual(list(mapped_ds.as_numpy_iterator()), expected_ds)


Expand Down

0 comments on commit a4183b6

Please sign in to comment.