Skip to content

Commit

Permalink
remove ref to
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Oct 23, 2023
1 parent fee6e92 commit 76db18f
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions grgrjax/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def amax(x, return_arg=False):
return absx.max()


def jvp_vmap(fun: Callable, argnums: int | Sequence[int] = 0):
def jvp_vmap(fun: Callable, argnums=0):
"""Vectorized (forward-mode) jacobian-vector product of ``fun``. This is by large adopted from the implementation of jacfwd in jax._src.api.
Args:
Expand All @@ -49,7 +49,7 @@ def jvpfun(args, tangents, **kwargs):
return jvpfun


def vjp_vmap(fun: Callable, argnums: int | Sequence[int] = 0):
def vjp_vmap(fun: Callable, argnums=0):
"""Vectorized (reverse-mode) vector-jacobian product of ``fun``. This is by large adopted from the implementation of jacrev in jax._src.api.
Args:
Expand All @@ -74,8 +74,7 @@ def vjpfun(args, tangents, **kwargs):
return vjpfun


def val_and_jacfwd(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False) -> Callable:
def val_and_jacfwd(fun: Callable, argnums=0, has_aux: bool = False, holomorphic: bool = False) -> Callable:
"""Value and Jacobian of ``fun`` evaluated column-by-column using forward-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api.
Args:
Expand Down Expand Up @@ -122,8 +121,7 @@ def jacfun(*args, **kwargs):
return jacfun


def val_and_jacrev(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable:
def val_and_jacrev(fun: Callable, argnums=0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable:
"""Value and Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD. Apart from returning the function value, this is one-to-one adopted from jax._src.api.
Args:
Expand Down

0 comments on commit 76db18f

Please sign in to comment.