From d21d8205d565c94d82b312709491deac0b31de31 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 22 Dec 2024 16:11:59 +0530 Subject: [PATCH] fix: changes jax.tree_map to jax.tree.map --- algorithmic_efficiency/data_utils.py | 2 +- algorithmic_efficiency/param_utils.py | 2 +- .../workloads/cifar/cifar_jax/workload.py | 2 +- .../imagenet_resnet/imagenet_jax/workload.py | 2 +- .../workloads/mnist/mnist_jax/workload.py | 2 +- .../workloads/ogbg/input_pipeline.py | 4 +-- .../workloads/ogbg/ogbg_pytorch/workload.py | 6 ++-- .../workloads/wmt/wmt_jax/decode.py | 8 ++--- .../workloads/wmt/wmt_jax/workload.py | 4 +-- .../workloads/wmt/wmt_pytorch/decode.py | 8 ++--- .../workloads/wmt/wmt_pytorch/workload.py | 2 +- .../external_tuning/jax_nadamw_full_budget.py | 16 +++++----- .../jax_nadamw_target_setting.py | 16 +++++----- .../self_tuning/jax_nadamw_full_budget.py | 16 +++++----- .../self_tuning/jax_nadamw_target_setting.py | 16 +++++----- .../cifar/cifar_jax/submission.py | 2 +- .../mnist/mnist_jax/submission.py | 2 +- .../adafactor/jax/sharded_adafactor.py | 16 +++++----- .../adafactor/jax/submission.py | 6 ++-- .../paper_baselines/adamw/jax/submission.py | 6 ++-- .../paper_baselines/lamb/jax/submission.py | 6 ++-- .../momentum/jax/submission.py | 6 ++-- .../paper_baselines/nadamw/jax/submission.py | 16 +++++----- .../nesterov/jax/submission.py | 6 ++-- .../paper_baselines/sam/jax/submission.py | 10 +++---- .../shampoo/jax/distributed_shampoo.py | 30 +++++++++---------- .../paper_baselines/shampoo/jax/submission.py | 6 ++-- .../target_setting_algorithms/jax_adamw.py | 2 +- .../target_setting_algorithms/jax_momentum.py | 2 +- .../target_setting_algorithms/jax_nadamw.py | 12 ++++---- .../target_setting_algorithms/jax_nesterov.py | 2 +- .../jax_submission_base.py | 4 +-- tests/modeldiffs/vanilla_sgd_jax.py | 2 +- tests/reference_algorithm_tests.py | 4 +-- .../imagenet_jax/workload_test.py | 2 +- 35 files changed, 124 insertions(+), 124 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 901f0b582..38a76381f 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree_map(_prepare, batch) + return jax.tree.map(_prepare, batch) def pad(tensor: np.ndarray, diff --git a/algorithmic_efficiency/param_utils.py b/algorithmic_efficiency/param_utils.py index b430366b1..916eb8728 100644 --- a/algorithmic_efficiency/param_utils.py +++ b/algorithmic_efficiency/param_utils.py @@ -66,7 +66,7 @@ def pytorch_param_types( def jax_param_shapes( params: spec.ParameterContainer) -> spec.ParameterShapeTree: - return jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params) + return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params) def jax_param_types(param_shapes: spec.ParameterShapeTree, diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index 6bbf9c64b..60f15c2f0 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -207,4 +207,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index 91cdec60a..4366fcf25 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -264,7 +264,7 @@ def _eval_model_on_split(self, eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples), + eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples), eval_metrics) return eval_metrics diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py index efbd73e33..dcb0b6f36 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py @@ -132,4 +132,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algorithmic_efficiency/workloads/ogbg/input_pipeline.py b/algorithmic_efficiency/workloads/ogbg/input_pipeline.py index a301d677a..3cb6f51de 100644 --- a/algorithmic_efficiency/workloads/ogbg/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogbg/input_pipeline.py @@ -51,7 +51,7 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir): def _to_jraph(example): """Converts an example graph to jraph.GraphsTuple.""" - example = jax.tree_map(lambda x: x._numpy(), example) # pylint: disable=protected-access + example = jax.tree.map(lambda x: x._numpy(), example) # pylint: disable=protected-access edge_feat = example['edge_feat'] node_feat = example['node_feat'] edge_index = example['edge_index'] @@ -150,7 +150,7 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): if count == num_shards: def f(x): - return jax.tree_map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) + return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) graphs_shards = f(graphs_shards) labels_shards = f(labels_shards) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index d4817226d..e66a7a151 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -20,8 +20,8 @@ def _pytorch_map(inputs: Any) -> Any: if USE_PYTORCH_DDP: - return jax.tree_map(lambda a: torch.as_tensor(a, device=DEVICE), inputs) - return jax.tree_map( + return jax.tree.map(lambda a: torch.as_tensor(a, device=DEVICE), inputs) + return jax.tree.map( lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1]) if len(a.shape) == 3 else torch.as_tensor(a, device=DEVICE).view(-1), inputs) @@ -30,7 +30,7 @@ def _pytorch_map(inputs: Any) -> Any: def _shard(inputs: Any) -> Any: if not USE_PYTORCH_DDP: return inputs - return jax.tree_map(lambda tensor: tensor[RANK], inputs) + return jax.tree.map(lambda tensor: tensor[RANK], inputs) def _graph_map(function: Callable, graph: GraphsTuple) -> GraphsTuple: diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py index 85d0eaac4..dfead5918 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py @@ -86,7 +86,7 @@ def gather_fn(x): return x return x[batch_indices, beam_indices] - return jax.tree_map(gather_fn, nested) + return jax.tree.map(gather_fn, nested) def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size): @@ -139,7 +139,7 @@ def beam_init(batch_size, beam_size, max_decode_len, cache): finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements - beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache) + beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -225,7 +225,7 @@ def beam_search_loop_body_fn(state): (batch_size, beam_size, 1))) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} - flat_cache = jax.tree_map(flatten_beam_dim, state.cache) + flat_cache = jax.tree.map(flatten_beam_dim, state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] @@ -236,7 +236,7 @@ def beam_search_loop_body_fn(state): logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} - new_cache = jax.tree_map( + new_cache = jax.tree.map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) # Gather log probabilities from logits diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 046d5e469..dd6728450 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -94,7 +94,7 @@ def eval_step(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: replicated_eval_metrics = self.eval_step_pmapped(params, batch) - return jax.tree_map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) + return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) @functools.partial( jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) @@ -291,7 +291,7 @@ def _normalize_eval_metrics( """Normalize eval metrics.""" del num_examples eval_denominator = total_metrics.pop('denominator') - return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics) + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) class WmtWorkloadPostLN(WmtWorkload): diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py index 0488a144f..078560c36 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py @@ -98,7 +98,7 @@ def gather_fn(x): return x return x[batch_indices, beam_indices] - return jax.tree_map(gather_fn, nested) + return jax.tree.map(gather_fn, nested) def gather_topk_beams(nested: Dict[str, Any], @@ -164,7 +164,7 @@ def beam_init(batch_size: int, dtype=torch.bool, device=DEVICE) # add beam dimension to attention cache pytree elements - beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache) + beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -251,7 +251,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: state.live_seqs[:batch_size, :beam_size, cur_index:cur_index + 1]) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} - flat_cache = jax.tree_map(flatten_beam_dim, state.cache) + flat_cache = jax.tree.map(flatten_beam_dim, state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] @@ -262,7 +262,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} - new_cache = jax.tree_map( + new_cache = jax.tree.map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) # Gather log probabilities from logits diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 0ba49c2f6..9c1c21e93 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -347,7 +347,7 @@ def _normalize_eval_metrics( dist.all_reduce(metric) total_metrics = {k: v.item() for k, v in total_metrics.items()} eval_denominator = total_metrics.pop('denominator') - return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics) + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) class WmtWorkloadPostLN(WmtWorkload): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 36e7e5607..30f9068d1 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 07281f540..71b1c5e1e 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 0d194ef7a..127e660d0 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -120,8 +120,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -132,7 +132,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -148,14 +148,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -200,7 +200,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters['beta2'], eps=1e-8, weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -248,7 +248,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -256,7 +256,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60fc25ec4..92c0f599c 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -120,8 +120,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -132,7 +132,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -148,14 +148,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -200,7 +200,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters['beta2'], eps=1e-8, weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -248,7 +248,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -256,7 +256,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index e8e0bf4ac..055de8569 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -60,7 +60,7 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optimizer(hyperparameters, workload.num_train_examples) diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index b33c0285b..b7c4dd2f2 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -26,7 +26,7 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optax.chain( optax.scale_by_adam( diff --git a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py index 9f4da9132..ff98464ae 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py @@ -316,11 +316,11 @@ def to_state(self, count, result_tree): """Maps from a tree of (factored) values to separate trees of values.""" return ShardedAdafactorState( count=count, - m=jax.tree_map(lambda o: o.m, result_tree), - m_scale=jax.tree_map(lambda o: o.m_scale, result_tree), - vr=jax.tree_map(lambda o: o.vr, result_tree), - vc=jax.tree_map(lambda o: o.vc, result_tree), - v=jax.tree_map(lambda o: o.v, result_tree)) + m=jax.tree.map(lambda o: o.m, result_tree), + m_scale=jax.tree.map(lambda o: o.m_scale, result_tree), + vr=jax.tree.map(lambda o: o.vr, result_tree), + vc=jax.tree.map(lambda o: o.vc, result_tree), + v=jax.tree.map(lambda o: o.v, result_tree)) def init(self, param): """Initializes the optimizer state for a given param.""" @@ -667,7 +667,7 @@ def init_fn(params): """Initializes the optimizer's state.""" return sharded_adafactor_helper.to_state( jnp.zeros([], jnp.int32), - jax.tree_map(sharded_adafactor_helper.init, params)) + jax.tree.map(sharded_adafactor_helper.init, params)) def update_fn(updates, state, params=None): if params is None: @@ -677,7 +677,7 @@ def update_fn(updates, state, params=None): compute_var_and_slot_update_fn = functools.partial( sharded_adafactor_helper.compute_var_and_slot_update, state.count) - output = jax.tree_map(compute_var_and_slot_update_fn, + output = jax.tree.map(compute_var_and_slot_update_fn, updates, state.m, state.m_scale, @@ -685,7 +685,7 @@ def update_fn(updates, state, params=None): state.vc, state.v, params) - updates = jax.tree_map(lambda o: o.update, output) + updates = jax.tree.map(lambda o: o.update, output) count_plus_one = state.count + jnp.array(1, jnp.int32) updated_states = sharded_adafactor_helper.to_state(count_plus_one, output) return updates, updated_states diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 0fcb9da0f..133468aea 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -46,7 +46,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): learning_rate=lr_schedule_fn, beta1=1.0 - hyperparameters.one_minus_beta1, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -94,7 +94,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -102,7 +102,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index e80a29693..60a336250 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -46,7 +46,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -94,7 +94,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -102,7 +102,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index ebcdc9914..7a3e1289c 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -53,7 +53,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -102,7 +102,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -110,7 +110,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index 271ef860b..182fbe644 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -28,7 +28,7 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, @@ -128,7 +128,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -136,7 +136,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 36e7e5607..30f9068d1 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index a435643e4..e45d8a854 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -28,7 +28,7 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, @@ -128,7 +128,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -136,7 +136,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 5f45901dd..3f029fbfd 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -24,7 +24,7 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray: """ gradient_norm = jnp.sqrt( sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y))) - normalized_gradient = jax.tree_map(lambda x: x / gradient_norm, y) + normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y) return normalized_gradient @@ -73,12 +73,12 @@ def update_fn(updates, state, grad_fn_params_tuple): # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), axis_name=batch_axis_name) - updates = jax.tree_map(lambda x: x / n_valid_examples, updates) + updates = jax.tree.map(lambda x: x / n_valid_examples, updates) if grad_clip: updates_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) - scaled_updates = jax.tree_map( + scaled_updates = jax.tree.map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, @@ -136,7 +136,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): base_opt_update_fn=opt_update_fn) # Initialize optimizer state. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -186,7 +186,7 @@ def _loss_fn(params, update_batch_norm=True): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 725529cae..a5c2732ac 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -342,7 +342,7 @@ def init_training_metrics( """Initialize TrainingMetrics, masked if disabled.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree_map( + return jax.tree.map( functools.partial(jnp.repeat, repeats=num_statistics), default_training_metrics()) @@ -356,14 +356,14 @@ def init_training_metrics_shapes( num_statistics, generate_training_metrics, ) - return jax.tree_map(lambda arr: [list(arr.shape), arr.dtype], seed) + return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed) def init_training_metrics_pspec(generate_training_metrics,): """Initialize training metrics partition specification.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree_map(lambda _: jax.sharding.PartitionSpec(), + return jax.tree.map(lambda _: jax.sharding.PartitionSpec(), default_training_metrics()) @@ -1253,7 +1253,7 @@ def _add_metrics_into_local_stats(local_stats, metrics, keep_old): index_start = int(local_stat.index_start) index_end = int(len(local_stat.sizes)) + index_start # pylint:disable=cell-var-from-loop Used immediately. - per_stat_metrics = jax.tree_map(lambda x: x[index_start:index_end], metrics) + per_stat_metrics = jax.tree.map(lambda x: x[index_start:index_end], metrics) # We don't want to update the metrics if we didn't do a new inverse p-th # root calculation to find a new preconditioner, so that TensorBoard curves # look consistent (otherwise they'd oscillate between NaN and measured @@ -1808,7 +1808,7 @@ def sharded_update_fn(grads, state, params): local_stat, )) - new_stats_flat = jax.tree_map( + new_stats_flat = jax.tree.map( lambda g, s, p: _compute_stats(g, s, p, state.count), @@ -1816,7 +1816,7 @@ def sharded_update_fn(grads, state, params): stats_flat, params_flat) - outputs = jax.tree_map( + outputs = jax.tree.map( lambda g, s, p: _transform_grad(g, s, p, state.count), @@ -1981,7 +1981,7 @@ def _init(param): )) return ShampooState( - count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)) + count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)) def _skip_preconditioning(param): return len(param.shape) < skip_preconditioning_rank_lt or any( @@ -2140,7 +2140,7 @@ def _internal_inverse_pth_root_all(): preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) metrics = jax.lax.all_gather(metrics, batch_axis_name) preconditioners_flat = unbatch(preconditioners) - metrics_flat = jax.tree_map(unbatch, metrics) + metrics_flat = jax.tree.map(unbatch, metrics) else: preconditioners, metrics = _matrix_inverse_pth_root_vmap( all_statistics[0], @@ -2149,9 +2149,9 @@ def _internal_inverse_pth_root_all(): _maybe_ix(all_preconditioners, 0), ) preconditioners_flat = unbatch(jnp.stack([preconditioners])) - metrics = jax.tree_map( + metrics = jax.tree.map( functools.partial(jnp.expand_dims, axis=0), metrics) - metrics_flat = jax.tree_map(unbatch, metrics) + metrics_flat = jax.tree.map(unbatch, metrics) return preconditioners_flat, metrics_flat @@ -2166,7 +2166,7 @@ def _internal_inverse_pth_root_all(): s[:, :precond_dim(s.shape[0])] for s in packed_statistics ] n = len(packed_statistics) - metrics_init = jax.tree_map( + metrics_init = jax.tree.map( lambda x: [x] * n, default_training_metrics().replace( inverse_pth_root_errors=inverse_failure_threshold)) @@ -2215,12 +2215,12 @@ def _select_preconditioner(error, new_p, old_p): if generate_training_metrics: # pylint:disable=cell-var-from-loop Used immediately. - metrics_for_state = jax.tree_map( + metrics_for_state = jax.tree.map( lambda x: jnp.stack(x[idx:idx + num_statistics]), metrics_flat, is_leaf=lambda x: isinstance(x, list)) assert jax.tree_util.tree_all( - jax.tree_map(lambda x: len(state.statistics) == len(x), + jax.tree.map(lambda x: len(state.statistics) == len(x), metrics_for_state)) # If we skipped preconditioner computation, record old metrics. metrics_for_state = efficient_cond(perform_step, @@ -2441,7 +2441,7 @@ def update_fn(grads, state, params): if custom_preconditioner and grads_custom is not None: stats_grads = treedef.flatten_up_to(grads_custom) - new_stats_flat = jax.tree_map( + new_stats_flat = jax.tree.map( lambda g, s, p: _compute_stats(g, s, p, state.count), @@ -2452,7 +2452,7 @@ def update_fn(grads, state, params): new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat, state.count) - outputs = jax.tree_map( + outputs = jax.tree.map( lambda g, s, p: _transform_grad(g, s, p, state.count), diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 294ad2706..4a257d17b 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -49,7 +49,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): weight_decay=hyperparameters.weight_decay, batch_axis_name='batch', eigh=False) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -97,7 +97,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -105,7 +105,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index 6d2cfe245..bb85ecf05 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -29,7 +29,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) epsilon = ( hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index 08a0f7e9d..c5fc2a0c6 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -32,7 +32,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 21f2a7b2b..1e6b691fc 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -96,8 +96,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -108,7 +108,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -124,14 +124,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -156,7 +156,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) epsilon = ( hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 6b27e0e2a..e5abde50b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -32,7 +32,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 7a16c07cb..703310df4 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -53,7 +53,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -61,7 +61,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/tests/modeldiffs/vanilla_sgd_jax.py b/tests/modeldiffs/vanilla_sgd_jax.py index d45694bcb..18dce968a 100644 --- a/tests/modeldiffs/vanilla_sgd_jax.py +++ b/tests/modeldiffs/vanilla_sgd_jax.py @@ -21,7 +21,7 @@ def init_optimizer_state(workload: spec.Workload, del rng # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optax.sgd(learning_rate=0.001) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index f107be8d7..6afea8a8e 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -97,9 +97,9 @@ def _make_fake_image_batch(batch_shape, data_shape, num_classes): def _pytorch_map(inputs): if USE_PYTORCH_DDP: - return jax.tree_map( + return jax.tree.map( lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs) - return jax.tree_map( + return jax.tree.map( lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1]) if len(a.shape) == 3 else torch.as_tensor(a, device=PYTORCH_DEVICE).view( -1), diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py index 6a85c2196..49fd85fef 100644 --- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py +++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py @@ -10,7 +10,7 @@ def _pytree_total_diff(pytree_a, pytree_b): - pytree_diff = jax.tree_map(lambda a, b: jnp.sum(a - b), pytree_a, pytree_b) + pytree_diff = jax.tree.map(lambda a, b: jnp.sum(a - b), pytree_a, pytree_b) pytree_diff = jax.tree_util.tree_leaves(pytree_diff) return jnp.sum(jnp.array(pytree_diff))