Skip to content

Commit

Permalink
extend amax
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Mar 14, 2023
1 parent f4ad846 commit 1ade245
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
6 changes: 0 additions & 6 deletions grgrjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,3 @@ def jax_print(w):
"""Print in jax compiled functions. Wrapper around `jax.experimental.host_callback.id_print`.
"""
return jax.experimental.host_callback.id_print(w)


def amax(x):
"""Return the maximum absolute value.
"""
return jnp.abs(x).max()
10 changes: 9 additions & 1 deletion grgrjax/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
from jax._src.api_util import check_callable as _check_callable


amax = jax.jit(lambda x: jnp.abs(x).max())
def amax(x, return_arg=False):
"""Return the maximum absolute value.
"""
absx = jnp.abs(x)
if return_arg:
arg = jnp.argmax(absx)
return absx[arg], arg
else:
return absx.max()


def jvp_vmap(fun: Callable, argnums: Union[int, Sequence[int]] = 0):
Expand Down

0 comments on commit 1ade245

Please sign in to comment.