Skip to content

Commit

Permalink
fixed index checks for cylinders
Browse files Browse the repository at this point in the history
Signed-off-by: Nick Papior <[email protected]>
  • Loading branch information
zerothi committed Aug 11, 2023
1 parent 5b4c872 commit fe0bdf4
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 2 deletions.
60 changes: 60 additions & 0 deletions src/sisl/_indices.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,66 @@ cdef void _indices_sorted_arrays(const Py_ssize_t n_element, const int[::1] elem
for j in range(j, n_test_element):
idx[j] = -1

@cython.boundscheck(False)
@cython.wraparound(False)
def indices_in_cylinder(np.ndarray[np.float64_t, ndim=2, mode='c'] dxyz, const double R, const double h):
""" Indices for all coordinates that are within a cylinde radius `R` and height `h`
Parameters
----------
dxyz : ndarray(np.float64)
coordinates centered around the cylinder
R : float
radius of cylinder to check
h : float
height of cylinder to check
Returns
-------
index : np.ndarray(np.int32)
indices of all dxyz coordinates that are within the cylinder
"""
cdef double[:, ::1] dXYZ = dxyz
cdef Py_ssize_t n = dXYZ.shape[0]
cdef np.ndarray[np.int32_t, ndim=1] idx = np.empty([n], dtype=np.int32)
cdef int[::1] IDX = idx

n = _indices_in_cylinder(dXYZ, R, h, IDX)

if n == 0:
return np.empty([0], dtype=np.int32)
return idx[:n].copy()


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.initializedcheck(False)
cdef Py_ssize_t _indices_in_cylinder(const double[:, ::1] dxyz, const double R, const double h, int[::1] idx) nogil:
cdef Py_ssize_t N = dxyz.shape[0]
cdef Py_ssize_t xyz = dxyz.shape[1]
cdef double R2 = R * R
cdef double L2
cdef Py_ssize_t i, j, n
cdef int skip

# Reset number of elements
n = 0

for i in range(N):
skip = 0
for j in range(xyz-1):
skip |= dxyz[i, j] > R
if skip or dxyz[i, -1] > h: continue

L2 = 0.
for j in range(xyz-1):
L2 += dxyz[i, j] * dxyz[i, j]
if L2 > R2: continue
idx[n] = i
n += 1

return n


@cython.boundscheck(False)
@cython.wraparound(False)
Expand Down
5 changes: 3 additions & 2 deletions src/sisl/shape/_cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

import sisl._array as _a
from sisl._indices import indices_in_sphere
from sisl._indices import indices_in_cylinder
from sisl._internal import set_module
from sisl.messages import warn
from sisl.utils.mathematics import expand, fnorm, fnorm2, orthogonalize
Expand Down Expand Up @@ -152,7 +152,7 @@ def within_index(self, other, tol=1.e-8):
# Get indices where we should do the more
# expensive exact check of being inside shape
# I.e. this reduces the search space to the box
return indices_in_sphere(tmp, 1. + tol)
return indices_in_cylinder(tmp, 1. + tol, 1. + tol)

@property
def height(self):
Expand All @@ -177,6 +177,7 @@ def height_vector(self):
def toSphere(self):
""" Convert to a sphere """
from .ellipsoid import Sphere

# figure out the distance from the center to the edge (along longest radius)
h = self.height / 2
r = self.radius.max()
Expand Down
8 changes: 8 additions & 0 deletions src/sisl/shape/tests/test_cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def test_create_ellipticalcylinder():
str(el)


def test_ellipticalcylinder_within():
el = EllipticalCylinder(1., 1.)
# center of cylinder
assert el.within_index([0, 0, 0])[0] == 0
# should not be in a circle
assert el.within_index([0.2, 0.2, 0.9])[0] == 0


def test_tosphere():
el = EllipticalCylinder([1., 1.], 1.)
el.to.Sphere()
Expand Down

0 comments on commit fe0bdf4

Please sign in to comment.