Skip to content

Commit

Permalink
adding key_data to check the CI tests
Browse files Browse the repository at this point in the history
  • Loading branch information
init-22 committed Nov 30, 2024
1 parent 7a0fee3 commit 7cdea16
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def train_once(
) -> Tuple[spec.Timing, Dict[str, Any]]:
_reset_cuda_mem()
data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)

data_rng = jax.random.key_data(data_rng)
# Workload setup.
logging.info('Initializing dataset.')
if hasattr(workload, '_eval_num_workers'):
Expand Down Expand Up @@ -336,7 +336,7 @@ def train_once(

step_rng = prng.fold_in(rng, global_step)
data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3)

eval_rng = jax.random.key_data(eval_rng)
with profiler.profile('Data selection'):
batch = data_selection(workload,
input_queue,
Expand Down

0 comments on commit 7cdea16

Please sign in to comment.