diff --git a/vit_jax/input_pipeline.py b/vit_jax/input_pipeline.py index 9161a61..b3079d8 100644 --- a/vit_jax/input_pipeline.py +++ b/vit_jax/input_pipeline.py @@ -243,7 +243,7 @@ def _shard(data): def prefetch(dataset, n_prefetch): """Prefetches data to device and converts to numpy array.""" ds_iter = iter(dataset) - ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), + ds_iter = map(lambda x: jax.tree.map(lambda t: np.asarray(memoryview(t)), x), ds_iter) if n_prefetch: ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch) diff --git a/vit_jax/models_test.py b/vit_jax/models_test.py index 40f0fcd..0e9bc64 100644 --- a/vit_jax/models_test.py +++ b/vit_jax/models_test.py @@ -79,7 +79,7 @@ def test_can_instantiate(self, name, size): self.assertEqual((2, 196, 1000), outputs.shape) else: self.assertEqual((2, 1000), outputs.shape) - param_count = sum(p.size for p in jax.tree_flatten(variables)[0]) + param_count = sum(p.size for p in jax.tree.flatten(variables)[0]) self.assertEqual( size, param_count, f'Expected {name} to have {size} params, found {param_count}.') diff --git a/vit_jax/test_utils.py b/vit_jax/test_utils.py index d1256cf..93e3836 100644 --- a/vit_jax/test_utils.py +++ b/vit_jax/test_utils.py @@ -52,7 +52,7 @@ def _tree_flatten_with_names(tree): Returns: A list of values with names: [(name, value), ...] """ - vals, tree_def = jax.tree_flatten(tree) + vals, tree_def = jax.tree.flatten(tree) # "Fake" token tree that is use to track jax internal tree traversal and # adjust our custom tree traversal to be compatible with it. diff --git a/vit_jax/train.py b/vit_jax/train.py index 365adef..8827b51 100644 --- a/vit_jax/train.py +++ b/vit_jax/train.py @@ -60,7 +60,7 @@ def loss_fn(params, images, labels): l, g = utils.accumulate_gradient( jax.value_and_grad(loss_fn), params, batch['image'], batch['label'], accum_steps) - g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g) + g = jax.tree.map(lambda x: jax.lax.pmean(x, axis_name='batch'), g) updates, opt_state = tx.update(g, opt_state) params = optax.apply_updates(params, updates) l = jax.lax.pmean(l, axis_name='batch') diff --git a/vit_jax/utils.py b/vit_jax/utils.py index c880f0b..526c810 100644 --- a/vit_jax/utils.py +++ b/vit_jax/utils.py @@ -111,9 +111,9 @@ def acc_grad_and_loss(i, l_and_g): (step_size, labels.shape[1])) li, gi = loss_and_grad_fn(params, imgs, lbls) l, g = l_and_g - return (l + li, jax.tree_map(lambda x, y: x + y, g, gi)) + return (l + li, jax.tree.map(lambda x, y: x + y, g, gi)) l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g)) - return jax.tree_map(lambda x: x / accum_steps, (l, g)) + return jax.tree.map(lambda x: x / accum_steps, (l, g)) else: return loss_and_grad_fn(params, images, labels)