From 8367b39af87e06292a2ed51ca99e454e13790660 Mon Sep 17 00:00:00 2001 From: KoyamaSohei Date: Tue, 13 Jun 2023 22:00:22 +0900 Subject: [PATCH] Fix typo (#935) --- examples/mnist/train_ray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mnist/train_ray.py b/examples/mnist/train_ray.py index 1ec664ff0..6bb8666de 100644 --- a/examples/mnist/train_ray.py +++ b/examples/mnist/train_ray.py @@ -118,7 +118,7 @@ def create_train_state(rng, config): apply_fn=cnn.apply, params=params, tx=tx) -def get_train_data_laoder(train_ds, state, batch_size): +def get_train_data_loader(train_ds, state, batch_size): images_np = train_ds['image'] labels_np = train_ds['label'] steps_per_epoch = len(images_np) // batch_size @@ -163,7 +163,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict, rng = jax.random.PRNGKey(0) state = create_train_state(rng, config) - train_data_loader, steps_per_epoch = get_train_data_laoder( + train_data_loader, steps_per_epoch = get_train_data_loader( train_ds, state, config.batch_size) for epoch in range(1, config.num_epochs + 1):