From 7cdea1638ceb2a3c0019e95c0a63f0c36605064a Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 30 Nov 2024 23:07:52 +0530 Subject: [PATCH] adding key_data to check the CI tests --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..0024c35d4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -210,7 +210,7 @@ def train_once( ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - + data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -336,7 +336,7 @@ def train_once( step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) - + eval_rng = jax.random.key_data(eval_rng) with profiler.profile('Data selection'): batch = data_selection(workload, input_queue,