diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..0024c35d4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -210,7 +210,7 @@ def train_once( ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - + data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -336,7 +336,7 @@ def train_once( step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) - + eval_rng = jax.random.key_data(eval_rng) with profiler.profile('Data selection'): batch = data_selection(workload, input_queue,