diff --git a/grgrjax/helpers.py b/grgrjax/helpers.py index 1cc2169..3867b18 100644 --- a/grgrjax/helpers.py +++ b/grgrjax/helpers.py @@ -4,7 +4,7 @@ import time import jax.numpy as jnp from jax._src.api import (_check_input_dtype_jacfwd, _check_input_dtype_jacrev, _check_output_dtype_jacfwd, _check_output_dtype_jacrev, _ensure_index, _jvp, - _vjp, _std_basis, _jacfwd_unravel, _jacrev_unravel, lu, argnums_partial, tree_map, tree_structure, tree_transpose, partial, Callable, Sequence, Union, vmap) + _vjp, _std_basis, _jacfwd_unravel, _jacrev_unravel, lu, argnums_partial, tree_map, tree_structure, tree_transpose, partial, Callable, Sequence, vmap) # fix import location for jax 0.4.1 try: from jax._src.api import _check_callable @@ -23,7 +23,7 @@ def amax(x, return_arg=False): return absx.max() -def jvp_vmap(fun: Callable, argnums: Union[int, Sequence[int]] = 0): +def jvp_vmap(fun: Callable, argnums: int | Sequence[int] = 0): """Vectorized (forward-mode) jacobian-vector product of ``fun``. This is by large adopted from the implementation of jacfwd in jax._src.api. Args: @@ -49,7 +49,7 @@ def jvpfun(args, tangents, **kwargs): return jvpfun -def vjp_vmap(fun: Callable, argnums: Union[int, Sequence[int]] = 0): +def vjp_vmap(fun: Callable, argnums: int | Sequence[int] = 0): """Vectorized (reverse-mode) vector-jacobian product of ``fun``. This is by large adopted from the implementation of jacrev in jax._src.api. Args: @@ -74,7 +74,7 @@ def vjpfun(args, tangents, **kwargs): return vjpfun -def val_and_jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0, +def val_and_jacfwd(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False) -> Callable: """Value and Jacobian of ``fun`` evaluated column-by-column using forward-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api. @@ -122,7 +122,7 @@ def jacfun(*args, **kwargs): return jacfun -def val_and_jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0, +def val_and_jacrev(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable: """Value and Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api.