Skip to content

Commit

Permalink
Add extrapolation outside of mesh
Browse files Browse the repository at this point in the history
  • Loading branch information
Hadrien committed Jun 28, 2024
1 parent d1956d1 commit 608260a
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 8 deletions.
7 changes: 3 additions & 4 deletions folie/analysis/overdamped_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ def free_energy_profile_1d(model, x):

def grad_V(x, _):
x = np.asarray(x).reshape(-1, 1)
diff_prime_val = model.diffusion.grad_x(x).ravel()
drift_val = model.drift(x).ravel()
diff_val = model.diffusion(x).ravel()
# print(np.min(diff_val), x[np.argmin(diff_val)])
diff_prime_val = model.diffusion.grad_x(x).flatten()
drift_val = model.drift(x).flatten()
diff_val = model.diffusion(x).flatten()
return (-drift_val + diff_prime_val) / diff_val

sol = solve_ivp(grad_V, [x.min() - 1e-10, x.max() + 1e-10], np.array([0.0]), t_eval=x) # Add some epsilon to range to ensure inclusion of x
Expand Down
8 changes: 4 additions & 4 deletions folie/domains/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ..data import stats_from_input_data
from .._numpy import np
import skfem
from .element_finder import get_element_finder

# TODO: Add gestion of the periodicity

Expand Down Expand Up @@ -53,8 +54,8 @@ def Rd(cls, dim):
if dim < 1:
dim = 1
range = np.empty((2, dim))
range[0, :] = -np.inf
range[1, :] = np.inf
range[0, :] = -np.infty
range[1, :] = np.infty
return cls(range)

@classmethod
Expand Down Expand Up @@ -93,8 +94,7 @@ def localize_data(self, X, mapping=None):
"""
if mapping is None:
mapping = self.mesh.mapping()
cells = self.mesh.element_finder(mapping=mapping)(*(X.T)) # Change the element finder
# Find a way to exclude out of mesh elements, we can define an outside elements that is a constant
cells = get_element_finder(self.mesh, mapping=mapping)(*(X.T))
loc_x = mapping.invF(X.T[:, :, np.newaxis], tind=cells)
return cells, loc_x[..., 0].T

Expand Down
266 changes: 266 additions & 0 deletions folie/domains/element_finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""
Contain utilities function to find element corresponding to positions.
Allow for extrapolation of FEM by using the closest element when outside of the mesh
Adapted from scikit-fem
"""

import numpy as np
import skfem


def get_element_finder(mesh, mapping=None):
"""
Get correct element finder depending of the type of the mesh
"""
if isinstance(mesh, skfem.MeshLine):
return element_finder_line(mesh, mapping)
elif isinstance(mesh, skfem.MeshTri):
return element_finder_tri(mesh, mapping)
elif isinstance(mesh, skfem.MeshQuad):
return element_finder_quad(mesh, mapping)
elif isinstance(mesh, skfem.MeshTet):
return element_finder_tet(mesh, mapping)
elif isinstance(mesh, skfem.MeshHex):
return element_finder_hex(mesh, mapping)
elif isinstance(mesh, skfem.MeshWedge1):
return element_finder_wedge(mesh, mapping)
else:
return mesh.element_finder(mapping)


def points_to_segment_dist(points, seg_start, seg_end):
"""Calculate the distance between points and segments."""
seg_vectors = seg_end - seg_start
points_vectors = points[:, :, None] - seg_start[:, None, :]

seg_length_sq = np.sum(seg_vectors**2, axis=0)
t = np.einsum("dij,dj->ij", points_vectors, seg_vectors) / seg_length_sq[None, :]
t = np.clip(t, 0, 1)
projections = seg_start[:, None, :] + t[None, :, :] * seg_vectors[:, None, :]
distances_sq = np.sum((points[:, :, None] - projections) ** 2, axis=0)
return distances_sq


def dist_tri_2d(x, y, tri):
"""
Compute distance to edge
tri is of shape (dim, 3, ntri)
Shoud return array of shape (len(x), ntri)
"""
# Juste pour tester on retourne la distance au centre
points = np.array([x, y])
distances = np.empty((x.shape[0], *tri.shape[1:]))
for i in range(3):
distances[:, i, :] = points_to_segment_dist(points, tri[:, i, :], tri[:, (i + 1) % 3, :])
return np.min(distances, axis=1)

# Using distances to center of triangle
center_tri = np.mean(tri, axis=1)
# print(center_tri.shape)
dists = (x[:, None] - center_tri[0:1, :]) ** 2 + (y[:, None] - center_tri[1:2, :]) ** 2
return dists


def points_to_facets_dist(points, triangles):
"""Calculate the distance between points and triangles in 3D in a vectorized manner."""

v0 = triangles[:, 2] - triangles[:, 0]
v1 = triangles[:, 1] - triangles[:, 0]
normals = np.cross(v0, v1)
normals = normals / np.linalg.norm(normals, axis=1)[:, None]

points_vectors = points[:, None, :] - triangles[:, None, 0, :]
dist_to_planes = np.einsum("ijk,ik->ij", points_vectors, normals)

projections = points[:, None, :] - dist_to_planes[:, :, None] * normals[None, :, :]

v2_proj = projections - triangles[:, None, 0, :]

dot00 = np.einsum("ij,ij->i", v0, v0)
dot01 = np.einsum("ij,ij->i", v0, v1)
dot11 = np.einsum("ij,ij->i", v1, v1)
dot02 = np.einsum("ijk,ik->ij", v2_proj, v0)
dot12 = np.einsum("ijk,ik->ij", v2_proj, v1)

inv_denom = 1 / (dot00 * dot11 - dot01 * dot01)
u = (dot11 * dot02 - dot01 * dot12) * inv_denom[:, None]
v = (dot00 * dot12 - dot01 * dot02) * inv_denom[:, None]

inside = (u >= 0) & (v >= 0) & (u + v <= 1)
dist_to_planes_abs = np.abs(dist_to_planes)

distances = np.full((points.shape[0], triangles.shape[0]), np.inf)

distances[inside] = dist_to_planes_abs[inside]

for i in range(3):
seg_start = triangles[:, i]
seg_end = triangles[:, (i + 1) % 3]
dist_to_segments = points_to_segment_dist(points.T, seg_start.T, seg_end.T)
distances = np.minimum(distances, dist_to_segments)

min_distances = np.min(distances, axis=1)
return min_distances


def dist_tet_3d(x, y, z, tet):
"""
Compute distance to edge
tri is of shape (dim, 4, ntet)
Shoud return array of shape (len(x), ntet)
"""
# Juste pour tester on retourne la distance au centre
points = np.array([x, y, z])
distances = np.empty((x.shape[0], *tet.shape[1:]))
facets = [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)]
for i in range(4):
distances[:, i, :] = points_to_facets_dist(points.T, tet[:, facets[i], :].T)
return np.min(distances, axis=1) # Distances to tetrahedral is shorst distance to facet

# Using distances to center of tetrahedrals
center_tri = np.mean(tet, axis=1)
# print(center_tri.shape)
dists = (x[:, None] - center_tri[0:1, :]) ** 2 + (y[:, None] - center_tri[1:2, :]) ** 2 + (z[:, None] - center_tri[2:3, :]) ** 2
return dists


def element_finder_line(mesh, mapping=None):

ix = np.argsort(mesh.p[0])
maxt = mesh.t[np.argmax(mesh.p[0, mesh.t], 0), np.arange(mesh.t.shape[1])]

def finder(x):
xin = x.copy() # bring endpoint inside for np.digitize
xin[x == mesh.p[0, ix[-1]]] = mesh.p[0, ix[-2:]].mean()
bins = np.digitize(xin, mesh.p[0, ix])
bins[bins == 0] = 1
bins[bins == len(mesh.p[0, ix])] = len(mesh.p[0, ix]) - 1
elems = np.nonzero(ix[bins][:, None] == maxt)[1]
if len(elems) < len(x):
raise ValueError("Point is outside of the mesh.")
return elems

return finder


def element_finder_tri(mesh, mapping=None):

if mapping is None:
mapping = mesh._mapping()

if not hasattr(mesh, "_cached_tree"):
from scipy.spatial import cKDTree

mesh._cached_tree = cKDTree(np.mean(mesh.p[:, mesh.t], axis=1).T)
tree = mesh._cached_tree
nelems = mesh.t.shape[1]

def finder(x, y, _search_all=False):

if not _search_all:
ix = tree.query(np.array([x, y]).T, min(5, nelems))[1].flatten()
_, ix_ind = np.unique(ix, return_index=True)
ix = ix[np.sort(ix_ind)]
else:
ix = np.arange(nelems, dtype=np.int64)

X = mapping.invF(np.array([x, y])[:, None], ix)
eps = np.finfo(X.dtype).eps
inside = (X[0] >= -eps) * (X[1] >= -eps) * (1 - X[0] - X[1] >= -eps)
out_elems = ix[inside.argmax(axis=0)].flatten()

if not inside.max(axis=0).all():
if _search_all:
outside = np.nonzero(~inside.max(axis=0))[0]
x_out, y_out = x[outside], y[outside]
# Not necessary to loop for all elements then
ix_out = tree.query(np.array([x_out, y_out]).T, min(10, nelems))[1].flatten()
_, ix_ind = np.unique(ix_out, return_index=True)
ix_out = ix_out[np.sort(ix_ind)]

dists = dist_tri_2d(x_out, y_out, mesh.p[:, mesh.t[:, ix_out]])
out_elems[outside] = ix_out[np.argmin(dists, axis=1)]
# raise ValueError("Point is outside of the mesh.")
else:
return finder(x, y, _search_all=True)

return out_elems

return finder


def element_finder_quad(mesh, mapping=None):
"""Transform to :class:`skfem.MeshTri` and return its finder."""
tri_finder = mesh.to_meshtri().element_finder()

def finder(*args):
return tri_finder(*args) % mesh.t.shape[1]

return finder


def element_finder_tet(mesh, mapping=None):

if mapping is None:
mapping = mesh._mapping()

if not hasattr(mesh, "_cached_tree"):
from scipy.spatial import cKDTree

mesh._cached_tree = cKDTree(np.mean(mesh.p[:, mesh.t], axis=1).T)

tree = mesh._cached_tree
nelems = mesh.t.shape[1]

def finder(x, y, z, _search_all=False):

if not _search_all:
ix = tree.query(np.array([x, y, z]).T, min(10, nelems))[1].flatten()
_, ix_ind = np.unique(ix, return_index=True)
ix = ix[np.sort(ix_ind)]
else:
ix = np.arange(nelems, dtype=np.int64)

X = mapping.invF(np.array([x, y, z])[:, None], ix)
eps = np.finfo(X.dtype).eps
inside = (X[0] >= -eps) * (X[1] >= -eps) * (X[2] >= -eps) * (1 - X[0] - X[1] - X[2] >= -eps)
out_elems = ix[inside.argmax(axis=0)].flatten()
if not inside.max(axis=0).all():
if _search_all:
outside = np.nonzero(~inside.max(axis=0))[0]
x_out, y_out, z_out = x[outside], y[outside], z[outside]
# Not necessary to loop for all elements then
ix_out = tree.query(np.array([x_out, y_out]).T, min(20, nelems))[1].flatten()
_, ix_ind = np.unique(ix_out, return_index=True)
ix_out = ix_out[np.sort(ix_ind)]

dists = dist_tet_3d(x_out, y_out, z_out, mesh.p[:, mesh.t[:, ix_out]])
out_elems[outside] = ix_out[np.argmin(dists, axis=1)]
# raise ValueError("Point is outside of the mesh.")
else:
return finder(x, y, z, _search_all=True)

return out_elems

return finder


def element_finder_wedge(mesh, mapping=None):
"""Transform to :class:`skfem.MeshTet` and return its finder."""
tet_finder = element_finder_tet(mesh.to_meshtet())

def finder(*args):
return tet_finder(*args) % mesh.t.shape[1]

return finder


def element_finder_hex(mesh, mapping=None):
"""Transform to :class:`skfem.MeshTet` and return its finder."""
tet_finder = element_finder_tet(mesh.to_meshtet())

def finder(*args):
return tet_finder(*args) % mesh.t.shape[1]

return finder

0 comments on commit 608260a

Please sign in to comment.