diff --git a/grgrjax/__init__.py b/grgrjax/__init__.py index f3f9155..24e46bc 100644 --- a/grgrjax/__init__.py +++ b/grgrjax/__init__.py @@ -15,9 +15,3 @@ def jax_print(w): """Print in jax compiled functions. Wrapper around `jax.experimental.host_callback.id_print`. """ return jax.experimental.host_callback.id_print(w) - - -def amax(x): - """Return the maximum absolute value. - """ - return jnp.abs(x).max() diff --git a/grgrjax/helpers.py b/grgrjax/helpers.py index f56590b..1cc2169 100644 --- a/grgrjax/helpers.py +++ b/grgrjax/helpers.py @@ -12,7 +12,15 @@ from jax._src.api_util import check_callable as _check_callable -amax = jax.jit(lambda x: jnp.abs(x).max()) +def amax(x, return_arg=False): + """Return the maximum absolute value. + """ + absx = jnp.abs(x) + if return_arg: + arg = jnp.argmax(absx) + return absx[arg], arg + else: + return absx.max() def jvp_vmap(fun: Callable, argnums: Union[int, Sequence[int]] = 0):