Skip to content

Commit

Permalink
Delete jit (#61)
Browse files Browse the repository at this point in the history
* Removed jax.jit from src

* Removed jit

* Add jit back to toplevel api

* Fixed test
  • Loading branch information
pnkraemer authored Feb 20, 2023
1 parent d6735f3 commit 3fcc723
Show file tree
Hide file tree
Showing 7 changed files with 6 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
pip install .[ci]
- name: Run tests with pytest through tox
run: |
tox -e pytest
tox -e py3
byexample:
runs-on: ubuntu-latest
steps:
Expand Down
2 changes: 1 addition & 1 deletion src/probfindiff/_toplevel_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def differentiate(
return dfx, unc_base


@functools.partial(jax.jit, static_argnames=["axis"])
def differentiate_along_axis(
fx: ArrayLike, *, axis: int, scheme: FiniteDifferenceScheme
) -> Array:
Expand Down Expand Up @@ -209,7 +210,6 @@ def central(
return scheme, grid


@functools.partial(jax.jit, static_argnames=("order_derivative", "kernel"))
def from_grid(
*,
xs: ArrayLike,
Expand Down
5 changes: 0 additions & 5 deletions src/probfindiff/collocation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
"""Finite differences and collocation with Gaussian processes."""

import functools
from typing import Tuple

import jax
import jax.numpy as jnp

from probfindiff.typing import Array, ArrayLike, KernelFunctionLike


@functools.partial(jax.jit, static_argnames=("ks",))
def non_uniform_nd(
*,
x: ArrayLike,
Expand Down Expand Up @@ -75,7 +72,6 @@ def prepare_gram(
return K, LK, LLK


@jax.jit
def unsymmetric(
*,
K: ArrayLike,
Expand Down Expand Up @@ -125,7 +121,6 @@ def _transpose(LK0: ArrayLike) -> Array:
return LKt


@jax.jit
def symmetric(
*, K: ArrayLike, LK1: ArrayLike, LLK: ArrayLike, noise_variance: float
) -> Tuple[Array, Array]:
Expand Down
14 changes: 0 additions & 14 deletions src/probfindiff/stencil.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
"""Stencil functionality."""


import functools
from typing import Tuple, Union

import jax
import jax.numpy as jnp

from probfindiff import defaults
from probfindiff.typing import Array, ArrayLike


@functools.partial(jax.jit, static_argnames=("shape_input", "shape_output"))
def multivariate(
*,
xs_1d: ArrayLike,
Expand Down Expand Up @@ -62,7 +59,6 @@ def multivariate(
return jnp.broadcast_to(coeffs, shape=shape_output + coeffs.shape)


@functools.partial(jax.jit, static_argnames=("shape_input",))
def _stencils_for_all_partial_derivatives(
*, stencil_1d: ArrayLike, shape_input: Tuple[int]
) -> Array:
Expand Down Expand Up @@ -131,13 +127,6 @@ def _stencils_for_all_partial_derivatives(
)


@functools.partial(
jax.jit,
static_argnames=(
"dimension",
"i",
),
)
def _stencil_for_ith_partial_derivative(
*, stencil_1d_as_row_matrix: ArrayLike, i: int, dimension: int
) -> Array:
Expand Down Expand Up @@ -186,7 +175,6 @@ def _stencil_for_ith_partial_derivative(
return jnp.pad(stencil_1d_as_row_matrix, pad_width=((i, dimension - i - 1), (0, 0)))


@functools.partial(jax.jit, static_argnames=("order_derivative", "order_method"))
def backward(
*,
dx: float,
Expand All @@ -199,7 +187,6 @@ def backward(
return grid


@functools.partial(jax.jit, static_argnames=("order_derivative", "order_method"))
def forward(
*,
dx: float,
Expand All @@ -212,7 +199,6 @@ def forward(
return grid


@functools.partial(jax.jit, static_argnames=("order_derivative", "order_method"))
def central(
*,
dx: float,
Expand Down
4 changes: 2 additions & 2 deletions src/probfindiff/utils/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def derivative(fun: Callable[[Any], Any], **kwargs: Any) -> Callable[[Any], Any]
"""

grad = jax.grad(fun, **kwargs)
return jax.jit(lambda *args: grad(*args)[0])
return lambda *args: grad(*args)[0]


def div(fun: Callable[[Any], Any], **kwargs: Any) -> Callable[[Any], Any]:
Expand All @@ -45,7 +45,7 @@ def div(fun: Callable[[Any], Any], **kwargs: Any) -> Callable[[Any], Any]:
"""

jac = jax.jacrev(fun, **kwargs)
return jax.jit(lambda *args: jnp.trace(jac(*args)))
return lambda *args: jnp.trace(jac(*args))


def laplace(fun: Callable[[Any], Any], **kwargs: Any) -> Callable[[Any], Any]:
Expand Down
4 changes: 2 additions & 2 deletions src/probfindiff/utils/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def differentiate(
k_batch, _ = batch_gram(k)
lk_batch, lk = batch_gram(L(k, argnums=0))
llk_batch, _ = batch_gram(L(lk, argnums=1))
return jax.jit(k_batch), jax.jit(lk_batch), jax.jit(llk_batch)
return k_batch, lk_batch, llk_batch


def batch_gram(k: KernelFunctionLike) -> Tuple[KernelFunctionLike, KernelFunctionLike]:
Expand All @@ -50,4 +50,4 @@ def batch_gram(k: KernelFunctionLike) -> Tuple[KernelFunctionLike, KernelFunctio
Tuple :math:`(\tilde k, k)` of the batched kernel function and the original kernel function.
"""
k_vmapped_x = jax.vmap(k, in_axes=(0, None), out_axes=-1)
return jax.jit(jax.vmap(k_vmapped_x, in_axes=(None, -1), out_axes=-1)), jax.jit(k)
return jax.vmap(k_vmapped_x, in_axes=(None, -1), out_axes=-1), k
3 changes: 0 additions & 3 deletions src/probfindiff/utils/kernel_zoo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""Kernel zoo."""


import jax
import jax.numpy as jnp

from probfindiff.typing import Array, ArrayLike


@jax.jit
def exponentiated_quadratic(
x: ArrayLike, y: ArrayLike, input_scale: float = 1.0, output_scale: float = 1.0
) -> ArrayLike:
Expand Down Expand Up @@ -41,7 +39,6 @@ def exponentiated_quadratic(
return output_scale * jnp.exp(-input_scale * jnp.dot(difference, difference) / 2.0)


@jax.jit
def polynomial(
x: ArrayLike,
y: ArrayLike,
Expand Down

0 comments on commit 3fcc723

Please sign in to comment.