Skip to content

Commit

Permalink
making number of reflected points flexible
Browse files Browse the repository at this point in the history
  • Loading branch information
jgostick committed Sep 15, 2023
1 parent e296032 commit cdb122b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
10 changes: 8 additions & 2 deletions openpnm/_skgraph/generators/_voronoi_delaunay_dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ def voronoi_delaunay_dual(points, shape, trim=True, reflect=True, relaxation=0,
vor : Voronoi object
The Voronoi tessellation object produced by ``scipy.spatial.Voronoi``
tri : Delaunay object
The Delaunay triangulation object produced ``scipy.spatial.Delaunay``
The Delaunay triangulation object produced by ``scipy.spatial.Delaunay``
"""
# Generate a set of base points if scalar was given
points = tools.parse_points(points=points, shape=shape, reflect=reflect)
points = tools.parse_points(
points=points,
shape=shape,
reflect=reflect,
f=0.2,
)

# Generate mask to remove any dims with all 0's
mask = ~np.all(points == 0, axis=0)

Expand Down
29 changes: 24 additions & 5 deletions openpnm/_skgraph/generators/tools/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def center_of_mass(simplices, points):
return CoM


def parse_points(shape, points, reflect=False):
def parse_points(shape, points, reflect=False, f=2):
r"""
Converts given points argument to consistent format
Expand All @@ -144,6 +144,7 @@ def parse_points(shape, points, reflect=False):
"""
# Deal with input arguments
shape = np.array(shape, dtype=int)
if isinstance(points, int):
points = generate_base_points(num_points=points,
domain_size=shape,
Expand All @@ -162,17 +163,17 @@ def parse_points(shape, points, reflect=False):
raise Exception('Some points lie outside the domain, '
+ 'cannot safely apply reflection')
if len(shape) == 3:
points = reflect_base_points(points=points, domain_size=shape)
points = reflect_base_points(points=points, domain_size=shape, f=f)
elif len(shape) == 2:
# Convert xyz to cylindrical, and back
R, Q, Z = tools.cart2cyl(*points.T)
R, Q, Z = reflect_base_points(np.vstack((R, Q, Z)), domain_size=shape)
R, Q, Z = reflect_base_points(np.vstack((R, Q, Z)), domain_size=shape, f=f)
# Convert back to cartesean coordinates
points = np.vstack(tools.cyl2cart(R, Q, Z)).T
elif len(shape) == 1:
# Convert to spherical coordinates
R, Q, P = tools.cart2sph(*points.T)
R, Q, P = reflect_base_points(np.vstack((R, Q, P)), domain_size=shape)
R, Q, P = reflect_base_points(np.vstack((R, Q, P)), domain_size=shape, f=f)
# Convert to back to cartesean coordinates
points = np.vstack(tools.sph2cart(R, Q, P)).T
return points
Expand Down Expand Up @@ -346,7 +347,7 @@ def template_cylinder_annulus(z, r_outer, r_inner=0):
return img


def reflect_base_points(points, domain_size):
def reflect_base_points(points, domain_size, f=2):
r"""
Relects a set of points about the faces of a given domain
Expand Down Expand Up @@ -382,6 +383,10 @@ def reflect_base_points(points, domain_size):
theta = np.hstack([theta, theta])
phi = np.hstack([phi, phi])
points = np.vstack((r, theta, phi))
# Trim excess points outside radius
hi = domain_size[0]*(1+f)
keep = (points[0, :] <= hi)
points = points[:, keep]
if len(domain_size) == 2:
r, theta, z = points
new_r = 2*domain_size[0] - r
Expand All @@ -393,6 +398,14 @@ def reflect_base_points(points, domain_size):
theta = np.hstack([theta, theta, theta])
z = np.hstack([z, -z, 2*domain_size[1]-z])
points = np.vstack((r, theta, z))
# Trim excess basepoints above and below cylinder
hi = domain_size[1]*(1+f)
lo = domain_size[1]*(-f)
keep = (points[2, :] <= hi)*(points[2, :] >= lo)
# Trim excess points outside radius
hi = domain_size[0]*(1+f)
keep *= (points[0, :] <= hi)
points = points[:, keep]
elif len(domain_size) == 3:
Nx, Ny, Nz = domain_size
# Reflect base points about all 6 faces
Expand All @@ -407,6 +420,12 @@ def reflect_base_points(points, domain_size):
points = np.vstack((points,
[1, 1, -1] * orig_pts + [0, 0, 2.0 * Nz]))
points = np.vstack((points, [1, 1, -1] * orig_pts))
# Trim excess basepoints
hi = domain_size*(1+f)
lo = domain_size*(-f)
keep = np.all(points <= hi, axis=1)
keep *= np.all(points >= lo, axis=1)
points = points[keep, :]
return points


Expand Down

0 comments on commit cdb122b

Please sign in to comment.