From 10fc8a682b30a5d0d766267706c7469974d0e78a Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 12 Feb 2025 17:06:43 -0500 Subject: [PATCH] Refactoring `Basis`/`Transform`/`Grid` stuff (#1405) - `Transform(method="fft")` now works with grids that have an even number of toroidal nodes - Adds attributes to `Basis` classes for unique mode numbers, inverse indices etc, similar to grids - Removes `unique` option `Basis.evaluate` in favor of always using precomputed unique modes/nodes (#1530) - Adds attributes to grids denoting whether they work with ffts in poloidal/toroidal directions (#1243) --------- Co-authored-by: Yigit Gunsur Elmacioglu <102380275+YigitElma@users.noreply.github.com> Co-authored-by: Dario Panici <37969854+dpanici@users.noreply.github.com> Co-authored-by: Dario Panici --- CHANGELOG.md | 4 + desc/basis.py | 508 ++++++++++++++++++++++++------------- desc/equilibrium/coords.py | 8 +- desc/grid.py | 75 +++++- desc/plotting.py | 164 ++++++------ desc/profiles.py | 2 +- desc/transform.py | 256 ++++++------------- desc/vmec_utils.py | 6 +- tests/test_plotting.py | 15 +- tests/test_transform.py | 110 +++----- 10 files changed, 614 insertions(+), 534 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 30468b88fa..1352fec262 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,10 @@ for compatibility with other codes which expect such files from the Booz_Xform c - Renames compute quantity ``sqrt(g)_B`` to ``sqrt(g)_Boozer_DESC`` to more accurately reflect what the quantiy is (the jacobian from (rho,theta_B,zeta_B) to (rho,theta,zeta)), and adds a new function to compute ``sqrt(g)_Boozer`` which is the jacobian from (rho,theta_B,zeta_B) to (R,phi,Z). - Allows specification of Nyquist spectrum maximum modenumbers when using ``VMECIO.save`` to save a DESC .h5 file as a VMEC-format wout file +Speed Improvements + +- A number of minor improvements to basis function evaluation and spectral transforms to improve speed. These will also enable future improvements for larger gains. + Bug Fixes - Small bug fix to use the correct normalization length ``a`` in the BallooningStability objective. diff --git a/desc/basis.py b/desc/basis.py index aff3a4a3c0..eb882c39bd 100644 --- a/desc/basis.py +++ b/desc/basis.py @@ -8,6 +8,7 @@ import numpy as np from desc.backend import custom_jvp, fori_loop, jit, jnp, sign +from desc.grid import Grid, _Grid from desc.io import IOAble from desc.utils import check_nonnegint, check_posint, flatten_list @@ -45,6 +46,16 @@ def __init__(self): self._N = int(self._N) self._NFP = int(self._NFP) self._modes = self._modes.astype(int) + ( + self._unique_L_idx, + self._inverse_L_idx, + self._unique_M_idx, + self._inverse_M_idx, + self._unique_N_idx, + self._inverse_N_idx, + self._unique_LM_idx, + self._inverse_LM_idx, + ) = self._find_unique_inverse_modes() def _set_up(self): """Do things after loading or changing resolution.""" @@ -59,6 +70,16 @@ def _set_up(self): self._N = int(self._N) self._NFP = int(self._NFP) self._modes = self._modes.astype(int) + ( + self._unique_L_idx, + self._inverse_L_idx, + self._unique_M_idx, + self._inverse_M_idx, + self._unique_N_idx, + self._inverse_N_idx, + self._unique_LM_idx, + self._inverse_LM_idx, + ) = self._find_unique_inverse_modes() def _enforce_symmetry(self): """Enforce stellarator symmetry.""" @@ -87,6 +108,32 @@ def _enforce_symmetry(self): elif self.sym is None: self._sym = False + def _find_unique_inverse_modes(self): + """Find unique values of modes and their indices.""" + __, unique_L_idx, inverse_L_idx = np.unique( + self.modes[:, 0], return_index=True, return_inverse=True + ) + __, unique_M_idx, inverse_M_idx = np.unique( + self.modes[:, 1], return_index=True, return_inverse=True + ) + __, unique_N_idx, inverse_N_idx = np.unique( + self.modes[:, 2], return_index=True, return_inverse=True + ) + __, unique_LM_idx, inverse_LM_idx = np.unique( + self.modes[:, :2], axis=0, return_index=True, return_inverse=True + ) + + return ( + unique_L_idx, + inverse_L_idx, + unique_M_idx, + inverse_M_idx, + unique_N_idx, + inverse_N_idx, + unique_LM_idx, + inverse_LM_idx, + ) + def _sort_modes(self): """Sorts modes for use with FFT.""" sort_idx = np.lexsort((self.modes[:, 1], self.modes[:, 0], self.modes[:, 2])) @@ -138,22 +185,17 @@ def _get_modes(self): """ndarray: Mode numbers for the basis.""" @abstractmethod - def evaluate( - self, nodes, derivatives=np.array([0, 0, 0]), modes=None, unique=False - ): + def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): """Evaluate basis functions at specified nodes. Parameters ---------- - nodes : ndarray of float, size(num_nodes,3) - node coordinates, in (rho,theta,zeta) + grid : Grid or ndarray of float, size(num_nodes,3) + Node coordinates, in (rho,theta,zeta). derivatives : ndarray of int, shape(3,) order of derivatives to compute in (rho,theta,zeta) modes : ndarray of in, shape(num_modes,3), optional basis modes to evaluate (if None, full basis is used) - unique : bool, optional - whether to reduce workload by only calculating for unique values of nodes, - modes can be faster, but doesn't work with jit or autodiff Returns ------- @@ -207,6 +249,60 @@ def spectral_indexing(self): """str: Type of indexing used for the spectral basis.""" return self.__dict__.setdefault("_spectral_indexing", "linear") + @property + def fft_poloidal(self): + """bool: whether this basis is compatible with fft in the poloidal direction.""" + if not hasattr(self, "_fft_poloidal"): + self._fft_poloidal = False + return self._fft_poloidal + + @property + def fft_toroidal(self): + """bool: whether this basis is compatible with fft in the toroidal direction.""" + if not hasattr(self, "_fft_toroidal"): + self._fft_toroidal = False + return self._fft_toroidal + + @property + def unique_L_idx(self): + """ndarray: Indices of unique radial modes.""" + return self._unique_L_idx + + @property + def unique_M_idx(self): + """ndarray: Indices of unique poloidal modes.""" + return self._unique_M_idx + + @property + def unique_N_idx(self): + """ndarray: Indices of unique toroidal modes.""" + return self._unique_N_idx + + @property + def unique_LM_idx(self): + """ndarray: Indices of unique radial/poloidal mode pairs.""" + return self._unique_LM_idx + + @property + def inverse_L_idx(self): + """ndarray: Indices of unique_L_idx that recover the radial modes.""" + return self._inverse_L_idx + + @property + def inverse_M_idx(self): + """ndarray: Indices of unique_M_idx that recover the poloidal modes.""" + return self._inverse_M_idx + + @property + def inverse_N_idx(self): + """ndarray: Indices of unique_N_idx that recover the toroidal modes.""" + return self._inverse_N_idx + + @property + def inverse_LM_idx(self): + """ndarray: Indices of unique_LM_idx that recover the LM mode pairs.""" + return self._inverse_LM_idx + def __repr__(self): """Get the string form of the object.""" return ( @@ -234,6 +330,9 @@ class PowerSeries(_Basis): """ + _fft_poloidal = True # trivially true + _fft_toroidal = True + def __init__(self, L, sym="even"): self._L = check_nonnegint(L, "L", False) self._M = 0 @@ -265,22 +364,17 @@ def _get_modes(self, L): z = np.zeros_like(l) return np.array([l, z, z]).T - def evaluate( - self, nodes, derivatives=np.array([0, 0, 0]), modes=None, unique=False - ): + def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): """Evaluate basis functions at specified nodes. Parameters ---------- - nodes : ndarray of float, size(num_nodes,3) + grid : Grid or ndarray of float, size(num_nodes,3) Node coordinates, in (rho,theta,zeta). derivatives : ndarray of int, shape(num_derivatives,3) Order of derivatives to compute in (rho,theta,zeta). modes : ndarray of in, shape(num_modes,3), optional Basis modes to evaluate (if None, full basis is used) - unique : bool, optional - whether to reduce workload by only calculating for unique values of nodes, - modes can be faster, but doesn't work with jit or autodiff Returns ------- @@ -288,29 +382,30 @@ def evaluate( basis functions evaluated at nodes """ + if not isinstance(grid, _Grid): + grid = Grid(grid, sort=False, jitable=True) if modes is None: modes = self.modes + lidx = self.unique_L_idx + loutidx = self.inverse_L_idx + else: + lidx = loutidx = np.arange(len(modes)) if (derivatives[1] != 0) or (derivatives[2] != 0): - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return jnp.zeros((grid.num_nodes, modes.shape[0])) if not len(modes): - return np.array([]).reshape((len(nodes), 0)) + return np.array([]).reshape((grid.num_nodes, 0)) - r, t, z = nodes.T - l, m, n = modes.T + try: + ridx = grid.unique_rho_idx + routidx = grid.inverse_rho_idx + except AttributeError: + ridx = routidx = np.arange(grid.num_nodes) - if unique: - _, ridx, routidx = np.unique( - r, return_index=True, return_inverse=True, axis=0 - ) - _, lidx, loutidx = np.unique( - l, return_index=True, return_inverse=True, axis=0 - ) - r = r[ridx] - l = l[lidx] + r = grid.nodes[ridx, 0] + l = modes[lidx, 0] radial = powers(r, l, dr=derivatives[0]) - if unique: - radial = radial[routidx][:, loutidx] + radial = radial[routidx, :][:, loutidx] return radial @@ -347,6 +442,9 @@ class FourierSeries(_Basis): """ + _fft_poloidal = True + _fft_toroidal = True + def __init__(self, N, NFP=1, sym=False): self._L = 0 self._M = 0 @@ -378,22 +476,17 @@ def _get_modes(self, N): z = np.zeros_like(n) return np.array([z, z, n]).T - def evaluate( - self, nodes, derivatives=np.array([0, 0, 0]), modes=None, unique=False - ): + def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): """Evaluate basis functions at specified nodes. Parameters ---------- - nodes : ndarray of float, size(num_nodes,3) + grid : Grid or ndarray of float, size(num_nodes,3) Node coordinates, in (rho,theta,zeta). derivatives : ndarray of int, shape(num_derivatives,3) Order of derivatives to compute in (rho,theta,zeta). modes : ndarray of in, shape(num_modes,3), optional Basis modes to evaluate (if None, full basis is used). - unique : bool, optional - Whether to workload by only calculating for unique values of nodes, modes - can be faster, but doesn't work with jit or autodiff. Returns ------- @@ -404,29 +497,30 @@ def evaluate( [sin(N𝛇), ..., sin(𝛇), 1, cos(𝛇), ..., cos(N𝛇)]. """ + if not isinstance(grid, _Grid): + grid = Grid(grid, sort=False, jitable=True) if modes is None: modes = self.modes + nidx = self.unique_N_idx + noutidx = self.inverse_N_idx + else: + nidx = noutidx = np.arange(len(modes)) if (derivatives[0] != 0) or (derivatives[1] != 0): - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return jnp.zeros((grid.num_nodes, modes.shape[0])) if not len(modes): - return np.array([]).reshape((len(nodes), 0)) + return np.array([]).reshape((grid.num_nodes, 0)) - r, t, z = nodes.T - l, m, n = modes.T + try: + zidx = grid.unique_zeta_idx + zoutidx = grid.inverse_zeta_idx + except AttributeError: + zidx = zoutidx = np.arange(grid.num_nodes) - if unique: - _, zidx, zoutidx = np.unique( - z, return_index=True, return_inverse=True, axis=0 - ) - _, nidx, noutidx = np.unique( - n, return_index=True, return_inverse=True, axis=0 - ) - z = z[zidx] - n = n[nidx] + z = grid.nodes[zidx, 2] + n = modes[nidx, 2] toroidal = fourier(z[:, np.newaxis], n, self.NFP, derivatives[2]) - if unique: - toroidal = toroidal[zoutidx][:, noutidx] + toroidal = toroidal[zoutidx, :][:, noutidx] return toroidal @@ -472,6 +566,9 @@ class DoubleFourierSeries(_Basis): """ + _fft_poloidal = True + _fft_toroidal = True + def __init__(self, M, N, NFP=1, sym=False): self._L = 0 self._M = check_nonnegint(M, "M", False) @@ -507,22 +604,17 @@ def _get_modes(self, M, N): z = np.zeros_like(m) return np.array([z, m, n]).T - def evaluate( - self, nodes, derivatives=np.array([0, 0, 0]), modes=None, unique=False - ): + def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): """Evaluate basis functions at specified nodes. Parameters ---------- - nodes : ndarray of float, size(num_nodes,3) + grid : Grid or ndarray of float, size(num_nodes,3) Node coordinates, in (rho,theta,zeta). derivatives : ndarray of int, shape(num_derivatives,3) Order of derivatives to compute in (rho,theta,zeta). modes : ndarray of in, shape(num_modes,3), optional Basis modes to evaluate (if None, full basis is used). - unique : bool, optional - Whether to workload by only calculating for unique values of nodes, modes - can be faster, but doesn't work with jit or autodiff. Returns ------- @@ -535,39 +627,45 @@ def evaluate( βŠ— [sin(N𝛇), ..., sin(𝛇), 1, cos(𝛇), ..., cos(N𝛇)]. """ + if not isinstance(grid, _Grid): + grid = Grid(grid, sort=False, jitable=True) if modes is None: modes = self.modes + midx = self.unique_M_idx + nidx = self.unique_N_idx + moutidx = self.inverse_M_idx + noutidx = self.inverse_N_idx + else: + midx = moutidx = np.arange(len(modes)) + nidx = noutidx = np.arange(len(modes)) if derivatives[0] != 0: - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return jnp.zeros((grid.num_nodes, modes.shape[0])) if not len(modes): - return np.array([]).reshape((len(nodes), 0)) + return np.array([]).reshape((grid.num_nodes, 0)) - r, t, z = nodes.T - l, m, n = modes.T + try: + zidx = grid.unique_zeta_idx + zoutidx = grid.inverse_zeta_idx + except AttributeError: + zidx = zoutidx = np.arange(grid.num_nodes) + try: + tidx = grid.unique_poloidal_idx + toutidx = grid.inverse_poloidal_idx + except AttributeError: + tidx = toutidx = np.arange(grid.num_nodes) - if unique: - _, tidx, toutidx = np.unique( - t, return_index=True, return_inverse=True, axis=0 - ) - _, zidx, zoutidx = np.unique( - z, return_index=True, return_inverse=True, axis=0 - ) - _, midx, moutidx = np.unique( - m, return_index=True, return_inverse=True, axis=0 - ) - _, nidx, noutidx = np.unique( - n, return_index=True, return_inverse=True, axis=0 - ) - t = t[tidx] - z = z[zidx] - m = m[midx] - n = n[nidx] + _, t, z = grid.nodes.T + _, m, n = modes.T + + t = t[tidx] + z = z[zidx] + m = m[midx] + n = n[nidx] poloidal = fourier(t[:, np.newaxis], m, 1, derivatives[1]) toroidal = fourier(z[:, np.newaxis], n, self.NFP, derivatives[2]) - if unique: - poloidal = poloidal[toutidx][:, moutidx] - toroidal = toroidal[zoutidx][:, noutidx] + poloidal = poloidal[toutidx][:, moutidx] + toroidal = toroidal[zoutidx][:, noutidx] return poloidal * toroidal @@ -633,6 +731,9 @@ class ZernikePolynomial(_Basis): """ + _fft_poloidal = False + _fft_toroidal = True + def __init__(self, L, M, sym=False, spectral_indexing="ansi"): self._L = check_nonnegint(L, "L", False) self._M = check_nonnegint(M, "M", False) @@ -719,22 +820,17 @@ def _get_modes(self, L, M, spectral_indexing="ansi"): return np.hstack([pol, tor]) - def evaluate( - self, nodes, derivatives=np.array([0, 0, 0]), modes=None, unique=False - ): + def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): """Evaluate basis functions at specified nodes. Parameters ---------- - nodes : ndarray of float, size(num_nodes,3) + grid : Grid or ndarray of float, size(num_nodes,3) Node coordinates, in (rho,theta,zeta). derivatives : ndarray of int, shape(num_derivatives,3) Order of derivatives to compute in (rho,theta,zeta). modes : ndarray of int, shape(num_modes,3), optional Basis modes to evaluate (if None, full basis is used). - unique : bool, optional - Whether to workload by only calculating for unique values of nodes, modes - can be faster, but doesn't work with jit or autodiff. Returns ------- @@ -742,41 +838,46 @@ def evaluate( Basis functions evaluated at nodes. """ + if not isinstance(grid, _Grid): + grid = Grid(grid, sort=False, jitable=True) if modes is None: modes = self.modes + lmidx = self.unique_LM_idx + lmoutidx = self.inverse_LM_idx + midx = self.unique_M_idx + moutidx = self.inverse_M_idx + else: + lmidx = lmoutidx = np.arange(len(modes)) + midx = moutidx = np.arange(len(modes)) if derivatives[2] != 0: - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return jnp.zeros((grid.num_nodes, modes.shape[0])) if not len(modes): - return np.array([]).reshape((len(nodes), 0)) + return np.array([]).reshape((grid.num_nodes, 0)) - r, t, z = nodes.T - l, m, n = modes.T + r, t, _ = grid.nodes.T lm = modes[:, :2] + m = modes[:, 1] - if unique: - _, ridx, routidx = np.unique( - r, return_index=True, return_inverse=True, axis=0 - ) - _, tidx, toutidx = np.unique( - t, return_index=True, return_inverse=True, axis=0 - ) - _, lmidx, lmoutidx = np.unique( - lm, return_index=True, return_inverse=True, axis=0 - ) - _, midx, moutidx = np.unique( - m, return_index=True, return_inverse=True, axis=0 - ) - r = r[ridx] - t = t[tidx] - lm = lm[lmidx] - m = m[midx] + try: + ridx = grid.unique_rho_idx + routidx = grid.inverse_rho_idx + except AttributeError: + ridx = routidx = np.arange(grid.num_nodes) + try: + tidx = grid.unique_theta_idx + toutidx = grid.inverse_theta_idx + except AttributeError: + tidx = toutidx = np.arange(grid.num_nodes) + + r = r[ridx] + t = t[tidx] + lm = lm[lmidx] + m = m[midx] radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0]) poloidal = fourier(t[:, np.newaxis], m, 1, derivatives[1]) - - if unique: - radial = radial[routidx][:, lmoutidx] - poloidal = poloidal[toutidx][:, moutidx] + radial = radial[routidx][:, lmoutidx] + poloidal = poloidal[toutidx][:, moutidx] return radial * poloidal @@ -829,6 +930,9 @@ class ChebyshevDoubleFourierBasis(_Basis): """ + _fft_poloidal = True + _fft_toroidal = True + def __init__(self, L, M, N, NFP=1, sym=False): self._L = check_nonnegint(L, "L", False) self._M = check_nonnegint(M, "M", False) @@ -869,22 +973,17 @@ def _get_modes(self, L, M, N): n = n.ravel() return np.array([l, m, n]).T - def evaluate( - self, nodes, derivatives=np.array([0, 0, 0]), modes=None, unique=False - ): + def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): """Evaluate basis functions at specified nodes. Parameters ---------- - nodes : ndarray of float, size(num_nodes,3) + grid : Grid or ndarray of float, size(num_nodes,3) Node coordinates, in (rho,theta,zeta). derivatives : ndarray of int, shape(num_derivatives,3) Order of derivatives to compute in (rho,theta,zeta). modes : ndarray of in, shape(num_modes,3), optional Basis modes to evaluate (if None, full basis is used). - unique : bool, optional - whether to reduce workload by only calculating for unique values of nodes, - modes can be faster, but doesn't work with jit or autodiff Returns ------- @@ -898,18 +997,57 @@ def evaluate( βŠ— [sin(N𝛇), ..., sin(𝛇), 1, cos(𝛇), ..., cos(N𝛇)]. """ + if not isinstance(grid, _Grid): + grid = Grid(grid, sort=False, jitable=True) if modes is None: modes = self.modes + lidx = self.unique_L_idx + midx = self.unique_M_idx + nidx = self.unique_N_idx + loutidx = self.inverse_L_idx + moutidx = self.inverse_M_idx + noutidx = self.inverse_N_idx + else: + lidx = loutidx = np.arange(len(modes)) + midx = moutidx = np.arange(len(modes)) + nidx = noutidx = np.arange(len(modes)) if not len(modes): - return np.array([]).reshape((len(nodes), 0)) + return np.array([]).reshape((grid.num_nodes, 0)) - r, t, z = nodes.T + r, t, z = grid.nodes.T l, m, n = modes.T + try: + ridx = grid.unique_rho_idx + routidx = grid.inverse_rho_idx + except AttributeError: + ridx = routidx = np.arange(grid.num_nodes) + try: + tidx = grid.unique_theta_idx + toutidx = grid.inverse_theta_idx + except AttributeError: + tidx = toutidx = np.arange(grid.num_nodes) + try: + zidx = grid.unique_zeta_idx + zoutidx = grid.inverse_zeta_idx + except AttributeError: + zidx = zoutidx = np.arange(grid.num_nodes) + + r = r[ridx] + t = t[tidx] + z = z[zidx] + l = l[lidx] + m = m[midx] + n = n[nidx] + radial = chebyshev(r[:, np.newaxis], l, dr=derivatives[0]) poloidal = fourier(t[:, np.newaxis], m, 1, derivatives[1]) toroidal = fourier(z[:, np.newaxis], n, self.NFP, derivatives[2]) + radial = radial[routidx][:, loutidx] + poloidal = poloidal[toutidx][:, moutidx] + toroidal = toroidal[zoutidx][:, noutidx] + return radial * poloidal * toroidal def change_resolution(self, L, M, N, NFP=None, sym=None): @@ -984,6 +1122,9 @@ class FourierZernikeBasis(_Basis): """ + _fft_poloidal = False + _fft_toroidal = True + def __init__(self, L, M, N, NFP=1, sym=False, spectral_indexing="ansi"): self._L = check_nonnegint(L, "L", False) self._M = check_nonnegint(M, "M", False) @@ -1075,22 +1216,17 @@ def _get_modes(self, L, M, N, spectral_indexing="ansi"): ).T return np.unique(np.hstack([pol, tor]), axis=0) - def evaluate( - self, nodes, derivatives=np.array([0, 0, 0]), modes=None, unique=False - ): + def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): """Evaluate basis functions at specified nodes. Parameters ---------- - nodes : ndarray of float, size(num_nodes,3) + grid : Grid or ndarray of float, size(num_nodes,3) Node coordinates, in (rho,theta,zeta). derivatives : ndarray of int, shape(num_derivatives,3) Order of derivatives to compute in (rho,theta,zeta). modes : ndarray of int, shape(num_modes,3), optional Basis modes to evaluate (if None, full basis is used). - unique : bool, optional - Whether to workload by only calculating for unique values of nodes, modes - can be faster, but doesn't work with jit or autodiff. Returns ------- @@ -1098,51 +1234,57 @@ def evaluate( Basis functions evaluated at nodes. """ + if not isinstance(grid, _Grid): + grid = Grid(grid, sort=False, jitable=True) if modes is None: modes = self.modes + lmidx = self.unique_LM_idx + midx = self.unique_M_idx + nidx = self.unique_N_idx + lmoutidx = self.inverse_LM_idx + moutidx = self.inverse_M_idx + noutidx = self.inverse_N_idx + else: + lmidx = lmoutidx = np.arange(len(modes)) + midx = moutidx = np.arange(len(modes)) + nidx = noutidx = np.arange(len(modes)) if not len(modes): - return np.array([]).reshape((len(nodes), 0)) + return np.array([]).reshape((grid.num_nodes, 0)) - # TODO(#1243): avoid duplicate calculations when mixing derivatives - r, t, z = nodes.T - l, m, n = modes.T + r, t, z = grid.nodes.T + _, m, n = modes.T lm = modes[:, :2] - if unique: - # TODO(#1243): can avoid this here by using grid.unique_idx etc - # and adding unique_modes attributes to basis - _, ridx, routidx = np.unique( - r, return_index=True, return_inverse=True, axis=0 - ) - _, tidx, toutidx = np.unique( - t, return_index=True, return_inverse=True, axis=0 - ) - _, zidx, zoutidx = np.unique( - z, return_index=True, return_inverse=True, axis=0 - ) - _, lmidx, lmoutidx = np.unique( - lm, return_index=True, return_inverse=True, axis=0 - ) - _, midx, moutidx = np.unique( - m, return_index=True, return_inverse=True, axis=0 - ) - _, nidx, noutidx = np.unique( - n, return_index=True, return_inverse=True, axis=0 - ) - r = r[ridx] - t = t[tidx] - z = z[zidx] - lm = lm[lmidx] - m = m[midx] - n = n[nidx] + try: + ridx = grid.unique_rho_idx + routidx = grid.inverse_rho_idx + except AttributeError: + ridx = routidx = np.arange(grid.num_nodes) + try: + tidx = grid.unique_theta_idx + toutidx = grid.inverse_theta_idx + except AttributeError: + tidx = toutidx = np.arange(grid.num_nodes) + try: + zidx = grid.unique_zeta_idx + zoutidx = grid.inverse_zeta_idx + except AttributeError: + zidx = zoutidx = np.arange(grid.num_nodes) + + r = r[ridx] + t = t[tidx] + z = z[zidx] + lm = lm[lmidx] + m = m[midx] + n = n[nidx] radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0]) poloidal = fourier(t[:, np.newaxis], m, dt=derivatives[1]) toroidal = fourier(z[:, np.newaxis], n, NFP=self.NFP, dt=derivatives[2]) - if unique: - radial = radial[routidx][:, lmoutidx] - poloidal = poloidal[toutidx][:, moutidx] - toroidal = toroidal[zoutidx][:, noutidx] + + radial = radial[routidx][:, lmoutidx] + poloidal = poloidal[toutidx][:, moutidx] + toroidal = toroidal[zoutidx][:, noutidx] return radial * poloidal * toroidal @@ -1190,6 +1332,9 @@ class ChebyshevPolynomial(_Basis): """ + _fft_poloidal = True # trivially true + _fft_toroidal = True + def __init__(self, L): self._L = check_nonnegint(L, "L", False) self._M = 0 @@ -1221,14 +1366,12 @@ def _get_modes(self, L): z = np.zeros_like(l) return np.array([l, z, z]).T - def evaluate( - self, nodes, derivatives=np.array([0, 0, 0]), modes=None, unique=False - ): + def evaluate(self, grid, derivatives=np.array([0, 0, 0]), modes=None): """Evaluate basis functions at specified nodes. Parameters ---------- - nodes : ndarray of float, size(num_nodes,3) + grid : Grid or ndarray of float, size(num_nodes,3) Node coordinates, in (rho,theta,zeta). derivatives : ndarray of int, shape(num_derivatives,3) Order of derivatives to compute in (rho,theta,zeta). @@ -1244,17 +1387,34 @@ def evaluate( basis functions evaluated at nodes """ + if not isinstance(grid, _Grid): + grid = Grid(grid, sort=False, jitable=True) if modes is None: modes = self.modes + lidx = self.unique_L_idx + loutidx = self.inverse_L_idx + else: + lidx = loutidx = np.arange(len(modes)) if (derivatives[1] != 0) or (derivatives[2] != 0): - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return jnp.zeros((grid.num_nodes, modes.shape[0])) if not len(modes): - return np.array([]).reshape((len(nodes), 0)) + return np.array([]).reshape((grid.num_nodes, 0)) - r, t, z = nodes.T - l, m, n = modes.T + r = grid.nodes[:, 0] + l = modes[:, 0] + + try: + ridx = grid.unique_rho_idx + routidx = grid.inverse_rho_idx + except AttributeError: + ridx = routidx = np.arange(grid.num_nodes) + + r = r[ridx] + l = l[lidx] radial = chebyshev(r[:, np.newaxis], l, dr=derivatives[0]) + radial = radial[routidx, :][:, loutidx] + return radial def change_resolution(self, L): diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 018104a5de..7ed6158a00 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -631,14 +631,14 @@ def to_sfl( toroidal_coords = eq.compute(["R", "Z", "lambda"], grid=grid) theta = grid.nodes[:, 1] vartheta = theta + toroidal_coords["lambda"] - sfl_grid = grid - sfl_grid.nodes[:, 1] = vartheta + sfl_grid = Grid(np.array([grid.nodes[:, 0], vartheta, grid.nodes[:, 2]]).T) bdry_coords = eq.compute(["R", "Z", "lambda"], grid=bdry_grid) bdry_theta = bdry_grid.nodes[:, 1] bdry_vartheta = bdry_theta + bdry_coords["lambda"] - bdry_sfl_grid = bdry_grid - bdry_sfl_grid.nodes[:, 1] = bdry_vartheta + bdry_sfl_grid = Grid( + np.array([bdry_grid.nodes[:, 0], bdry_vartheta, bdry_grid.nodes[:, 2]]).T + ) if copy: eq_sfl = eq.copy() diff --git a/desc/grid.py b/desc/grid.py index 9e82524c23..77b4adf058 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -44,6 +44,8 @@ class _Grid(IOAble, ABC): "_inverse_zeta_idx", "_is_meshgrid", "_can_fft2", + "_fft_poloidal", + "_fft_toroidal", ] @abstractmethod @@ -240,7 +242,9 @@ def can_fft2(self): (ΞΈ, ΞΆ) ∈ [0, 2Ο€) Γ— [0, 2Ο€/NFP). """ # TODO: GitHub issue 1243? - return self.__dict__.setdefault("_can_fft2", self.is_meshgrid and not self.sym) + return self.__dict__.setdefault( + "_can_fft2", self.is_meshgrid and self.fft_poloidal and self.fft_toroidal + ) @property def coordinates(self): @@ -416,6 +420,20 @@ def nodes(self): """ndarray: Node coordinates, in (rho,theta,zeta).""" return self.__dict__.setdefault("_nodes", np.array([]).reshape((0, 3))) + @property + def fft_poloidal(self): + """bool: whether this grid is compatible with fft in the poloidal direction.""" + if not hasattr(self, "_fft_poloidal"): + self._fft_poloidal = False + return self._fft_poloidal + + @property + def fft_toroidal(self): + """bool: whether this grid is compatible with fft in the toroidal direction.""" + if not hasattr(self, "_fft_toroidal"): + self._fft_toroidal = False + return self._fft_toroidal + @property def spacing(self): """Quadrature weights for integration over surfaces. @@ -720,6 +738,11 @@ class Grid(_Grid): symmetry etc. may be wrong if grid contains duplicate nodes. """ + # if you're using a custom grid it almost always isnt uniform, or is under jit + # where we can't properly check this anyways, so just set to false + _fft_poloidal = False + _fft_toroidal = False + def __init__( self, nodes, @@ -771,7 +794,7 @@ def __init__( # Don't do anything with symmetry since that changes # of nodes # avoid point at the axis, for now. r, t, z = self._nodes.T - r = jnp.where(r == 0, 1e-12, r) + r = jnp.where(r == 0, kwargs.pop("axis_shift", 1e-12), r) self._nodes = jnp.column_stack([r, t, z]) self._axis = np.array([], dtype=int) # allow for user supplied indices/inverse indices for special cases @@ -976,6 +999,11 @@ class LinearGrid(_Grid): Note that if supplied the values may be reordered in the resulting grid. """ + _io_attrs_ = _Grid._io_attrs_ + [ + "_toroidal_endpoint", + "_poloidal_endpoint", + ] + def __init__( self, L=None, @@ -985,18 +1013,25 @@ def __init__( sym=False, axis=True, endpoint=False, - rho=np.array(1.0), - theta=np.array(0.0), - zeta=np.array(0.0), + rho=None, + theta=None, + zeta=None, ): + assert (L is None) or (rho is None), "cannot specify both L and rho" + assert (M is None) or (theta is None), "cannot specify both M and theta" + assert (N is None) or (zeta is None), "cannot specify both N and zeta" self._L = check_nonnegint(L, "L") self._M = check_nonnegint(M, "M") self._N = check_nonnegint(N, "N") self._NFP = check_posint(NFP, "NFP", False) self._sym = sym self._endpoint = bool(endpoint) + # these are just default values that may get overwritten in _create_nodes self._poloidal_endpoint = False self._toroidal_endpoint = False + self._fft_poloidal = False + self._fft_toroidal = False + self._node_pattern = "linear" self._coordinates = "rtz" self._is_meshgrid = True @@ -1093,9 +1128,12 @@ def _create_nodes( # noqa: C901 r = np.flipud(np.linspace(1, 0, int(rho), endpoint=axis)) # choose dr such that each node has the same weight dr = np.ones_like(r) / r.size - else: + elif rho is not None: r = np.sort(np.atleast_1d(rho)) dr = _midpoint_spacing(r, jnp=np) + else: + r = np.array(1.0, ndmin=1) + dr = np.ones_like(r) # theta if M is not None: @@ -1123,7 +1161,9 @@ def _create_nodes( # noqa: C901 dt *= t.size / (t.size - 1) # scale_weights() will reduce endpoint (dt[0] and dt[-1]) # duplicate node weight - else: + # if custom theta used usually safe to assume its non-uniform so no fft + self._fft_poloidal = (not endpoint) and (not self.sym) + elif theta is not None: t = np.atleast_1d(theta).astype(float) # enforce periodicity t[t != theta_period] %= theta_period @@ -1174,6 +1214,10 @@ def _create_nodes( # noqa: C901 # The scale_weights() function will handle this. else: dt = np.array([theta_period]) + else: + t = np.array(0.0, ndmin=1) + dt = theta_period * np.ones_like(t) + self._fft_poloidal = not self.sym # zeta # note: dz spacing should not depend on NFP @@ -1183,14 +1227,17 @@ def _create_nodes( # noqa: C901 self._N = check_nonnegint(N, "N") zeta = 2 * N + 1 if np.isscalar(zeta) and (int(zeta) == zeta) and zeta > 0: - z = np.linspace(0, zeta_period, int(zeta), endpoint=endpoint) + zeta = int(zeta) + z = np.linspace(0, zeta_period, zeta, endpoint=endpoint) dz = 2 * np.pi / z.size * np.ones_like(z) if endpoint and z.size > 1: # increase node weight to account for duplicate node dz *= z.size / (z.size - 1) # scale_weights() will reduce endpoint (dz[0] and dz[-1]) # duplicate node weight - else: + # if custom zeta used usually safe to assume its non-uniform so no fft + self._fft_toroidal = not endpoint + elif zeta is not None: z, dz = _periodic_spacing(zeta, zeta_period, sort=True, jnp=np) dz = dz * NFP if z[0] == 0 and z[-1] == zeta_period: @@ -1200,6 +1247,10 @@ def _create_nodes( # noqa: C901 # counteract the reduction that will be done there. dz[0] += dz[-1] dz[-1] = dz[0] + else: + z = np.array(0.0, ndmin=1) + dz = zeta_period * np.ones_like(z) * NFP + self._fft_toroidal = True # trivially true self._poloidal_endpoint = ( t.size > 0 @@ -1288,6 +1339,9 @@ class QuadratureGrid(_Grid): """ + _fft_poloidal = True + _fft_toroidal = True + def __init__(self, L, M, N, NFP=1): self._L = check_nonnegint(L, "L", False) self._M = check_nonnegint(M, "N", False) @@ -1433,6 +1487,9 @@ class ConcentricGrid(_Grid): """ + _fft_poloidal = False + _fft_toroidal = True + def __init__(self, L, M, N, NFP=1, sym=False, axis=False, node_pattern="jacobi"): self._L = check_nonnegint(L, "L", False) self._M = check_nonnegint(M, "M", False) diff --git a/desc/plotting.py b/desc/plotting.py index 6a1e6043f3..3e0e2adf7b 100644 --- a/desc/plotting.py +++ b/desc/plotting.py @@ -23,7 +23,7 @@ from desc.grid import Grid, LinearGrid from desc.integrals import surface_averages_map from desc.magnetic_fields import field_line_integrate -from desc.utils import errorif, only1, parse_argname_change, setdefault +from desc.utils import errorif, islinspaced, only1, parse_argname_change, setdefault from desc.vmec_utils import ptolemy_linear_transform __all__ = [ @@ -219,20 +219,19 @@ def _get_grid(**kwargs): """ grid_args = { - "L": None, - "M": None, - "N": None, "NFP": 1, "sym": False, "axis": True, "endpoint": True, - "rho": np.array([1.0]), - "theta": np.array([0.0]), - "zeta": np.array([0.0]), } - for key in kwargs.keys(): - if key in grid_args.keys(): - grid_args[key] = kwargs[key] + grid_args.update(kwargs) + if ("L" not in grid_args) and ("rho" not in grid_args): + grid_args["rho"] = np.array([1.0]) + if ("M" not in grid_args) and ("theta" not in grid_args): + grid_args["theta"] = np.array([0.0]) + if ("N" not in grid_args) and ("zeta" not in grid_args): + grid_args["zeta"] = np.array([0.0]) + grid = LinearGrid(**grid_args) return grid @@ -324,6 +323,73 @@ def _compute(eq, name, grid, component=None, reshape=True): return data, label +def _compute_Bn(eq, field, plot_grid, field_grid): + """Compute normal field from virtual casing + coils, using correct grids.""" + errorif( + field is None, + ValueError, + "If B*n is entered as the variable to plot, a magnetic field" + " must be provided.", + ) + errorif( + not np.all(np.isclose(plot_grid.nodes[:, 0], 1)), + ValueError, + "If B*n is entered as the variable to plot, " + "the grid nodes must be at rho=1.", + ) + + theta_endpoint = zeta_endpoint = False + + if plot_grid.fft_poloidal and plot_grid.fft_toroidal: + source_grid = eval_grid = plot_grid + # often plot_grid is still linearly spaced but includes endpoints. In that case + # make a temp grid that just leaves out the endpoint so we can FFT + elif ( + isinstance(plot_grid, LinearGrid) + and not plot_grid.sym + and islinspaced(plot_grid.nodes[plot_grid.unique_theta_idx, 1]) + and islinspaced(plot_grid.nodes[plot_grid.unique_zeta_idx, 2]) + ): + if plot_grid._poloidal_endpoint: + theta_endpoint = True + theta = plot_grid.nodes[plot_grid.unique_theta_idx[0:-1], 1] + if plot_grid._toroidal_endpoint: + zeta_endpoint = True + zeta = plot_grid.nodes[plot_grid.unique_zeta_idx[0:-1], 2] + vc_grid = LinearGrid( + theta=theta, + zeta=zeta, + NFP=plot_grid.NFP, + endpoint=False, + ) + # override attr since we know fft is ok even with custom nodes + vc_grid._fft_poloidal = vc_grid._fft_toroidal = True + source_grid = eval_grid = vc_grid + else: + eval_grid = plot_grid + source_grid = LinearGrid( + M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, sym=False, endpoint=False + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + data, _ = field.compute_Bnormal( + eq, eval_grid=eval_grid, source_grid=field_grid, vc_source_grid=source_grid + ) + data = data.reshape((eval_grid.num_theta, eval_grid.num_zeta), order="F") + if theta_endpoint: + data = np.vstack((data, data[0, :])) + if zeta_endpoint: + data = np.hstack((data, np.atleast_2d(data[:, 0]).T)) + data = data.reshape( + (plot_grid.num_theta, plot_grid.num_rho, plot_grid.num_zeta), order="F" + ) + + label = r"$\mathbf{B} \cdot \hat{n} ~(\mathrm{T})$" + return data, label + + def plot_coefficients(eq, L=True, M=True, N=True, ax=None, **kwargs): """Plot spectral coefficient magnitudes vs spectral mode number. @@ -658,45 +724,10 @@ def plot_2d( component=component, ) else: - field = kwargs.pop("field", None) - errorif( - field is None, - ValueError, - "If B*n is entered as the variable to plot, a magnetic field" - " must be provided.", - ) - errorif( - not np.all(np.isclose(grid.nodes[:, 0], 1)), - ValueError, - "If B*n is entered as the variable to plot, " - "the grid nodes must be at rho=1.", + data, label = _compute_Bn( + eq, kwargs.pop("field", None), grid, kwargs.pop("field_grid", None) ) - field_grid = kwargs.pop("field_grid", None) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - if grid.endpoint: - # cannot use a grid with endpoint=True for FFT interpolator - vc_grid = LinearGrid( - theta=grid.nodes[grid.unique_theta_idx[0:-1], 1], - zeta=grid.nodes[grid.unique_zeta_idx[0:-1], 2], - NFP=grid.NFP, - endpoint=False, - ) - else: - vc_grid = grid - data, _ = field.compute_Bnormal( - eq, eval_grid=vc_grid, source_grid=field_grid, vc_source_grid=vc_grid - ) - data = data.reshape((vc_grid.num_theta, vc_grid.num_zeta), order="F") - if grid.endpoint: - data = np.hstack((data, np.atleast_2d(data[:, 0]).T)) - data = np.vstack((data, data[0, :])) - data = data.reshape((grid.num_theta, grid.num_rho, grid.num_zeta), order="F") - - label = r"$\mathbf{B} \cdot \hat{n} ~(\mathrm{T})$" - fig, ax = _format_ax(ax, figsize=kwargs.pop("figsize", None)) divider = make_axes_locatable(ax) @@ -950,44 +981,9 @@ def plot_3d( component=component, ) else: - field = kwargs.pop("field", None) - errorif( - field is None, - ValueError, - "If B*n is entered as the variable to plot, a magnetic field" - " must be provided.", + data, label = _compute_Bn( + eq, kwargs.pop("field", None), grid, kwargs.pop("field_grid", None) ) - errorif( - not np.all(np.isclose(grid.nodes[:, 0], 1)), - ValueError, - "If B*n is entered as the variable to plot, " - "the grid nodes must be at rho=1.", - ) - - field_grid = kwargs.pop("field_grid", None) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - if grid.endpoint: - # cannot use a grid with endpoint=True for FFT interpolator - vc_grid = LinearGrid( - theta=grid.nodes[grid.unique_theta_idx[0:-1], 1], - zeta=grid.nodes[grid.unique_zeta_idx[0:-1], 2], - NFP=grid.NFP, - endpoint=False, - ) - else: - vc_grid = grid - data, _ = field.compute_Bnormal( - eq, eval_grid=vc_grid, source_grid=field_grid, vc_source_grid=vc_grid - ) - data = data.reshape((vc_grid.num_theta, vc_grid.num_zeta), order="F") - if grid.endpoint: - data = np.hstack((data, np.atleast_2d(data[:, 0]).T)) - data = np.vstack((data, data[0, :])) - data = data.reshape((grid.num_theta, grid.num_rho, grid.num_zeta), order="F") - - label = r"$\mathbf{B} \cdot \hat{n} ~(\mathrm{T})$" errorif( len(kwargs) != 0, diff --git a/desc/profiles.py b/desc/profiles.py index faaa7e3abd..ce7398da98 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -1498,7 +1498,7 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): """ if params is None: params = self.params - A = self.basis.evaluate(grid.nodes, [dr, dt, dz]) + A = self.basis.evaluate(grid, [dr, dt, dz]) return A @ params @classmethod diff --git a/desc/transform.py b/desc/transform.py index c714c5def8..1660fe5e6e 100644 --- a/desc/transform.py +++ b/desc/transform.py @@ -6,14 +6,9 @@ from termcolor import colored from desc.backend import jnp, put +from desc.grid import Grid from desc.io import IOAble -from desc.utils import ( - combination_permutation, - isalmostequal, - islinspaced, - issorted, - warnif, -) +from desc.utils import combination_permutation, warnif class Transform(IOAble): @@ -158,76 +153,32 @@ def _check_inputs_fft(self, grid, basis): self._method = "direct1" return - zeta_vals, zeta_cts = np.unique(grid.nodes[:, 2], return_counts=True) - - if not isalmostequal(zeta_cts): - warnings.warn( - colored( - "fft method requires the same number of nodes on each zeta plane, " - + "falling back to direct1 method", - "yellow", - ) - ) - self.method = "direct1" - return - - if not isalmostequal( - grid.nodes[:, :2].T.reshape((2, zeta_cts[0], -1), order="F") - ): + if not grid.fft_toroidal: warnings.warn( colored( - "fft method requires that node pattern is the same on each zeta " - + "plane, falling back to direct1 method", - "yellow", - ) - ) - self.method = "direct1" - return - - id2 = np.lexsort((basis.modes[:, 1], basis.modes[:, 0], basis.modes[:, 2])) - if not issorted(id2): - warnings.warn( - colored( - "fft method requires zernike indices to be sorted by toroidal mode " - + "number, falling back to direct1 method", - "yellow", - ) - ) - self.method = "direct1" - return - - if ( - len(zeta_vals) > 1 - and not abs((zeta_vals[-1] + zeta_vals[1]) * basis.NFP - 2 * np.pi) < 1e-14 - ): - warnings.warn( - colored( - "fft method requires that nodes complete 1 full field period, " + "fft method requires compatible grid, got {}".format(grid) + "falling back to direct2 method", "yellow", ) ) self.method = "direct2" return - - n_vals, n_cts = np.unique(basis.modes[:, 2], return_counts=True) - if len(n_vals) > 1 and not islinspaced(n_vals): + if not basis.fft_toroidal: warnings.warn( colored( - "fft method requires the toroidal modes are equally spaced in n, " - + "falling back to direct1 method", + "fft method requires compatible basis, got {}".format(basis) + + "falling back to direct2 method", "yellow", ) ) - self.method = "direct1" + self.method = "direct2" return - - if len(zeta_vals) < len(n_vals): + if grid.num_zeta < 2 * basis.N + 1: warnings.warn( colored( "fft method can not undersample in zeta, " + "num_toroidal_modes={}, num_toroidal_angles={}, ".format( - len(n_vals), len(zeta_vals) + basis.N, grid.num_zeta ) + "falling back to direct2 method", "yellow", @@ -235,33 +186,11 @@ def _check_inputs_fft(self, grid, basis): ) self.method = "direct2" return - - if len(zeta_vals) % 2 == 0: - warnings.warn( - colored( - "fft method requires an odd number of toroidal nodes, " - + "falling back to direct2 method", - "yellow", - ) - ) - self.method = "direct2" - return - - if not issorted(grid.nodes[:, 2]): - warnings.warn( - colored( - "fft method requires nodes to be sorted by toroidal angle in " - + "ascending order, falling back to direct1 method", - "yellow", - ) - ) - self.method = "direct1" - return - - if len(zeta_vals) > 1 and not islinspaced(zeta_vals): + if (basis.N > 0) and (grid.NFP != basis.NFP): warnings.warn( colored( - "fft method requires nodes to be equally spaced in zeta, " + "fft method requires grid and basis to have the same NFP, got " + + f"grid.NFP={grid.NFP}, basis.NFP={basis.NFP}, " + "falling back to direct2 method", "yellow", ) @@ -270,25 +199,29 @@ def _check_inputs_fft(self, grid, basis): return self._method = "fft" - self.lm_modes = np.unique(basis.modes[:, :2], axis=0) + self.lm_modes = basis.modes[basis.unique_LM_idx, :2] self.num_lm_modes = self.lm_modes.shape[0] # number of radial/poloidal modes self.num_n_modes = 2 * basis.N + 1 # number of toroidal modes - self.num_z_nodes = len(zeta_vals) # number of zeta nodes - self.N = basis.N # toroidal resolution of basis - self.pad_dim = (self.num_z_nodes - 1) // 2 - self.N - self.dk = basis.NFP * np.arange(-self.N, self.N + 1).reshape((1, -1)) - self.fft_index = np.zeros((basis.num_modes,), dtype=int) + self.pad_dim = self.grid.num_zeta - self.num_n_modes + self.dk = basis.NFP * np.arange(-basis.N, basis.N + 1).reshape((1, -1)) offset = np.min(basis.modes[:, 2]) + basis.N # N for sym="cos", 0 otherwise - for k in range(basis.num_modes): - row = np.where((basis.modes[k, :2] == self.lm_modes).all(axis=1))[0] - col = np.where(basis.modes[k, 2] == n_vals)[0] - self.fft_index[k] = np.squeeze(self.num_n_modes * row + col + offset) - self.fft_nodes = np.hstack( + row = np.where( + (basis.modes[:, None, :2] == self.lm_modes[None, :, :]).all(axis=-1) + )[1] + col = np.where( + basis.modes[None, :, 2] == basis.modes[basis.unique_N_idx, None, 2] + )[0] + self.fft_index = np.atleast_1d( + np.squeeze(self.num_n_modes * row + col + offset) + ) + fft_nodes = np.hstack( [ - grid.nodes[:, :2][: grid.num_nodes // self.num_z_nodes], - np.zeros((grid.num_nodes // self.num_z_nodes, 1)), + grid.nodes[:, :2][: grid.num_nodes // self.grid.num_zeta], + np.zeros((grid.num_nodes // self.grid.num_zeta, 1)), ] ) + # temp grid only used for building transforms, don't need any indexing etc + self.fft_grid = Grid(fft_nodes, sort=False, jitable=True, axis_shift=0) def _check_inputs_direct2(self, grid, basis): """Check that inputs are formatted correctly for direct2 method.""" @@ -297,80 +230,54 @@ def _check_inputs_direct2(self, grid, basis): self._method = "direct1" return - zeta_vals, zeta_cts = np.unique(grid.nodes[:, 2], return_counts=True) - - if not issorted(grid.nodes[:, 2]): - warnings.warn( - colored( - "direct2 method requires nodes to be sorted by toroidal angle in " - + "ascending order, falling back to direct1 method", - "yellow", - ) - ) - self.method = "direct1" - return + from desc.grid import LinearGrid - if not isalmostequal(zeta_cts): + if not (grid.fft_toroidal or isinstance(grid, LinearGrid)): warnings.warn( colored( - "direct2 method requires the same number of nodes on each zeta " - + "plane, falling back to direct1 method", - "yellow", - ) - ) - self.method = "direct1" - return - - if len(zeta_vals) > 1 and not isalmostequal( - grid.nodes[:, :2].T.reshape((2, zeta_cts[0], -1), order="F") - ): - warnings.warn( - colored( - "direct2 method requires that node pattern is the same on each " - + "zeta plane, falling back to direct1 method", + "direct2 method requires compatible grid, got {}".format(grid) + + "falling back to direct1 method", "yellow", ) ) self.method = "direct1" return - - id2 = np.lexsort((basis.modes[:, 1], basis.modes[:, 0], basis.modes[:, 2])) - if not issorted(id2): + if not basis.fft_toroidal: # direct2 and fft have same basis requirements warnings.warn( colored( - "direct2 method requires zernike indices to be sorted by toroidal " - + "mode number, falling back to direct1 method", + "direct2 method requires compatible basis, got {}".format(basis) + + "falling back to direct1 method", "yellow", ) ) self.method = "direct1" return - n_vals, n_cts = np.unique(basis.modes[:, 2], return_counts=True) - self._method = "direct2" - self.lm_modes = np.unique(basis.modes[:, :2], axis=0) - self.n_modes = n_vals - self.zeta_nodes = zeta_vals + self.lm_modes = basis.modes[basis.unique_LM_idx, :2] + self.n_modes = basis.modes[basis.unique_N_idx, 2] + self.zeta_nodes = grid.nodes[grid.unique_zeta_idx, 2] self.num_lm_modes = self.lm_modes.shape[0] # number of radial/poloidal modes self.num_n_modes = self.n_modes.size # number of toroidal modes - self.num_z_nodes = len(zeta_vals) # number of zeta nodes - self.N = basis.N # toroidal resolution of basis - - self.fft_index = np.zeros((basis.num_modes,), dtype=int) - for k in range(basis.num_modes): - row = np.where((basis.modes[k, :2] == self.lm_modes).all(axis=1))[0] - col = np.where(basis.modes[k, 2] == n_vals)[0] - self.fft_index[k] = np.squeeze(self.num_n_modes * row + col) - self.fft_nodes = np.hstack( + + row = np.where( + (basis.modes[:, None, :2] == self.lm_modes[None, :, :]).all(axis=-1) + )[1] + col = np.where( + basis.modes[None, :, 2] == basis.modes[basis.unique_N_idx, None, 2] + )[0] + self.fft_index = np.atleast_1d(np.squeeze(self.num_n_modes * row + col)) + fft_nodes = np.hstack( [ - grid.nodes[:, :2][: grid.num_nodes // self.num_z_nodes], - np.zeros((grid.num_nodes // self.num_z_nodes, 1)), + grid.nodes[:, :2][: grid.num_nodes // grid.num_zeta], + np.zeros((grid.num_nodes // grid.num_zeta, 1)), ] ) - self.dft_nodes = np.hstack( + self.fft_grid = Grid(fft_nodes, sort=False, jitable=True, axis_shift=0) + dft_nodes = np.hstack( [np.zeros((self.zeta_nodes.size, 2)), self.zeta_nodes[:, np.newaxis]] ) + self.dft_grid = Grid(dft_nodes, sort=False, jitable=True, axis_shift=0) def build(self): """Build the transform matrices for each derivative order.""" @@ -381,16 +288,10 @@ def build(self): self._built = True return - if self.method == "direct1": - for d in self.derivatives: - self.matrices["direct1"][d[0]][d[1]][d[2]] = self.basis.evaluate( - self.grid.nodes, d, unique=True - ) - - if self.method == "jitable": + if self.method in ["direct1", "jitable"]: for d in self.derivatives: self.matrices["direct1"][d[0]][d[1]][d[2]] = self.basis.evaluate( - self.grid.nodes, d, unique=False + self.grid, d ) if self.method in ["fft", "direct2"]: @@ -400,7 +301,7 @@ def build(self): temp_modes = np.hstack([self.lm_modes, np.zeros((self.num_lm_modes, 1))]) for d in temp_d: self.matrices["fft"][d[0]][d[1]] = self.basis.evaluate( - self.fft_nodes, d, modes=temp_modes, unique=True + self.fft_grid, d, modes=temp_modes ) if self.method == "direct2": temp_d = np.hstack( @@ -411,7 +312,7 @@ def build(self): ) for d in temp_d: self.matrices["direct2"][d[2]] = self.basis.evaluate( - self.dft_nodes, d, modes=temp_modes, unique=True + self.dft_grid, d, modes=temp_modes ) self._built = True @@ -422,20 +323,20 @@ def build_pinv(self): return rcond = None if self.rcond == "auto" else self.rcond if self.method in ["direct1", "jitable"]: - A = self.basis.evaluate(self.grid.nodes, np.array([0, 0, 0])) + A = self.basis.evaluate(self.grid, np.array([0, 0, 0])) self.matrices["pinv"] = ( jnp.linalg.pinv(A, rtol=rcond) if A.size else np.zeros_like(A.T) ) elif self.method == "direct2": temp_modes = np.hstack([self.lm_modes, np.zeros((self.num_lm_modes, 1))]) A = self.basis.evaluate( - self.fft_nodes, np.array([0, 0, 0]), modes=temp_modes, unique=True + self.fft_grid, np.array([0, 0, 0]), modes=temp_modes ) temp_modes = np.hstack( [np.zeros((self.num_n_modes, 2)), self.n_modes[:, np.newaxis]] ) B = self.basis.evaluate( - self.dft_nodes, np.array([0, 0, 0]), modes=temp_modes, unique=True + self.dft_grid, np.array([0, 0, 0]), modes=temp_modes ) self.matrices["pinvA"] = ( jnp.linalg.pinv(A, rtol=rcond) if A.size else np.zeros_like(A.T) @@ -446,7 +347,7 @@ def build_pinv(self): elif self.method == "fft": temp_modes = np.hstack([self.lm_modes, np.zeros((self.num_lm_modes, 1))]) A = self.basis.evaluate( - self.fft_nodes, np.array([0, 0, 0]), modes=temp_modes, unique=True + self.fft_grid, np.array([0, 0, 0]), modes=temp_modes ) self.matrices["pinvA"] = ( jnp.linalg.pinv(A, rtol=rcond) if A.size else np.zeros_like(A.T) @@ -521,21 +422,19 @@ def transform(self, c, dr=0, dt=0, dz=0): # differentiate c_diff = c_mtrx[:, :: (-1) ** dz] * self.dk**dz * (-1) ** (dz > 1) # re-format in complex notation - c_real = jnp.pad( - (self.num_z_nodes / 2) - * (c_diff[:, self.N + 1 :] - 1j * c_diff[:, self.N - 1 :: -1]), - ((0, 0), (0, self.pad_dim)), - mode="constant", + c_cplx = (self.grid.num_zeta / 2) * ( + c_diff[:, self.basis.N + 1 :] - 1j * c_diff[:, self.basis.N - 1 :: -1] ) - c_cplx = jnp.hstack( + c_pad = jnp.hstack( ( - self.num_z_nodes * c_diff[:, self.N, jnp.newaxis], - c_real, - jnp.fliplr(jnp.conj(c_real)), + self.grid.num_zeta * c_diff[:, self.basis.N, jnp.newaxis], + c_cplx, + jnp.zeros((c_cplx.shape[0], self.pad_dim)), + jnp.fliplr(jnp.conj(c_cplx)), ) ) # transform coefficients - c_fft = jnp.real(jnp.fft.ifft(c_cplx)) + c_fft = jnp.real(jnp.fft.ifft(c_pad)) return (A @ c_fft).flatten(order="F") def fit(self, x): @@ -563,17 +462,16 @@ def fit(self, x): elif self.method == "direct2": Ainv = self.matrices["pinvA"] Binv = self.matrices["pinvB"] - yy = jnp.matmul(Ainv, x.reshape((-1, self.num_z_nodes), order="F")) + yy = jnp.matmul(Ainv, x.reshape((-1, self.grid.num_zeta), order="F")) c = jnp.matmul(Binv, yy.T).T.flatten()[self.fft_index] elif self.method == "fft": Ainv = self.matrices["pinvA"] c_fft = jnp.matmul(Ainv, x.reshape((Ainv.shape[1], -1), order="F")) c_cplx = jnp.fft.fft(c_fft) - c_real = c_cplx[:, 1 : c_cplx.shape[1] // 2 + 1] - c_unpad = c_real[:, : c_real.shape[1] - self.pad_dim] - c0 = c_cplx[:, :1].real / self.num_z_nodes - c2 = c_unpad.real / (self.num_z_nodes / 2) - c1 = -c_unpad.imag[:, ::-1] / (self.num_z_nodes / 2) + c_unpad = c_cplx[:, 1 : (c_cplx.shape[1] - self.pad_dim - 1) // 2 + 1] + c0 = c_cplx[:, :1].real / self.grid.num_zeta + c2 = c_unpad.real / (self.grid.num_zeta / 2) + c1 = -c_unpad.imag[:, ::-1] / (self.grid.num_zeta / 2) c_diff = jnp.hstack([c1, c0, c2]) c = c_diff.flatten()[self.fft_index] return c @@ -615,7 +513,7 @@ def project(self, y): elif self.method == "direct2": A = self.matrices["fft"][0][0] B = self.matrices["direct2"][0] - yy = jnp.matmul(A.T, y.reshape((-1, self.num_z_nodes), order="F")) + yy = jnp.matmul(A.T, y.reshape((-1, self.grid.num_zeta), order="F")) return jnp.matmul(yy, B).flatten()[self.fft_index] elif self.method == "fft": @@ -624,7 +522,7 @@ def project(self, y): # there might be a more efficient way... a = jnp.fft.fft(A.T @ y.reshape((A.shape[0], -1), order="F")) cdn = a[:, 0] - cr = a[:, 1 : 1 + self.N] + cr = a[:, 1 : 1 + self.basis.N] b = jnp.hstack( [-cr.imag[:, ::-1], cdn.real[:, np.newaxis], cr.real] ).flatten()[self.fft_index] diff --git a/desc/vmec_utils.py b/desc/vmec_utils.py index 18e42cd47e..6bf18631f1 100644 --- a/desc/vmec_utils.py +++ b/desc/vmec_utils.py @@ -844,9 +844,9 @@ def make_boozmn_output( # noqa: C901 chi.units = "Wb" chi[:] = np.linspace( 0, - eq.compute( - "chi", grid=LinearGrid(L=eq.L_grid, M=M_booz, N=N_booz, rho=1.0, NFP=eq.NFP) - )["chi"][-1] + eq.compute("chi", grid=LinearGrid(L=eq.L_grid, M=M_booz, N=N_booz, NFP=eq.NFP))[ + "chi" + ][-1] * 2 * np.pi, surfs, diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 81966167d2..d69fa0ab0c 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -257,29 +257,26 @@ def test_3d_tz(self): assert "Z" in data.keys() assert "|F|" in data.keys() - return fig @pytest.mark.unit def test_3d_rz(self): """Test 3d plotting of pressure on toroidal cross section.""" eq = get("DSHAPE_CURRENT") grid = LinearGrid(rho=100, theta=0.0, zeta=100) - fig = plot_3d(eq, "p", grid=grid) - return fig + _ = plot_3d(eq, "p", grid=grid) @pytest.mark.unit def test_3d_rt(self): """Test 3d plotting of flux on poloidal ribbon.""" eq = get("DSHAPE_CURRENT") grid = LinearGrid(rho=100, theta=100, zeta=0.0) - fig = plot_3d(eq, "psi", grid=grid) - return fig + _ = plot_3d(eq, "psi", grid=grid) @pytest.mark.unit def test_plot_3d_surface(self): """Test 3d plotting of surface object.""" surf = FourierRZToroidalSurface() - fig = plot_3d( + _ = plot_3d( surf, "curvature_H_rho", showgrid=False, @@ -288,7 +285,6 @@ def test_plot_3d_surface(self): showticklabels=False, showaxislabels=False, ) - return fig @pytest.mark.unit def test_3d_plot_Bn(self): @@ -296,13 +292,12 @@ def test_3d_plot_Bn(self): eq = get("precise_QA") with pytest.warns(UserWarning, match="Reducing radial"): eq.change_resolution(M=4, N=4, L=4, M_grid=8, N_grid=8, L_grid=8) - fig = plot_3d( + _ = plot_3d( eq, "B*n", field=ToroidalMagneticField(1, 1), grid=LinearGrid(M=30, N=30, NFP=1, endpoint=True), ) - return fig class TestPlotFSA: @@ -866,7 +861,6 @@ def flatten_coils(coilset): for string in ["X", "Y", "Z"]: assert string in data.keys() assert len(data[string]) == len(coil_list) - return fig @pytest.mark.unit @@ -900,7 +894,6 @@ def flatten_coils(coilset): for string in ["X", "Y", "Z"]: assert string in data.keys() assert len(data[string]) == len(coil_list) - return fig @pytest.mark.unit diff --git a/tests/test_transform.py b/tests/test_transform.py index 188235d8a9..9984381436 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -437,28 +437,13 @@ def test_project(self): @pytest.mark.unit def test_fft_warnings(self): """Test that warnings are thrown when trying to use fft where it won't work.""" - g = LinearGrid(rho=2, theta=2, zeta=2) - b = ZernikePolynomial(L=0, M=0) - with pytest.warns(UserWarning): - t = Transform(g, b, method="fft") - assert t.method == "direct2" - g = Grid(np.array([[0, 0, 0], [1, 1, 0], [1, 1, 1]])) b = ZernikePolynomial(L=2, M=2) - with pytest.warns(UserWarning, match="same number of nodes on each zeta plane"): - t = Transform(g, b, method="fft") - assert t.method == "direct1" - - g = LinearGrid(rho=2, theta=2, zeta=2) - b = DoubleFourierSeries(M=2, N=2) - b._modes = b.modes[np.random.permutation(25)] - with pytest.warns( - UserWarning, match="zernike indices to be sorted by toroidal mode" - ): + with pytest.warns(UserWarning, match="compatible grid"): t = Transform(g, b, method="fft") assert t.method == "direct1" - g = LinearGrid(rho=2, theta=2, zeta=2, NFP=2) + g = LinearGrid(rho=2, M=2, N=2, NFP=2) b = DoubleFourierSeries(M=2, N=2) # this actually will emit 2 warnings, one for the NFP for # basis and grid not matching, and one for nodes completing 1 full period @@ -473,73 +458,34 @@ def test_fft_warnings(self): for r in record: if "Unequal number of field periods" in str(r.message): NFP_grid_basis_warning_exists = True - if "nodes complete 1 full field period" in str(r.message): + if "grid and basis to have the same NFP" in str(r.message): nodes_warning_exists = True assert NFP_grid_basis_warning_exists and nodes_warning_exists - g = LinearGrid(rho=2, theta=2, zeta=2) - b = DoubleFourierSeries(M=1, N=1) - b._modes[:, 2] = np.where(b._modes[:, 2] == 1, 2, b._modes[:, 2]) - with pytest.warns(UserWarning, match="toroidal modes are equally spaced in n"): - t = Transform(g, b, method="fft") - assert t.method == "direct1" - - g = LinearGrid(rho=2, theta=2, zeta=2) + g = LinearGrid(rho=2, M=2, N=2) b = DoubleFourierSeries(M=1, N=3) with pytest.warns(UserWarning, match="can not undersample in zeta"): t = Transform(g, b, method="fft") assert t.method == "direct2" - g = LinearGrid(rho=2, theta=2, zeta=4) - b = DoubleFourierSeries(M=1, N=1) - with pytest.warns( - UserWarning, match="requires an odd number of toroidal nodes" - ): - t = Transform(g, b, method="fft") - assert t.method == "direct2" - - g = LinearGrid(rho=2, theta=2, zeta=5) - b = DoubleFourierSeries(M=1, N=1) - g._nodes = g._nodes[::-1] - with pytest.warns(UserWarning, match="nodes to be sorted by toroidal angle"): + b._fft_toroidal = False + g = LinearGrid(2, 3, 4) + with pytest.warns(UserWarning, match="compatible basis"): t = Transform(g, b, method="fft") assert t.method == "direct1" - g = LinearGrid(rho=2, theta=2, zeta=5) - b = DoubleFourierSeries(M=1, N=1) - g._nodes[:, 2] = np.where(g._nodes[:, 2] == 0, 0.01, g.nodes[:, 2]) - with pytest.warns(UserWarning, match="nodes to be equally spaced in zeta"): - t = Transform(g, b, method="fft") - assert t.method == "direct2" - @pytest.mark.unit def test_direct2_warnings(self): """Test that warnings are thrown when trying to use direct2 if it won't work.""" - g = LinearGrid(rho=2, theta=2, zeta=5) - b = DoubleFourierSeries(M=1, N=1) - g._nodes = g._nodes[::-1] - with pytest.warns(UserWarning, match="nodes to be sorted by toroidal angle"): - t = Transform(g, b, method="direct2") - assert t.method == "direct1" - - g = Grid(np.array([[0, 0, 0], [1, 1, 0], [1, 1, 1]])) - b = ZernikePolynomial(L=2, M=2) - with pytest.warns(UserWarning, match="same number of nodes on each zeta plane"): - t = Transform(g, b, method="direct2") - assert t.method == "direct1" - g = Grid(np.array([[0, 0, -1], [1, 1, 0], [1, 1, 1]])) b = ZernikePolynomial(L=2, M=2) - with pytest.warns(UserWarning, match="node pattern is the same"): + with pytest.warns(UserWarning, match="requires compatible grid"): t = Transform(g, b, method="direct2") assert t.method == "direct1" - g = LinearGrid(rho=2, theta=2, zeta=2) - b = DoubleFourierSeries(M=2, N=2) - b._modes = b.modes[np.random.permutation(25)] - with pytest.warns( - UserWarning, match="zernike indices to be sorted by toroidal mode" - ): + b._fft_toroidal = False + g = LinearGrid(2, 3, 4) + with pytest.warns(UserWarning, match="compatible basis"): t = Transform(g, b, method="direct2") assert t.method == "direct1" @@ -611,6 +557,32 @@ def test_Z_projection(self): ) _ = tr["Z"].project(f) + @pytest.mark.unit + def test_fft_even_grid(self): + """Test fft method with even number of grid points.""" + for sym in ["cos", "sin", False]: + basis = FourierZernikeBasis(2, 2, 4, sym=sym) + c = np.random.random(basis.num_modes) + for N in range(9, 16): + grid = LinearGrid(L=2, M=2, zeta=N) + t1 = Transform(grid, basis, method="direct1", build_pinv=True) + t2 = Transform(grid, basis, method="fft", build_pinv=True) + x1 = t1.transform(c) + x2 = t2.transform(c) + np.testing.assert_allclose( + x1, x2, atol=1e-10, err_msg=f"N={N} sym={sym}" + ) + c1 = t1.fit(x1) + c2 = t2.fit(x2) + np.testing.assert_allclose( + c1, c2, atol=1e-10, err_msg=f"N={N} sym={sym}" + ) + y1 = t1.project(x1) + y2 = t2.project(x2) + np.testing.assert_allclose( + y1, y2, atol=1e-10, err_msg=f"N={N} sym={sym}" + ) + @pytest.mark.unit def test_transform_pytree(): @@ -643,10 +615,10 @@ def bar(x): def test_NFP_warning(): """Make sure we only warn about basis/grid NFP in cases where it matters.""" rho = np.linspace(0, 1, 20) - g01 = LinearGrid(rho=rho, L=5, N=0, NFP=1) - g02 = LinearGrid(rho=rho, L=5, N=0, NFP=2) - g21 = LinearGrid(rho=rho, L=5, N=5, NFP=1) - g22 = LinearGrid(rho=rho, L=5, N=5, NFP=2) + g01 = LinearGrid(rho=rho, M=5, N=0, NFP=1) + g02 = LinearGrid(rho=rho, M=5, N=0, NFP=2) + g21 = LinearGrid(rho=rho, M=5, N=5, NFP=1) + g22 = LinearGrid(rho=rho, M=5, N=5, NFP=2) b01 = FourierZernikeBasis(L=2, M=2, N=0, NFP=1) b02 = FourierZernikeBasis(L=2, M=2, N=0, NFP=2) b21 = FourierZernikeBasis(L=2, M=2, N=2, NFP=1)