Skip to content

Commit

Permalink
compatibility with jax 0.4.19
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Oct 23, 2023
1 parent 6db3bf8 commit f1de7e0
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions grgrjax/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import jax.numpy as jnp
from jax._src.api import (_check_input_dtype_jacfwd, _check_input_dtype_jacrev, _check_output_dtype_jacfwd, _check_output_dtype_jacrev, _ensure_index, _jvp,
_vjp, _std_basis, _jacfwd_unravel, _jacrev_unravel, lu, argnums_partial, tree_map, tree_structure, tree_transpose, partial, Callable, Sequence, Union, vmap)
_vjp, _std_basis, _jacfwd_unravel, _jacrev_unravel, lu, argnums_partial, tree_map, tree_structure, tree_transpose, partial, Callable, Sequence, vmap)
# fix import location for jax 0.4.1
try:
from jax._src.api import _check_callable
Expand All @@ -23,7 +23,7 @@ def amax(x, return_arg=False):
return absx.max()


def jvp_vmap(fun: Callable, argnums: Union[int, Sequence[int]] = 0):
def jvp_vmap(fun: Callable, argnums: int | Sequence[int] = 0):
"""Vectorized (forward-mode) jacobian-vector product of ``fun``. This is by large adopted from the implementation of jacfwd in jax._src.api.
Args:
Expand All @@ -49,7 +49,7 @@ def jvpfun(args, tangents, **kwargs):
return jvpfun


def vjp_vmap(fun: Callable, argnums: Union[int, Sequence[int]] = 0):
def vjp_vmap(fun: Callable, argnums: int | Sequence[int] = 0):
"""Vectorized (reverse-mode) vector-jacobian product of ``fun``. This is by large adopted from the implementation of jacrev in jax._src.api.
Args:
Expand All @@ -74,7 +74,7 @@ def vjpfun(args, tangents, **kwargs):
return vjpfun


def val_and_jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def val_and_jacfwd(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False) -> Callable:
"""Value and Jacobian of ``fun`` evaluated column-by-column using forward-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api.
Expand Down Expand Up @@ -122,7 +122,7 @@ def jacfun(*args, **kwargs):
return jacfun


def val_and_jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def val_and_jacrev(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable:
"""Value and Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api.
Expand Down

0 comments on commit f1de7e0

Please sign in to comment.