From 3fcc7235afeab6efcacab99a5d49df2db0504259 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Mon, 20 Feb 2023 15:08:02 +0100 Subject: [PATCH] Delete jit (#61) * Removed jax.jit from src * Removed jit * Add jit back to toplevel api * Fixed test --- .github/workflows/ci.yaml | 2 +- src/probfindiff/_toplevel_api.py | 2 +- src/probfindiff/collocation.py | 5 ----- src/probfindiff/stencil.py | 14 -------------- src/probfindiff/utils/autodiff.py | 4 ++-- src/probfindiff/utils/kernel.py | 4 ++-- src/probfindiff/utils/kernel_zoo.py | 3 --- 7 files changed, 6 insertions(+), 28 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c3658a2..4c1e871 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -62,7 +62,7 @@ jobs: pip install .[ci] - name: Run tests with pytest through tox run: | - tox -e pytest + tox -e py3 byexample: runs-on: ubuntu-latest steps: diff --git a/src/probfindiff/_toplevel_api.py b/src/probfindiff/_toplevel_api.py index c495d1b..49bd00c 100644 --- a/src/probfindiff/_toplevel_api.py +++ b/src/probfindiff/_toplevel_api.py @@ -53,6 +53,7 @@ def differentiate( return dfx, unc_base +@functools.partial(jax.jit, static_argnames=["axis"]) def differentiate_along_axis( fx: ArrayLike, *, axis: int, scheme: FiniteDifferenceScheme ) -> Array: @@ -209,7 +210,6 @@ def central( return scheme, grid -@functools.partial(jax.jit, static_argnames=("order_derivative", "kernel")) def from_grid( *, xs: ArrayLike, diff --git a/src/probfindiff/collocation.py b/src/probfindiff/collocation.py index 84a80c6..9f28333 100644 --- a/src/probfindiff/collocation.py +++ b/src/probfindiff/collocation.py @@ -1,15 +1,12 @@ """Finite differences and collocation with Gaussian processes.""" -import functools from typing import Tuple -import jax import jax.numpy as jnp from probfindiff.typing import Array, ArrayLike, KernelFunctionLike -@functools.partial(jax.jit, static_argnames=("ks",)) def non_uniform_nd( *, x: ArrayLike, @@ -75,7 +72,6 @@ def prepare_gram( return K, LK, LLK -@jax.jit def unsymmetric( *, K: ArrayLike, @@ -125,7 +121,6 @@ def _transpose(LK0: ArrayLike) -> Array: return LKt -@jax.jit def symmetric( *, K: ArrayLike, LK1: ArrayLike, LLK: ArrayLike, noise_variance: float ) -> Tuple[Array, Array]: diff --git a/src/probfindiff/stencil.py b/src/probfindiff/stencil.py index 0c902f4..fb18b5f 100644 --- a/src/probfindiff/stencil.py +++ b/src/probfindiff/stencil.py @@ -1,17 +1,14 @@ """Stencil functionality.""" -import functools from typing import Tuple, Union -import jax import jax.numpy as jnp from probfindiff import defaults from probfindiff.typing import Array, ArrayLike -@functools.partial(jax.jit, static_argnames=("shape_input", "shape_output")) def multivariate( *, xs_1d: ArrayLike, @@ -62,7 +59,6 @@ def multivariate( return jnp.broadcast_to(coeffs, shape=shape_output + coeffs.shape) -@functools.partial(jax.jit, static_argnames=("shape_input",)) def _stencils_for_all_partial_derivatives( *, stencil_1d: ArrayLike, shape_input: Tuple[int] ) -> Array: @@ -131,13 +127,6 @@ def _stencils_for_all_partial_derivatives( ) -@functools.partial( - jax.jit, - static_argnames=( - "dimension", - "i", - ), -) def _stencil_for_ith_partial_derivative( *, stencil_1d_as_row_matrix: ArrayLike, i: int, dimension: int ) -> Array: @@ -186,7 +175,6 @@ def _stencil_for_ith_partial_derivative( return jnp.pad(stencil_1d_as_row_matrix, pad_width=((i, dimension - i - 1), (0, 0))) -@functools.partial(jax.jit, static_argnames=("order_derivative", "order_method")) def backward( *, dx: float, @@ -199,7 +187,6 @@ def backward( return grid -@functools.partial(jax.jit, static_argnames=("order_derivative", "order_method")) def forward( *, dx: float, @@ -212,7 +199,6 @@ def forward( return grid -@functools.partial(jax.jit, static_argnames=("order_derivative", "order_method")) def central( *, dx: float, diff --git a/src/probfindiff/utils/autodiff.py b/src/probfindiff/utils/autodiff.py index c4639d6..99eca16 100644 --- a/src/probfindiff/utils/autodiff.py +++ b/src/probfindiff/utils/autodiff.py @@ -25,7 +25,7 @@ def derivative(fun: Callable[[Any], Any], **kwargs: Any) -> Callable[[Any], Any] """ grad = jax.grad(fun, **kwargs) - return jax.jit(lambda *args: grad(*args)[0]) + return lambda *args: grad(*args)[0] def div(fun: Callable[[Any], Any], **kwargs: Any) -> Callable[[Any], Any]: @@ -45,7 +45,7 @@ def div(fun: Callable[[Any], Any], **kwargs: Any) -> Callable[[Any], Any]: """ jac = jax.jacrev(fun, **kwargs) - return jax.jit(lambda *args: jnp.trace(jac(*args))) + return lambda *args: jnp.trace(jac(*args)) def laplace(fun: Callable[[Any], Any], **kwargs: Any) -> Callable[[Any], Any]: diff --git a/src/probfindiff/utils/kernel.py b/src/probfindiff/utils/kernel.py index 4389f25..f352200 100644 --- a/src/probfindiff/utils/kernel.py +++ b/src/probfindiff/utils/kernel.py @@ -27,7 +27,7 @@ def differentiate( k_batch, _ = batch_gram(k) lk_batch, lk = batch_gram(L(k, argnums=0)) llk_batch, _ = batch_gram(L(lk, argnums=1)) - return jax.jit(k_batch), jax.jit(lk_batch), jax.jit(llk_batch) + return k_batch, lk_batch, llk_batch def batch_gram(k: KernelFunctionLike) -> Tuple[KernelFunctionLike, KernelFunctionLike]: @@ -50,4 +50,4 @@ def batch_gram(k: KernelFunctionLike) -> Tuple[KernelFunctionLike, KernelFunctio Tuple :math:`(\tilde k, k)` of the batched kernel function and the original kernel function. """ k_vmapped_x = jax.vmap(k, in_axes=(0, None), out_axes=-1) - return jax.jit(jax.vmap(k_vmapped_x, in_axes=(None, -1), out_axes=-1)), jax.jit(k) + return jax.vmap(k_vmapped_x, in_axes=(None, -1), out_axes=-1), k diff --git a/src/probfindiff/utils/kernel_zoo.py b/src/probfindiff/utils/kernel_zoo.py index e23760c..4e5d089 100644 --- a/src/probfindiff/utils/kernel_zoo.py +++ b/src/probfindiff/utils/kernel_zoo.py @@ -1,13 +1,11 @@ """Kernel zoo.""" -import jax import jax.numpy as jnp from probfindiff.typing import Array, ArrayLike -@jax.jit def exponentiated_quadratic( x: ArrayLike, y: ArrayLike, input_scale: float = 1.0, output_scale: float = 1.0 ) -> ArrayLike: @@ -41,7 +39,6 @@ def exponentiated_quadratic( return output_scale * jnp.exp(-input_scale * jnp.dot(difference, difference) / 2.0) -@jax.jit def polynomial( x: ArrayLike, y: ArrayLike,