diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cecccc2e8..79501a0f8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,7 +47,7 @@ Bug Fixes - Fixes the coil currents in ``desc.coils.initialize_modular_coils`` to now give the correct expected linking current. - ``desc.objectives.PlasmaVesselDistance`` now correctly accounts for multiple field periods on both the equilibrium and the vessel surface. Previously it only considered distances within a single field period. - Sets ``os.environ["JAX_PLATFORMS"] = "cpu"`` instead of ``os.environ["JAX_PLATFORM_NAME"] = "cpu"`` when doing ``set_device("cpu")``. - +- Fixes bug when passing only `sym` into `.change_resolution` for ``FourierRZToroidalSurface``, ``ZernikeRZToroidalSection`` and ``FourierRZCurve``. Performance Improvements diff --git a/desc/basis.py b/desc/basis.py index eb882c39bd..65f203dba3 100644 --- a/desc/basis.py +++ b/desc/basis.py @@ -539,7 +539,7 @@ def change_resolution(self, N, NFP=None, sym=None): """ NFP = check_posint(NFP, "NFP") self._NFP = NFP if NFP is not None else self.NFP - if N != self.N: + if N != self.N or (sym is not None and sym != self.sym): self._N = check_nonnegint(N, "N", False) self._sym = sym if sym is not None else self.sym self._modes = self._get_modes(self.N) @@ -690,7 +690,7 @@ def change_resolution(self, M, N, NFP=None, sym=None): """ NFP = check_posint(NFP, "NFP") self._NFP = NFP if NFP is not None else self.NFP - if M != self.M or N != self.N or sym != self.sym: + if M != self.M or N != self.N or (sym is not None and sym != self.sym): self._M = check_nonnegint(M, "M", False) self._N = check_nonnegint(N, "N", False) self._sym = sym if sym is not None else self.sym @@ -898,7 +898,7 @@ def change_resolution(self, L, M, sym=None): None """ - if L != self.L or M != self.M or sym != self.sym: + if L != self.L or M != self.M or (sym is not None and sym != self.sym): self._L = check_nonnegint(L, "L", False) self._M = check_nonnegint(M, "M", False) self._sym = sym if sym is not None else self.sym @@ -1073,7 +1073,12 @@ def change_resolution(self, L, M, N, NFP=None, sym=None): """ NFP = check_posint(NFP, "NFP") self._NFP = NFP if NFP is not None else self.NFP - if L != self.L or M != self.M or N != self.N or sym != self.sym: + if ( + L != self.L + or M != self.M + or N != self.N + or (sym is not None and sym != self.sym) + ): self._L = check_nonnegint(L, "L", False) self._M = check_nonnegint(M, "M", False) self._N = check_nonnegint(N, "N", False) @@ -1311,7 +1316,12 @@ def change_resolution(self, L, M, N, NFP=None, sym=None): """ NFP = check_posint(NFP, "NFP") self._NFP = NFP if NFP is not None else self.NFP - if L != self.L or M != self.M or N != self.N or sym != self.sym: + if ( + L != self.L + or M != self.M + or N != self.N + or (sym is not None and sym != self.sym) + ): self._L = check_nonnegint(L, "L", False) self._M = check_nonnegint(M, "M", False) self._N = check_nonnegint(N, "N", False) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 84a3b6dfce..37c8cb520f 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -264,8 +264,8 @@ def __init__( self._M_grid = setdefault(M_grid, 2 * self.M) self._N_grid = setdefault(N_grid, 2 * self.N) - self._surface.change_resolution(self.L, self.M, self.N) - self._axis.change_resolution(self.N) + self._surface.change_resolution(self.L, self.M, self.N, sym=self.sym) + self._axis.change_resolution(self.N, sym=self.sym) # bases self._R_basis = FourierZernikeBasis( diff --git a/desc/geometry/curve.py b/desc/geometry/curve.py index 2e141710fb..077fbea5ca 100644 --- a/desc/geometry/curve.py +++ b/desc/geometry/curve.py @@ -127,8 +127,7 @@ def change_resolution(self, N=None, NFP=None, sym=None): if ( ((N is not None) and (N != self.N)) or ((NFP is not None) and (NFP != self.NFP)) - or (sym is not None) - and (sym != self.sym) + or ((sym is not None) and (sym != self.sym)) ): self._NFP = int(NFP if NFP is not None else self.NFP) self._sym = bool(sym) if sym is not None else self.sym diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index f6350ea210..d0aca17079 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -205,13 +205,14 @@ def change_resolution(self, *args, **kwargs): N = check_nonnegint(N, "N") NFP = check_posint(NFP, "NFP") self._NFP = int(NFP if NFP is not None else self.NFP) - self._sym = sym if sym is not None else self.sym if ( ((N is not None) and (N != self.N)) or ((M is not None) and (M != self.M)) or (NFP is not None) + or ((sym is not None) and (sym != self.sym)) ): + self._sym = sym if sym is not None else self.sym M = int(M if M is not None else self.M) N = int(N if N is not None else self.N) R_modes_old = self.R_basis.modes @@ -975,9 +976,13 @@ def change_resolution(self, *args, **kwargs): L = check_nonnegint(L, "L") M = check_nonnegint(M, "M") - self._sym = sym if sym is not None else self.sym - if ((L is not None) and (L != self.L)) or ((M is not None) and (M != self.M)): + if ( + ((L is not None) and (L != self.L)) + or ((M is not None) and (M != self.M)) + or ((sym is not None) and (sym != self.sym)) + ): + self._sym = sym if sym is not None else self.sym L = int(L if L is not None else self.L) M = int(M if M is not None else self.M) R_modes_old = self.R_basis.modes diff --git a/tests/test_curves.py b/tests/test_curves.py index 0b1af1d63e..8b617600c8 100644 --- a/tests/test_curves.py +++ b/tests/test_curves.py @@ -327,6 +327,20 @@ def test_to_FourierRZCurve(self): with pytest.raises(ValueError): xyz.to_FourierRZ(N=1, grid=grid) + @pytest.mark.unit + def test_change_symmetry(self): + """Test correct sym changes when only sym is passed to change_resolution.""" + c = FourierRZCurve(sym=False) + c.change_resolution(sym=True) + assert c.sym + assert c.R_basis.sym == "cos" + assert c.Z_basis.sym == "sin" + + c.change_resolution(sym=False) + assert c.sym is False + assert c.R_basis.sym is False + assert c.Z_basis.sym is False + class TestFourierXYZCurve: """Tests for FourierXYZCurve class.""" diff --git a/tests/test_equilibrium.py b/tests/test_equilibrium.py index 14e9cea903..456884d8b9 100644 --- a/tests/test_equilibrium.py +++ b/tests/test_equilibrium.py @@ -1,11 +1,11 @@ """Tests for Equilibrium class.""" import os -import pickle import warnings import numpy as np import pytest +from qic import Qic from desc.__main__ import main from desc.backend import sign @@ -263,6 +263,9 @@ def test_eq_change_symmetry(): assert eq.surface.sym assert eq.surface.R_basis.sym == "cos" assert eq.surface.Z_basis.sym == "sin" + assert eq.axis.sym + assert eq.axis.R_basis.sym == "cos" + assert eq.axis.Z_basis.sym == "sin" # undo symmetry eq.change_resolution(sym=False) @@ -276,6 +279,9 @@ def test_eq_change_symmetry(): assert not eq.surface.sym assert not eq.surface.R_basis.sym assert not eq.surface.Z_basis.sym + assert eq.axis.sym is False + assert eq.axis.R_basis.sym is False + assert eq.axis.Z_basis.sym is False @pytest.mark.unit @@ -292,20 +298,58 @@ def test_resolution(): @pytest.mark.unit def test_equilibrium_from_near_axis(): """Test loading a solution from pyQSC/pyQIC.""" - qsc_path = "./tests/inputs/qsc_r2section5.5.pkl" - file = open(qsc_path, "rb") - na = pickle.load(file) - file.close() + na = Qic.from_paper("r2 section 5.5", rs=[0, 1e-5], zc=[0, 1e-5]) r = 1e-2 eq = Equilibrium.from_near_axis(na, r=r, M=8, N=8) grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, sym=eq.sym) data = eq.compute("|B|", grid=grid) + # get the sin/cos modes + eq_rc = eq.Ra_n[ + np.where( + np.logical_and( + eq.axis.R_basis.modes[:, 2] >= 0, + eq.axis.R_basis.modes[:, 2] < na.nfourier, + ) + ) + ] + eq_zc = eq.Za_n[ + np.where( + np.logical_and( + eq.axis.Z_basis.modes[:, 2] >= 0, + eq.axis.Z_basis.modes[:, 2] < na.nfourier, + ) + ) + ] + eq_rs = np.flipud( + eq.Ra_n[ + np.where( + np.logical_and( + eq.axis.R_basis.modes[:, 2] < 0, + eq.axis.R_basis.modes[:, 2] > -na.nfourier, + ) + ) + ] + ) + eq_zs = np.flipud( + eq.Za_n[ + np.where( + np.logical_and( + eq.axis.Z_basis.modes[:, 2] < 0, + eq.axis.Z_basis.modes[:, 2] > -na.nfourier, + ) + ) + ] + ) + assert eq.is_nested() assert eq.NFP == na.nfp - np.testing.assert_allclose(eq.Ra_n[:2], na.rc, atol=1e-10) - np.testing.assert_allclose(eq.Za_n[-2:], na.zs, atol=1e-10) + np.testing.assert_allclose(eq_rc, na.rc, atol=1e-10) + # na.zs[0] is always 0, which DESC doesn't include + np.testing.assert_allclose(eq_zs, na.zs[1:], atol=1e-10) + np.testing.assert_allclose(eq_rs, na.rs[1:], atol=1e-10) + np.testing.assert_allclose(eq_zc, na.zc, atol=1e-10) np.testing.assert_allclose(data["|B|"][0], na.B_mag(r, 0, 0), rtol=2e-2) diff --git a/tests/test_surfaces.py b/tests/test_surfaces.py index c6d3a23480..aa47500728 100644 --- a/tests/test_surfaces.py +++ b/tests/test_surfaces.py @@ -514,3 +514,33 @@ def test_surface_orientation(): assert surf._compute_orientation() == -1 eq = Equilibrium(M=surf.M, N=surf.N, surface=surf, check_orientation=False) assert np.sign(eq.compute("sqrt(g)")["sqrt(g)"].mean()) == -1 + + +@pytest.mark.unit +def test_surface_change_only_symmetry(): + """Test that sym correctly changes when only sym is passed to change_resolution.""" + surf = FourierRZToroidalSurface(sym=False) + surf.change_resolution(sym=True) + assert surf.sym + assert surf.R_basis.sym == "cos" + assert surf.Z_basis.sym == "sin" + + surf.change_resolution(sym=False) + assert surf.sym is False + assert surf.R_basis.sym is False + assert surf.Z_basis.sym is False + + +@pytest.mark.unit +def test_section_change_only_symmetry(): + """Test that sym correctly changes when only sym is passed to change_resolution.""" + surf = ZernikeRZToroidalSection(sym=False) + surf.change_resolution(sym=True) + assert surf.sym + assert surf.R_basis.sym == "cos" + assert surf.Z_basis.sym == "sin" + + surf.change_resolution(sym=False) + assert surf.sym is False + assert surf.R_basis.sym is False + assert surf.Z_basis.sym is False