Skip to content

Commit

Permalink
JAX version of autodiff (#1092)
Browse files Browse the repository at this point in the history
  • Loading branch information
kinnala authored Jan 29, 2024
1 parent d098c6b commit 580df1a
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 72 deletions.
4 changes: 2 additions & 2 deletions docs/examples/ex45.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from skfem.experimental.autodiff import NonlinearForm
from skfem.experimental.autodiff.helpers import grad, dot
import numpy as np
import autograd.numpy as anp
import jax.numpy as jnp

m = MeshTri().refined(5)


@NonlinearForm(hessian=True)
def energy(u, _):
return anp.sqrt(1. + dot(grad(u), grad(u)))
return jnp.sqrt(1. + dot(grad(u), grad(u)))


basis = Basis(m, ElementTriP1())
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pyamg
mypy
flake8
sphinx~=6.1
autograd
jax
jaxlib
pep517
shapely
132 changes: 113 additions & 19 deletions skfem/experimental/autodiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,89 @@
from skfem.assembly.form.coo_data import COOData
from numpy import ndarray
import numpy as np
from jax import jvp, linearize, config
from jax.tree_util import register_pytree_node
import jax.numpy as jnp

from autograd import make_jvp

config.update("jax_enable_x64", True)


class JaxDiscreteField(object):

def __init__(self,
value,
grad=None,
div=None,
curl=None,
hess=None,
grad3=None,
grad4=None,
grad5=None,
grad6=None):
self.value = value
self.grad = grad
self.div = div
self.curl = curl
self.hess = hess
self.grad3 = grad3
self.grad4 = grad4
self.grad5 = grad5
self.grad6 = grad6

def __add__(self, other):
if isinstance(other, JaxDiscreteField):
return self.value + other.value
return self.value + other

def __sub__(self, other):
if isinstance(other, JaxDiscreteField):
return self.value - other.value
return self.value - other

def __mul__(self, other):
if isinstance(other, JaxDiscreteField):
return self.value * other.value
return self.value * other

def __rmul__(self, other):
if isinstance(other, JaxDiscreteField):
return self.value * other.value
return self.value * other

def __pow__(self, ix):
return self.value ** ix

def __array__(self):
return self.value

def __getitem__(self, index):
return self.value[index]

@property
def shape(self):
return self.value.shape

@property
def astuple(self):
return (
self.value,
self.grad,
self.div,
self.curl,
self.hess,
self.grad3,
self.grad4,
self.grad5,
self.grad6,
)


register_pytree_node(
JaxDiscreteField,
lambda xs: (xs.astuple, None),
lambda _, xs: JaxDiscreteField(*xs),
)


class NonlinearForm(Form):
Expand All @@ -19,21 +100,31 @@ def assemble(self, basis, x=None, **kwargs):
Optional point at which the form is linearized, default is zero.
"""
# interpolate and cast to tuple
# make x compatible with u in forms
if x is None:
x = basis.zeros()
if isinstance(x, ndarray):
x = basis.interpolate(x)
if isinstance(x, tuple):
x = tuple(c.astuple for c in x)
x = tuple(JaxDiscreteField(*c.astuple) for c in x)
else:
x = (x.astuple,)
x = (JaxDiscreteField(*x.astuple),)

nt = basis.nelems
dx = basis.dx

defaults = basis.default_parameters()
# turn defaults into JaxDiscreteField to avoid np.ndarray
# to jnp.ndarray promotion issues
w = FormExtraParams({
**basis.default_parameters(),
**{
k: JaxDiscreteField(*tuple(
jnp.asarray(x)
if x is not None else None
for x in defaults[k].astuple
))
for k in defaults
},
**self._normalize_asm_kwargs(kwargs, basis),
})

Expand All @@ -47,26 +138,29 @@ def assemble(self, basis, x=None, **kwargs):
data1 = np.zeros(sz1, dtype=self.dtype)
rows1 = np.zeros(sz1, dtype=np.int64)

def _make_jacobian(V):
if 'hessian' in self.params:
F = make_jvp(lambda U: self.form(*U, w))
return make_jvp(lambda W: F(W)(V)[1])(x)
return make_jvp(lambda U: self.form(*U, *V, w))(x)

# # JAX version
# # autograd version
# def _make_jacobian(V):
# if 'hessian' in self.params:
# return linearize(
# lambda W: jvp(lambda U: self.form(*U, w), (W,), (V,))[1],
# x
# )
# return linearize(lambda U: self.form(*U, *V, w), x)
# F = make_jvp(lambda U: self.form(*U, w))
# return make_jvp(lambda W: F(W)(V)[1])(x)
# return make_jvp(lambda U: self.form(*U, *V, w))(x)

# JAX version
def _make_jacobian(V):
if 'hessian' in self.params:
return linearize(
lambda W: jvp(lambda U: self.form(*U, w), (W,), (V,))[1],
x
)
return linearize(lambda U: self.form(*U, *V, w), x)

# loop over the indices of local stiffness matrix
for i in range(basis.Nbfun):
DF = _make_jacobian(tuple(c.astuple for c in basis.basis[i]))
y, DF = _make_jacobian(tuple(JaxDiscreteField(*c.astuple)
for c in basis.basis[i]))
for j in range(basis.Nbfun):
y, DFU = DF(tuple(c.astuple for c in basis.basis[j]))
DFU = DF(tuple(JaxDiscreteField(*c.astuple)
for c in basis.basis[j]))
# Jacobian
ixs = slice(nt * (basis.Nbfun * j + i),
nt * (basis.Nbfun * j + i + 1))
Expand Down
94 changes: 44 additions & 50 deletions skfem/experimental/autodiff/helpers.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,87 @@
import autograd.numpy as np
from autograd.builtins import isinstance
import jax.numpy as jnp

from skfem import DiscreteField
from . import JaxDiscreteField


def dot(u, v):
if isinstance(u, tuple):
u = u[0]
if isinstance(v, tuple):
v = v[0]
return np.einsum('i...,i...', u, v)
if isinstance(u, JaxDiscreteField):
u = u.value
if isinstance(v, JaxDiscreteField):
v = v.value
return jnp.einsum('i...,i...', u, v)


def ddot(u, v):
if isinstance(u, tuple):
u = u[0]
if isinstance(v, tuple):
v = v[0]
return np.einsum('ij...,ij...', u, v)
if isinstance(u, JaxDiscreteField):
u = u.value
if isinstance(v, JaxDiscreteField):
v = v.value
return jnp.einsum('ij...,ij...', u, v)


def dddot(u, v):
if isinstance(u, tuple):
u = u[0]
if isinstance(v, tuple):
v = v[0]
return np.einsum('ijk...,ijk...', u, v)
if isinstance(u, JaxDiscreteField):
u = u.value
if isinstance(v, JaxDiscreteField):
v = v.value
return jnp.einsum('ijk...,ijk...', u, v)


def grad(u):
if isinstance(u, DiscreteField):
return u.grad
return u[1]
return u.grad


def sym_grad(u):
if isinstance(u, DiscreteField):
return .5 * (u.grad + transpose(u.grad))
return .5 * (u[1] + transpose(u[1]))
return .5 * (u.grad + transpose(u.grad))


def div(u):
if len(u[1].shape) == 4:
return np.einsum('ii...', u[1])
return u[2]
if len(u.grad.shape) == 4:
return jnp.einsum('ii...', u.grad)
return u.div


def dd(u):
if isinstance(u, DiscreteField):
return u.hess
return u[4]
return u.hess


def transpose(T):
if isinstance(T, tuple):
T = T[0]
return np.einsum('ij...->ji...', T)
if isinstance(T, JaxDiscreteField):
T = T.value
return jnp.einsum('ij...->ji...', T)


def mul(A, B):
if isinstance(A, tuple):
A = A[0]
if isinstance(B, tuple):
B = B[0]
if isinstance(A, JaxDiscreteField):
A = A.value
if isinstance(B, JaxDiscreteField):
B = B.value
if len(A.shape) == len(B.shape):
return np.einsum('ij...,jk...->ik...', A, B)
return np.einsum('ij...,j...->i...', A, B)
return jnp.einsum('ij...,jk...->ik...', A, B)
return jnp.einsum('ij...,j...->i...', A, B)


def trace(T):
if isinstance(T, tuple):
T = T[0]
return np.einsum('ii...', T)
if isinstance(T, JaxDiscreteField):
T = T.value
return jnp.einsum('ii...', T)


def eye(w, size):
return np.array([[w if i == j else 0. * w for i in range(size)]
return jnp.array([[w if i == j else 0. * w
for i in range(size)]
for j in range(size)])


def det(A):
detA = np.zeros_like(A[0, 0])
detA = jnp.zeros_like(A[0, 0])
if A.shape[0] == 3:
detA = A[0, 0] * (A[1, 1] * A[2, 2] -
A[1, 2] * A[2, 1]) -\
A[0, 1] * (A[1, 0] * A[2, 2] -
A[1, 2] * A[2, 0]) +\
A[0, 2] * (A[1, 0] * A[2, 1] -
A[1, 1] * A[2, 0])
detA = (A[0, 0] * (A[1, 1] * A[2, 2]
- A[1, 2] * A[2, 1])
- A[0, 1] * (A[1, 0] * A[2, 2] -
- A[1, 2] * A[2, 0])
+ A[0, 2] * (A[1, 0] * A[2, 1]
- A[1, 1] * A[2, 0]))
elif A.shape[0] == 2:
detA = A[0, 0] * A[1, 1] - A[1, 0] * A[0, 1]
return detA
Loading

0 comments on commit 580df1a

Please sign in to comment.