Skip to content

Commit

Permalink
fix: running yapf again with 0.32, earlier using 0.43
Browse files Browse the repository at this point in the history
  • Loading branch information
init-22 committed Dec 3, 2024
1 parent c65d93e commit 3afd1df
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None):
mu_hat = _update_moment(updates, mu, b1, 1)
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
mu_hat,
nu_hat)
updates = jax.tree_map(
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

return optax.GradientTransformation(init_fn, update_fn)
Expand All @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple):

def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
updates,
moments)
return jax.tree_map(
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


def _bias_correction(moment, decay, count):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None):
mu_hat = _update_moment(updates, mu, b1, 1)
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
mu_hat,
nu_hat)
updates = jax.tree_map(
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

return optax.GradientTransformation(init_fn, update_fn)
Expand All @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple):

def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
updates,
moments)
return jax.tree_map(
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


def _bias_correction(moment, decay, count):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None):
mu_hat = _update_moment(updates, mu, b1, 1)
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
mu_hat,
nu_hat)
updates = jax.tree_map(
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

return optax.GradientTransformation(init_fn, update_fn)
Expand All @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple):

def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
updates,
moments)
return jax.tree_map(
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


def _bias_correction(moment, decay, count):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None):
mu_hat = _update_moment(updates, mu, b1, 1)
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
mu_hat,
nu_hat)
updates = jax.tree_map(
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

return optax.GradientTransformation(init_fn, update_fn)
Expand All @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple):

def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
updates,
moments)
return jax.tree_map(
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


def _bias_correction(moment, decay, count):
Expand Down
10 changes: 4 additions & 6 deletions reference_algorithms/paper_baselines/nadamw/jax/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None):
mu_hat = _update_moment(updates, mu, b1, 1)
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
mu_hat,
nu_hat)
updates = jax.tree_map(
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

return optax.GradientTransformation(init_fn, update_fn)
Expand All @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple):

def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
updates,
moments)
return jax.tree_map(
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


def _bias_correction(moment, decay, count):
Expand Down
8 changes: 4 additions & 4 deletions reference_algorithms/paper_baselines/sam/jax/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,8 @@ def update_fn(updates, state, grad_fn_params_tuple):
# the noised parameters in the same order as on the original gradients and
# with the same 1e-6 epsilon that is used when clipping the gradients.
updates = dual_vector(updates)
noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u,
params,
updates)
noised_params = jax.tree_util.tree_map(
lambda p, u: p + rho * u, params, updates)
(_, (n_valid_examples, _)), updates = grad_fn(noised_params)
# Get correct global mean grad.
(n_valid_examples, updates) = lax.psum((n_valid_examples, updates),
Expand All @@ -81,7 +80,8 @@ def update_fn(updates, state, grad_fn_params_tuple):
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates)))
scaled_updates = jax.tree_map(
lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates)
updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates,
updates = jax.lax.cond(updates_norm > grad_clip,
lambda _: scaled_updates,
lambda _: updates,
None)
updates, state = base_opt_update_fn(updates, state, params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,8 @@ def matrix_inverse_pth_root(

if padding_start is not None:
# Zero out padding in identity as well for convergence checks.
ix = (jnp.arange(matrix_size, dtype=jnp.int32)
< padding_start).astype(matrix.dtype)
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
matrix.dtype)
matrix *= ix[jnp.newaxis, :]
matrix *= ix[:, jnp.newaxis]
identity *= ix
Expand Down Expand Up @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh(
alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
if padding_start is not None:
ix = (jnp.arange(matrix_size, dtype=jnp.int32)
< padding_start).astype(matrix.dtype)
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
matrix.dtype)
matrix *= ix[jnp.newaxis, :]
matrix *= ix[:, jnp.newaxis]
identity *= ix
Expand Down Expand Up @@ -1809,13 +1809,17 @@ def sharded_update_fn(grads, state, params):
))

new_stats_flat = jax.tree_map(
lambda g, s, p: _compute_stats(g, s, p, state.count),
lambda g,
s,
p: _compute_stats(g, s, p, state.count),
grads_flat,
stats_flat,
params_flat)

outputs = jax.tree_map(
lambda g, s, p: _transform_grad(g, s, p, state.count),
lambda g,
s,
p: _transform_grad(g, s, p, state.count),
grads_flat,
new_stats_flat,
params_flat)
Expand Down Expand Up @@ -1919,8 +1923,8 @@ def _internal_inverse_pth_root_all():
errors = metrics.inverse_pth_root_errors
errors = errors.reshape((-1, 1, 1))
predicate = jnp.logical_or(
jnp.isnan(errors), errors
>= inverse_failure_threshold).astype(new_preconditioners.dtype)
jnp.isnan(errors),
errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
# TODO(rohananil): Check for numerical instabilities.
new_conditional_preconditioners = (
predicate * global_stats.preconditioners +
Expand Down Expand Up @@ -2438,7 +2442,9 @@ def update_fn(grads, state, params):
stats_grads = treedef.flatten_up_to(grads_custom)

new_stats_flat = jax.tree_map(
lambda g, s, p: _compute_stats(g, s, p, state.count),
lambda g,
s,
p: _compute_stats(g, s, p, state.count),
stats_grads,
stats_flat,
params_flat)
Expand All @@ -2447,7 +2453,9 @@ def update_fn(grads, state, params):
params_flat,
state.count)
outputs = jax.tree_map(
lambda g, s, p: _transform_grad(g, s, p, state.count),
lambda g,
s,
p: _transform_grad(g, s, p, state.count),
grads_flat,
new_stats_flat,
params_flat)
Expand Down
10 changes: 4 additions & 6 deletions reference_algorithms/target_setting_algorithms/jax_nadamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,8 @@ def update_fn(updates, state, params=None):
mu_hat = _update_moment(updates, mu, b1, 1)
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps),
mu_hat,
nu_hat)
updates = jax.tree_map(
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

return optax.GradientTransformation(init_fn, update_fn)
Expand All @@ -125,9 +124,8 @@ class ScaleByAdamState(NamedTuple):

def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
updates,
moments)
return jax.tree_map(
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)


def _bias_correction(moment, decay, count):
Expand Down

0 comments on commit 3afd1df

Please sign in to comment.