Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested Dataset Slowdown #593

Open
vyeevani opened this issue Oct 3, 2024 · 1 comment
Open

Nested Dataset Slowdown #593

vyeevani opened this issue Oct 3, 2024 · 1 comment

Comments

@vyeevani
Copy link

vyeevani commented Oct 3, 2024

TLDR: I'm working with robotics datasets. They are expressed a nested dataset. Grain calls repr on the inner dataset which slows down the loops significantly.

I think grain has serious benefits over tfds for nested datasets like this. Currently, rlds is the best way of handling these types of datasets. It somewhat abuses tfds dataset manipulation techniques since it doesn't have a better of handling this. I think grain would be able to simplify a lot of these workflows and allow for much more complex things to be created, i.e stitching multiple robot episodes together in unique ways that aren't easily expressible in tfds or rlds.

Fixing the above problem would be hugely beneficial and doesn't feel like it's super necessary.

Simple example to showcase the problem:

builder = tfds.builder_from_directory(builder_dir=dataset_path)
episode_data_source = builder.as_data_source("train", deserialize_method=tfds.core.decode.DeserializeMethod.DESERIALIZE_AND_DECODE)
episode_index_sampler = grain.IndexSampler(
    num_records=2,
    num_epochs=1,
    shard_options=grain.ShardOptions(shard_index=0, shard_count=1, drop_remainder=True),
    shuffle=True,
    seed=0
)

steps_index_sampler = grain.IndexSampler(
    num_records=2,
    num_epochs=1,
    shard_options=grain.ShardOptions(shard_index=0, shard_count=1, drop_remainder=True),
    shuffle=True,
    seed=0
)

import pyinstrument
profiler = pyinstrument.Profiler()

profiler.start()

episode_data_loader = grain.DataLoader(data_source=episode_data_source, sampler=episode_index_sampler)
for episode_data in episode_data_loader:
    steps_data_source = episode_data[rlds.STEPS]
    steps_data_loader = grain.DataLoader(data_source=steps_data_source, sampler=steps_index_sampler)
    for steps_data in steps_data_loader:
        pass

profiler.stop()

# Save the flamegraph to an HTML file
with open('flamegraph.html', 'w') as f:
    f.write(profiler.output_html())

Specific slowdowns happen when creating the state/validating the state.

@vyeevani
Copy link
Author

vyeevani commented Oct 3, 2024

Screenshot 2024-10-03 at 9 16 48 AM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant