-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose data via iterator #39
Comments
Hi @apaleyes, Thanks for checking out APEBench. Initially, I designed APEBench for emulator training use cases for which everything fits into memory (back for the paper, I used 24GB GPUs). Hence, there is not yet an interface for iterating over the dataset without fully instantiating it. However, I am open to adding one if you have a cool idea 😊. Generally speaking, the problem is that for mini batching across time, one must produce the entire trajectory since later points in time obviously depend on earlier ones. Still, since the reference simulator is sufficiently fast, I can imagine a setting where minibatch entries could be procedurally regenerated based on trajectory and time index. Below is a quick hack for such a dynamic generator (haven't had the chance to test is much yet, lmk what you think): class DynamicDataGenerator(eqx.Module):
scenario: apebench.BaseScenario
mixer: apebench.trainax.PermutationMixer
keys: jax.random.PRNGKey
def __init__(self, scenario, shuffle_key):
self.scenario = scenario
self.mixer = apebench.trainax.PermutationMixer(
# Valid for one-step supervised training
num_total_samples=scenario.train_temporal_horizon * scenario.num_test_samples,
num_minibatches=scenario.num_training_steps,
batch_size=scenario.batch_size,
shuffle_key=shuffle_key,
)
# splitting into the relevant keys would have also worked, but this is
# to correctly replicate the behavior of apebench.exponax.build_ic_set
_, self.keys = jax.lax.scan(
lambda key, _: (*jax.random.split(key),),
jax.random.key(scenario.train_seed),
None,
length=scenario.num_train_samples,
)
def generate_sample(self, trajectory_index, time_index):
ic = self.scenario.get_ic_generator()(self.scenario.num_points, key=self.keys[trajectory_index])
ref_stepper = self.scenario.get_ref_stepper()
sample_state_start = apebench.exponax.repeat(ref_stepper, time_index)(ic)
# Again assuming one-step supervised training, needs a window of two
# consecutive snapshots
sample_window = apebench.exponax.rollout(ref_stepper, 1, include_init=True)(sample_state_start)
return sample_window
def __call__(self, batch_index):
sample_indices = self.mixer(batch_index)
trajectory_indices = sample_indices // self.scenario.train_temporal_horizon
time_indices = sample_indices % self.scenario.train_temporal_horizon
# Below is the optimal way of computing this in parallel. Unfortunately,
# vmap will not work because the `time_indices` affect the scan length.
# Maybe there is a workaround
# samples = jax.vmap(self.generate_sample)(trajectory_indices, time_indices)
# Quite slow alternative is to use a loop
samples = jnp.stack(
[
self.generate_sample(trajectory_index, time_index)
for trajectory_index, time_index in zip(trajectory_indices, time_indices)
]
)
return samples You can use it like advection_scenario = apebench.scenarios.difficulty.Advection()
dynamic_generator = DynamicDataGenerator(advection_scenario, jax.random.PRNGKey(0))
for i in range(advection_scenario.num_training_steps):
batch = dynamic_generator(i)
# ... forward, backward, parameter update etc. Alternatively, you can generate and hold the data in CPU memory and stream it to the GPU/TPU with a custom dataloader (using |
Hi team, thanks for sharing a great library!
I am looking into using APEbench to benchmark some models implemented with flax.linen. Following the tutorial here it suggests to generate train/test data upfront. I suspect it might be problematic memory-wise, e.g.:
7.8GB in memory already! And this can easily grow if one would like to consider more time steps, more points in the domain, or more train/test samples.
Have you considered an iterator interface to access data, so that external models could only see one (batch) of data samples at a time? Or maybe this interface already exists, I just couldn't locate it?
The text was updated successfully, but these errors were encountered: