diff --git a/grgrjax/helpers.py b/grgrjax/helpers.py index 3867b18..89c8d7a 100644 --- a/grgrjax/helpers.py +++ b/grgrjax/helpers.py @@ -23,7 +23,7 @@ def amax(x, return_arg=False): return absx.max() -def jvp_vmap(fun: Callable, argnums: int | Sequence[int] = 0): +def jvp_vmap(fun: Callable, argnums=0): """Vectorized (forward-mode) jacobian-vector product of ``fun``. This is by large adopted from the implementation of jacfwd in jax._src.api. Args: @@ -49,7 +49,7 @@ def jvpfun(args, tangents, **kwargs): return jvpfun -def vjp_vmap(fun: Callable, argnums: int | Sequence[int] = 0): +def vjp_vmap(fun: Callable, argnums=0): """Vectorized (reverse-mode) vector-jacobian product of ``fun``. This is by large adopted from the implementation of jacrev in jax._src.api. Args: @@ -74,8 +74,7 @@ def vjpfun(args, tangents, **kwargs): return vjpfun -def val_and_jacfwd(fun: Callable, argnums: int | Sequence[int] = 0, - has_aux: bool = False, holomorphic: bool = False) -> Callable: +def val_and_jacfwd(fun: Callable, argnums=0, has_aux: bool = False, holomorphic: bool = False) -> Callable: """Value and Jacobian of ``fun`` evaluated column-by-column using forward-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api. Args: @@ -122,8 +121,7 @@ def jacfun(*args, **kwargs): return jacfun -def val_and_jacrev(fun: Callable, argnums: int | Sequence[int] = 0, - has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable: +def val_and_jacrev(fun: Callable, argnums=0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable: """Value and Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api. Args: