Skip to content

Commit

Permalink
Make seqio caching work when some dataset elements are ragged tensors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 556835998
  • Loading branch information
rhofour authored and SeqIO committed Aug 17, 2023
1 parent 6594a3c commit 9a29d46
Show file tree
Hide file tree
Showing 10 changed files with 361 additions and 19 deletions.
19 changes: 14 additions & 5 deletions seqio/beam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def _emit_examples(self, shard: Tuple[int, str]):
ds = task.preprocess_precache(ds, seed=shard_preprocessors_seed)
ds = ds.prefetch(tf.data.AUTOTUNE)

def _add_provenance(index_within_shard: int, ex: Dict[str, Any]):
def _add_provenance(
index_within_shard: int, ex: Dict[str, Any]) -> Dict[str, Any]:
ex.update({
TASK_PROVENANCE_KEY: self._task_name,
SOURCE_SHARD_PROVENANCE_KEY: shard_name,
Expand All @@ -153,7 +154,7 @@ def _add_provenance(index_within_shard: int, ex: Dict[str, Any]):
ex.update({PREPROCESSORS_SEED_PROVENANCE_KEY: self._preprocessors_seed})
return ex

for i, ex in enumerate(ds.as_numpy_iterator()):
for i, ex in enumerate(ds):
if self._add_provenance:
ex = _add_provenance(i, ex)
self._increment_counter("examples")
Expand Down Expand Up @@ -255,7 +256,10 @@ def _info_dict(self, ex: List[Dict[str, Any]]):
for k, v in ex.items():
if self._exclude_provenance and k.startswith(PROVENANCE_PREFIX):
continue
t = tf.constant(v)
if isinstance(v, tf.RaggedTensor):
t = v
else:
t = tf.constant(v)
dtype = t.dtype.name
shape = t.shape.as_list()
# Keep all the dimensions but the first if t is not a scalar.
Expand Down Expand Up @@ -291,10 +295,14 @@ def process(self, ex: Mapping[str, Any]) -> Iterable[Tuple[str, int]]:
for name, feat in self._output_features.items():
if (
name in ex
and isinstance(ex[name], np.ndarray)
and (isinstance(ex[name], np.ndarray)
or isinstance(ex[name], tf.Tensor))
and ex[name].dtype in (np.int32, np.int64)
):
values = ex[name]
if isinstance(ex[name], tf.Tensor):
values = ex[name].numpy()
else:
values = ex[name]
conditions = []
if feat.vocabulary.eos_id is not None:
conditions.append((values != feat.vocabulary.eos_id))
Expand Down Expand Up @@ -414,6 +422,7 @@ def __init__(
self._output_features = output_features
self._task_ids = task_ids or {}
self._enable_char_counts = enable_char_counts
logging.info("Getting stats for output features: %s", str(output_features))

def expand(self, pcoll):
example_counts = (
Expand Down
13 changes: 12 additions & 1 deletion seqio/beam_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@ class BeamUtilsTest(seqio.test_utils.FakeTaskTest):

def test_preprocess_task(self):
def _np_to_list(ex):
def _convert_value(v):
if isinstance(v, tf.Tensor):
v = v.numpy()
if isinstance(v, np.ndarray):
v = v.tolist()
return v
return {
k: v.tolist() if isinstance(v, np.ndarray) else v
k: _convert_value(v)
for k, v in ex.items()
}

Expand Down Expand Up @@ -125,6 +131,10 @@ def test_get_info(self):
"inputs": "test",
"2d_shape": np.ones((1, 3), np.int32),
"3d_shape": np.ones((1, 2, 3), np.int32),
"ragged_shape": tf.RaggedTensor.from_row_splits(
tf.constant([[3, 1], [4, 1], [5, 9]]),
row_splits=[0, 1, 3],
validate=True),
}]
with TestPipeline() as p:
pcoll = p | beam.Create(input_examples) | beam_utils.GetInfo(num_shards=3)
Expand All @@ -141,6 +151,7 @@ def test_get_info(self):
"shape": [None, 2, 3],
"dtype": "int32",
},
"ragged_shape": {"shape": [None, None, 2], "dtype": "int32"},
},
"seqio_version": seqio.__version__,
}]),
Expand Down
27 changes: 25 additions & 2 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,19 +926,42 @@ def __init__(
feat["dtype"] = "int32"

# Use `FixedLenSequenceFeature` for sequences with variable length.
def _feature_config(shape, dtype):
def _feature_config(
key: str,
shape,
dtype: str,
) -> Union[tf.io.FixedLenFeature, tf.io.RaggedFeature]:
if dtype in ("int32", "bool"):
# int32 and bool are stored as int64 in the tf.train.Example protobuf.
# TODO(adarob): Support other conversions.
dtype = "int64"
if shape:
num_none_components = 0
for x in shape[1:]:
if x is None:
num_none_components += 1
if num_none_components > 0: # Parse as a ragged feature.
partitions = []
ragged_idx = 0
for x in shape[1:]:
if x is None:
partitions.append(tf.io.RaggedFeature.RowLengths(
utils.tfexample_ragged_length_key(key, ragged_idx)))
ragged_idx += 1
else:
partitions.append(tf.io.RaggedFeature.UniformRowLength(x))
return tf.io.RaggedFeature(
value_key=key,
partitions=partitions,
dtype=dtype)
if shape and shape[0] is None:
return tf.io.FixedLenSequenceFeature(
shape[1:], dtype, allow_missing=True
)
return tf.io.FixedLenFeature(shape, dtype)

feature_description = {
feat: _feature_config(**desc) for feat, desc in features.items()
feat: _feature_config(feat, **desc) for feat, desc in features.items()
}

def read_file_fn(filepattern):
Expand Down
37 changes: 37 additions & 0 deletions seqio/scripts/cache_tasks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def validate_pipeline(
expected_task_dir="cached_task",
token_preprocessed=False,
ndfeatures=False,
ragged_features=False,
num_shards=2,
):
self.assertTrue(TaskRegistry.get("cached_task").cache_dir)
Expand Down Expand Up @@ -107,6 +108,8 @@ def validate_pipeline(
splits=task.splits,
token_preprocessed=token_preprocessed,
ndfeatures=ndfeatures,
ragged_features=ragged_features,
num_shards=num_shards,
)

def test_tfds_pipeline(self):
Expand Down Expand Up @@ -163,6 +166,40 @@ def test_cache_before_tokenization_ndfeatures_pipeline(self):
ndfeatures=True,
)

def test_cache_before_tokenization_ragged_features_pipeline(self):
self.add_task(
"task_tokenized_postcache_ragged_features",
seqio.dataset_providers.FunctionDataSource(
dataset_fn=functools.partial(
test_utils.get_fake_dataset, ragged_features=True
),
splits=["train", "validation"],
),
output_features={
"inputs": seqio.Feature(test_utils.sentencepiece_vocab()),
"targets": seqio.Feature(test_utils.sentencepiece_vocab()),
"ragged_feature": seqio.Feature(
seqio.PassThroughVocabulary(1000, eos_id=0),
add_eos=False,
rank=3,
),
},
preprocessors=[
test_utils.test_text_preprocessor,
seqio.CacheDatasetPlaceholder(),
seqio.preprocessors.tokenize,
test_utils.token_preprocessor_no_sequence_length,
seqio.preprocessors.append_eos_after_trim,
],
)
self.validate_pipeline(
"task_tokenized_postcache_ragged_features",
expected_task_dir="cached_untokenized_ragged_features_task",
num_shards=1,
token_preprocessed=True,
ragged_features=True,
)

def test_cache_before_tokenization_pipeline(self):
self.add_task(
"task_tokenized_postcache",
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"features": {
"inputs": {
"dtype": "string",
"shape": []
},
"ragged_feature": {
"dtype": "int32",
"shape": [
null,
null,
2
]
},
"targets": {
"dtype": "string",
"shape": []
}
},
"num_shards": 1,
"seqio_version": "0.0.0"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"features": {
"id": {
"dtype": "string",
"shape": []
},
"ids": {
"dtype": "string",
"shape": [
null
]
},
"idx": {
"dtype": "int64",
"shape": []
},
"idxs": {
"dtype": "int32",
"shape": [
null
]
},
"inputs": {
"dtype": "string",
"shape": []
},
"ragged_feature": {
"dtype": "int32",
"shape": [
null,
null,
2
]
},
"targets": {
"dtype": "string",
"shape": []
}
},
"num_shards": 1,
"seqio_version": "0.0.0"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"examples": 3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"examples": 2
}
Loading

0 comments on commit 9a29d46

Please sign in to comment.