From ee4acc32d3f49ff25eef05e48a196899fb7ae89c Mon Sep 17 00:00:00 2001 From: Tom Gustafsson Date: Sun, 14 Jan 2024 22:45:12 +0200 Subject: [PATCH] allow indexing and add nonlinear elasticity test --- skfem/experimental/autodiff/__init__.py | 3 ++ tests/test_autodiff.py | 37 +++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/skfem/experimental/autodiff/__init__.py b/skfem/experimental/autodiff/__init__.py index 78874b53..4664f6b8 100644 --- a/skfem/experimental/autodiff/__init__.py +++ b/skfem/experimental/autodiff/__init__.py @@ -54,6 +54,9 @@ def __rmul__(self, other): def __array__(self): return self.value + def __getitem__(self, index): + return self.value[index] + @property def shape(self): return self.value.shape diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index eb98af7f..898b3dc9 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -6,9 +6,11 @@ from skfem.experimental.autodiff import NonlinearForm from skfem.experimental.autodiff.helpers import (grad, dot, ddot, mul, - div, sym_grad) + div, sym_grad, + transpose, + eye, trace) from skfem.assembly import Basis -from skfem.mesh import MeshTri +from skfem.mesh import MeshTri, MeshQuad from skfem.element import (ElementTriP1, ElementTriP2, ElementVector) from skfem.utils import solve, condense @@ -88,3 +90,34 @@ def navierstokes(u, p, v, q, w): (u, ubasis), (p, pbasis) = basis.split(x) assert_almost_equal(np.max(p), 5212.45466, decimal=5) + + +def test_nonlin_elast(): + + m = MeshQuad.init_tensor(np.linspace(0, 5, 20), + np.linspace(0, 0.5, 5)).to_meshtri(style='x') + e = ElementVector(ElementTriP1()) + basis = Basis(m, e) + x = basis.zeros() + + @NonlinearForm + def elast(u, v, w): + epsu = .5 * (grad(u) + transpose(grad(u)) + + mul(transpose(grad(u)), grad(u))) + epsv = .5 * (grad(v) + transpose(grad(v))) + sigu = 2 * 10 * epsu + 1. * eye(trace(epsu), 2) + return ddot(sigu, epsv) - w.t * 2e-2 * v[1] + + + for itr in range(50): + xp = x.copy() + x += solve(*condense(*elast.assemble(basis, + x=x, + t=np.minimum((itr + 1) / 5, 1)), + D=basis.get_dofs({'left'}).all())) + res = jnp.linalg.norm(x - xp) + print(res) + if res < 1e-8: + break + + assert_almost_equal(np.max(x), 2.83411524813795)