Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose data via iterator #39

Open
apaleyes opened this issue Dec 20, 2024 · 1 comment
Open

Expose data via iterator #39

apaleyes opened this issue Dec 20, 2024 · 1 comment

Comments

@apaleyes
Copy link

Hi team, thanks for sharing a great library!

I am looking into using APEbench to benchmark some models implemented with flax.linen. Following the tutorial here it suggests to generate train/test data upfront. I suspect it might be problematic memory-wise, e.g.:

advection_scenario = apebench.scenarios.physical.Advection(
    num_train_samples=10,
    train_temporal_horizon=50,
    num_spatial_dims=3
)

train_data = advection_scenario.get_train_data()
train_data.nbytes / (1024**3)

> 7.781982421875

7.8GB in memory already! And this can easily grow if one would like to consider more time steps, more points in the domain, or more train/test samples.

Have you considered an iterator interface to access data, so that external models could only see one (batch) of data samples at a time? Or maybe this interface already exists, I just couldn't locate it?

@Ceyron
Copy link
Collaborator

Ceyron commented Dec 21, 2024

Hi @apaleyes,

Thanks for checking out APEBench.

Initially, I designed APEBench for emulator training use cases for which everything fits into memory (back for the paper, I used 24GB GPUs). Hence, there is not yet an interface for iterating over the dataset without fully instantiating it. However, I am open to adding one if you have a cool idea 😊. Generally speaking, the problem is that for mini batching across time, one must produce the entire trajectory since later points in time obviously depend on earlier ones. Still, since the reference simulator is sufficiently fast, I can imagine a setting where minibatch entries could be procedurally regenerated based on trajectory and time index.

Below is a quick hack for such a dynamic generator (haven't had the chance to test is much yet, lmk what you think):

class DynamicDataGenerator(eqx.Module):
    scenario: apebench.BaseScenario
    mixer: apebench.trainax.PermutationMixer
    keys: jax.random.PRNGKey

    def __init__(self, scenario, shuffle_key):
        self.scenario = scenario
        self.mixer = apebench.trainax.PermutationMixer(
            # Valid for one-step supervised training
            num_total_samples=scenario.train_temporal_horizon * scenario.num_test_samples,
            num_minibatches=scenario.num_training_steps,
            batch_size=scenario.batch_size,
            shuffle_key=shuffle_key,
        )
        # splitting into the relevant keys would have also worked, but this is
        # to correctly replicate the behavior of apebench.exponax.build_ic_set
        _, self.keys = jax.lax.scan(
            lambda key, _: (*jax.random.split(key),),
            jax.random.key(scenario.train_seed),
            None,
            length=scenario.num_train_samples,
        )

    def generate_sample(self, trajectory_index, time_index):
        ic = self.scenario.get_ic_generator()(self.scenario.num_points, key=self.keys[trajectory_index])
        ref_stepper = self.scenario.get_ref_stepper()
        sample_state_start = apebench.exponax.repeat(ref_stepper, time_index)(ic)
        # Again assuming one-step supervised training, needs a window of two
        # consecutive snapshots
        sample_window = apebench.exponax.rollout(ref_stepper, 1, include_init=True)(sample_state_start)
        return sample_window

    def __call__(self, batch_index):
        sample_indices = self.mixer(batch_index)
        trajectory_indices = sample_indices // self.scenario.train_temporal_horizon
        time_indices = sample_indices % self.scenario.train_temporal_horizon

        # Below is the optimal way of computing this in parallel. Unfortunately,
        # vmap will not work because the `time_indices` affect the scan length.
        # Maybe there is a workaround

        # samples = jax.vmap(self.generate_sample)(trajectory_indices, time_indices)

        # Quite slow alternative is to use a loop
        samples = jnp.stack(
            [
                self.generate_sample(trajectory_index, time_index)
                for trajectory_index, time_index in zip(trajectory_indices, time_indices)
            ]
        )

        return samples

You can use it like

advection_scenario = apebench.scenarios.difficulty.Advection()
dynamic_generator = DynamicDataGenerator(advection_scenario, jax.random.PRNGKey(0))

for i in range(advection_scenario.num_training_steps):
    batch = dynamic_generator(i)
    # ... forward, backward, parameter update etc.

Alternatively, you can generate and hold the data in CPU memory and stream it to the GPU/TPU with a custom dataloader (using jax.device_put).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants