Skip to content

Commit

Permalink
fix: changes jax.tree_map to jax.tree.map
Browse files Browse the repository at this point in the history
  • Loading branch information
init-22 committed Dec 22, 2024
1 parent 53eff1d commit d21d820
Show file tree
Hide file tree
Showing 35 changed files with 124 additions and 124 deletions.
2 changes: 1 addition & 1 deletion algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _prepare(x):
# Assumes that `global_batch_size % local_device_count == 0`.
return x.reshape((local_device_count, -1, *x.shape[1:]))

return jax.tree_map(_prepare, batch)
return jax.tree.map(_prepare, batch)


def pad(tensor: np.ndarray,
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def pytorch_param_types(

def jax_param_shapes(
params: spec.ParameterContainer) -> spec.ParameterShapeTree:
return jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params)
return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params)


def jax_param_types(param_shapes: spec.ParameterShapeTree,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,4 @@ def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics)
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _eval_model_on_split(self,
eval_metrics[metric_name] = 0.0
eval_metrics[metric_name] += metric_value

eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples),
eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples),
eval_metrics)
return eval_metrics

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,4 @@ def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics)
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/ogbg/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir):

def _to_jraph(example):
"""Converts an example graph to jraph.GraphsTuple."""
example = jax.tree_map(lambda x: x._numpy(), example) # pylint: disable=protected-access
example = jax.tree.map(lambda x: x._numpy(), example) # pylint: disable=protected-access
edge_feat = example['edge_feat']
node_feat = example['node_feat']
edge_index = example['edge_index']
Expand Down Expand Up @@ -150,7 +150,7 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None):
if count == num_shards:

def f(x):
return jax.tree_map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:])
return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:])

graphs_shards = f(graphs_shards)
labels_shards = f(labels_shards)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

def _pytorch_map(inputs: Any) -> Any:
if USE_PYTORCH_DDP:
return jax.tree_map(lambda a: torch.as_tensor(a, device=DEVICE), inputs)
return jax.tree_map(
return jax.tree.map(lambda a: torch.as_tensor(a, device=DEVICE), inputs)
return jax.tree.map(
lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1])
if len(a.shape) == 3 else torch.as_tensor(a, device=DEVICE).view(-1),
inputs)
Expand All @@ -30,7 +30,7 @@ def _pytorch_map(inputs: Any) -> Any:
def _shard(inputs: Any) -> Any:
if not USE_PYTORCH_DDP:
return inputs
return jax.tree_map(lambda tensor: tensor[RANK], inputs)
return jax.tree.map(lambda tensor: tensor[RANK], inputs)


def _graph_map(function: Callable, graph: GraphsTuple) -> GraphsTuple:
Expand Down
8 changes: 4 additions & 4 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def gather_fn(x):
return x
return x[batch_indices, beam_indices]

return jax.tree_map(gather_fn, nested)
return jax.tree.map(gather_fn, nested)


def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size):
Expand Down Expand Up @@ -139,7 +139,7 @@ def beam_init(batch_size, beam_size, max_decode_len, cache):
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache)
return BeamState(
cur_index=cur_index0,
live_logprobs=live_logprobs0,
Expand Down Expand Up @@ -225,7 +225,7 @@ def beam_search_loop_body_fn(state):
(batch_size, beam_size, 1)))
# Flatten beam dimension into batch to be compatible with model.
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
flat_cache = jax.tree_map(flatten_beam_dim, state.cache)
flat_cache = jax.tree.map(flatten_beam_dim, state.cache)

# Call fast-decoder model on current tokens to get next-position logits.
# --> [batch * beam, vocab]
Expand All @@ -236,7 +236,7 @@ def beam_search_loop_body_fn(state):
logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
# Unflatten beam dimension in attention cache arrays
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
new_cache = jax.tree_map(
new_cache = jax.tree.map(
lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)

# Gather log probabilities from logits
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def eval_step(self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]:
replicated_eval_metrics = self.eval_step_pmapped(params, batch)
return jax.tree_map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics)
return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics)

@functools.partial(
jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,))
Expand Down Expand Up @@ -291,7 +291,7 @@ def _normalize_eval_metrics(
"""Normalize eval metrics."""
del num_examples
eval_denominator = total_metrics.pop('denominator')
return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics)
return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics)


class WmtWorkloadPostLN(WmtWorkload):
Expand Down
8 changes: 4 additions & 4 deletions algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def gather_fn(x):
return x
return x[batch_indices, beam_indices]

return jax.tree_map(gather_fn, nested)
return jax.tree.map(gather_fn, nested)


def gather_topk_beams(nested: Dict[str, Any],
Expand Down Expand Up @@ -164,7 +164,7 @@ def beam_init(batch_size: int,
dtype=torch.bool,
device=DEVICE)
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache)
return BeamState(
cur_index=cur_index0,
live_logprobs=live_logprobs0,
Expand Down Expand Up @@ -251,7 +251,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
state.live_seqs[:batch_size, :beam_size, cur_index:cur_index + 1])
# Flatten beam dimension into batch to be compatible with model.
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
flat_cache = jax.tree_map(flatten_beam_dim, state.cache)
flat_cache = jax.tree.map(flatten_beam_dim, state.cache)

# Call fast-decoder model on current tokens to get next-position logits.
# --> [batch * beam, vocab]
Expand All @@ -262,7 +262,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
# Unflatten beam dimension in attention cache arrays
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
new_cache = jax.tree_map(
new_cache = jax.tree.map(
lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)

# Gather log probabilities from logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _normalize_eval_metrics(
dist.all_reduce(metric)
total_metrics = {k: v.item() for k, v in total_metrics.items()}
eval_denominator = total_metrics.pop('denominator')
return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics)
return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics)


class WmtWorkloadPostLN(WmtWorkload):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9,
raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power)

def init_fn(params):
mu = jax.tree_map(jnp.zeros_like, params) # First moment
nu = jax.tree_map(jnp.zeros_like, params) # Second moment
mu = jax.tree.map(jnp.zeros_like, params) # First moment
nu = jax.tree.map(jnp.zeros_like, params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

def update_fn(updates, state, params=None):
Expand All @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None):
mu_hat = _update_moment(updates, mu, b1, 1)
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree_map(
updates = jax.tree.map(
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

Expand All @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple):

def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree_map(
return jax.tree.map(
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


def _bias_correction(moment, decay, count):
"""Perform bias correction. This becomes a no-op as count goes to infinity."""
beta = 1 - decay**count
return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment)
return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment)


def scale_by_learning_rate(learning_rate, flip_sign=True):
Expand Down Expand Up @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
b2=hyperparameters.beta2,
eps=1e-8,
weight_decay=hyperparameters.weight_decay)
params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
workload.param_shapes)
optimizer_state = opt_init_fn(params_zeros_like)

Expand Down Expand Up @@ -236,15 +236,15 @@ def _loss_fn(params):
(summed_loss, n_valid_examples, grad) = lax.psum(
(summed_loss, n_valid_examples, grad), axis_name='batch')
loss = summed_loss / n_valid_examples
grad = jax.tree_map(lambda x: x / n_valid_examples, grad)
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)

grad_norm = jnp.sqrt(
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))

if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)

updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
current_param_container)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9,
raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power)

def init_fn(params):
mu = jax.tree_map(jnp.zeros_like, params) # First moment
nu = jax.tree_map(jnp.zeros_like, params) # Second moment
mu = jax.tree.map(jnp.zeros_like, params) # First moment
nu = jax.tree.map(jnp.zeros_like, params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

def update_fn(updates, state, params=None):
Expand All @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None):
mu_hat = _update_moment(updates, mu, b1, 1)
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree_map(
updates = jax.tree.map(
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

Expand All @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple):

def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree_map(
return jax.tree.map(
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


def _bias_correction(moment, decay, count):
"""Perform bias correction. This becomes a no-op as count goes to infinity."""
beta = 1 - decay**count
return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment)
return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment)


def scale_by_learning_rate(learning_rate, flip_sign=True):
Expand Down Expand Up @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
b2=hyperparameters.beta2,
eps=1e-8,
weight_decay=hyperparameters.weight_decay)
params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
workload.param_shapes)
optimizer_state = opt_init_fn(params_zeros_like)

Expand Down Expand Up @@ -236,15 +236,15 @@ def _loss_fn(params):
(summed_loss, n_valid_examples, grad) = lax.psum(
(summed_loss, n_valid_examples, grad), axis_name='batch')
loss = summed_loss / n_valid_examples
grad = jax.tree_map(lambda x: x / n_valid_examples, grad)
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)

grad_norm = jnp.sqrt(
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))

if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)

updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
current_param_container)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def scale_by_nadam(b1: float = 0.9,
raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power)

def init_fn(params):
mu = jax.tree_map(jnp.zeros_like, params) # First moment
nu = jax.tree_map(jnp.zeros_like, params) # Second moment
mu = jax.tree.map(jnp.zeros_like, params) # First moment
nu = jax.tree.map(jnp.zeros_like, params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

def update_fn(updates, state, params=None):
Expand All @@ -132,7 +132,7 @@ def update_fn(updates, state, params=None):
mu_hat = _update_moment(updates, mu, b1, 1)
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree_map(
updates = jax.tree.map(
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

Expand All @@ -148,14 +148,14 @@ class ScaleByAdamState(NamedTuple):

def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree_map(
return jax.tree.map(
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


def _bias_correction(moment, decay, count):
"""Perform bias correction. This becomes a no-op as count goes to infinity."""
beta = 1 - decay**count
return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment)
return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment)


def scale_by_learning_rate(learning_rate, flip_sign=True):
Expand Down Expand Up @@ -200,7 +200,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
b2=hyperparameters['beta2'],
eps=1e-8,
weight_decay=hyperparameters['weight_decay'])
params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
workload.param_shapes)
optimizer_state = opt_init_fn(params_zeros_like)

Expand Down Expand Up @@ -248,15 +248,15 @@ def _loss_fn(params):
(summed_loss, n_valid_examples, grad) = lax.psum(
(summed_loss, n_valid_examples, grad), axis_name='batch')
loss = summed_loss / n_valid_examples
grad = jax.tree_map(lambda x: x / n_valid_examples, grad)
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)

grad_norm = jnp.sqrt(
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))

if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)

updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
current_param_container)
Expand Down
Loading

0 comments on commit d21d820

Please sign in to comment.