Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix .change_resolution(sym=False) for FourierRZToroidalSurface, ZernikeRZToroidalSection and FourierRZCurve #1593

Merged
merged 10 commits into from
Feb 21, 2025
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Bug Fixes
- Fixes the coil currents in ``desc.coils.initialize_modular_coils`` to now give the correct expected linking current.
- ``desc.objectives.PlasmaVesselDistance`` now correctly accounts for multiple field periods on both the equilibrium and the vessel surface. Previously it only considered distances within a single field period.
- Sets ``os.environ["JAX_PLATFORMS"] = "cpu"`` instead of ``os.environ["JAX_PLATFORM_NAME"] = "cpu"`` when doing ``set_device("cpu")``.

- Fixes bug when passing only `sym` into `.change_resolution` for ``FourierRZToroidalSurface``, ``ZernikeRZToroidalSection`` and ``FourierRZCurve``.

Performance Improvements

Expand Down
20 changes: 15 additions & 5 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@
"""
NFP = check_posint(NFP, "NFP")
self._NFP = NFP if NFP is not None else self.NFP
if N != self.N:
if N != self.N or (sym is not None and sym != self.sym):
self._N = check_nonnegint(N, "N", False)
self._sym = sym if sym is not None else self.sym
self._modes = self._get_modes(self.N)
Expand Down Expand Up @@ -690,7 +690,7 @@
"""
NFP = check_posint(NFP, "NFP")
self._NFP = NFP if NFP is not None else self.NFP
if M != self.M or N != self.N or sym != self.sym:
if M != self.M or N != self.N or (sym is not None and sym != self.sym):
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
self._sym = sym if sym is not None else self.sym
Expand Down Expand Up @@ -898,7 +898,7 @@
None

"""
if L != self.L or M != self.M or sym != self.sym:
if L != self.L or M != self.M or (sym is not None and sym != self.sym):

Check warning on line 901 in desc/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/basis.py#L901

Added line #L901 was not covered by tests
self._L = check_nonnegint(L, "L", False)
self._M = check_nonnegint(M, "M", False)
self._sym = sym if sym is not None else self.sym
Expand Down Expand Up @@ -1073,7 +1073,12 @@
"""
NFP = check_posint(NFP, "NFP")
self._NFP = NFP if NFP is not None else self.NFP
if L != self.L or M != self.M or N != self.N or sym != self.sym:
if (

Check warning on line 1076 in desc/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/basis.py#L1076

Added line #L1076 was not covered by tests
L != self.L
or M != self.M
or N != self.N
or (sym is not None and sym != self.sym)
):
self._L = check_nonnegint(L, "L", False)
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
Expand Down Expand Up @@ -1311,7 +1316,12 @@
"""
NFP = check_posint(NFP, "NFP")
self._NFP = NFP if NFP is not None else self.NFP
if L != self.L or M != self.M or N != self.N or sym != self.sym:
if (
L != self.L
or M != self.M
or N != self.N
or (sym is not None and sym != self.sym)
):
self._L = check_nonnegint(L, "L", False)
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
Expand Down
4 changes: 2 additions & 2 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def __init__(
self._M_grid = setdefault(M_grid, 2 * self.M)
self._N_grid = setdefault(N_grid, 2 * self.N)

self._surface.change_resolution(self.L, self.M, self.N)
self._axis.change_resolution(self.N)
self._surface.change_resolution(self.L, self.M, self.N, sym=self.sym)
self._axis.change_resolution(self.N, sym=self.sym)

# bases
self._R_basis = FourierZernikeBasis(
Expand Down
3 changes: 1 addition & 2 deletions desc/geometry/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def change_resolution(self, N=None, NFP=None, sym=None):
if (
((N is not None) and (N != self.N))
or ((NFP is not None) and (NFP != self.NFP))
or (sym is not None)
and (sym != self.sym)
or ((sym is not None) and (sym != self.sym))
):
self._NFP = int(NFP if NFP is not None else self.NFP)
self._sym = bool(sym) if sym is not None else self.sym
Expand Down
11 changes: 8 additions & 3 deletions desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,14 @@
N = check_nonnegint(N, "N")
NFP = check_posint(NFP, "NFP")
self._NFP = int(NFP if NFP is not None else self.NFP)
self._sym = sym if sym is not None else self.sym

if (
((N is not None) and (N != self.N))
or ((M is not None) and (M != self.M))
or (NFP is not None)
or ((sym is not None) and (sym != self.sym))
):
self._sym = sym if sym is not None else self.sym
M = int(M if M is not None else self.M)
N = int(N if N is not None else self.N)
R_modes_old = self.R_basis.modes
Expand Down Expand Up @@ -975,9 +976,13 @@

L = check_nonnegint(L, "L")
M = check_nonnegint(M, "M")
self._sym = sym if sym is not None else self.sym

if ((L is not None) and (L != self.L)) or ((M is not None) and (M != self.M)):
if (
((L is not None) and (L != self.L))
or ((M is not None) and (M != self.M))
or ((sym is not None) and (sym != self.sym))
):
self._sym = sym if sym is not None else self.sym

Check warning on line 985 in desc/geometry/surface.py

View check run for this annotation

Codecov / codecov/patch

desc/geometry/surface.py#L985

Added line #L985 was not covered by tests
L = int(L if L is not None else self.L)
M = int(M if M is not None else self.M)
R_modes_old = self.R_basis.modes
Expand Down
14 changes: 14 additions & 0 deletions tests/test_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,20 @@ def test_to_FourierRZCurve(self):
with pytest.raises(ValueError):
xyz.to_FourierRZ(N=1, grid=grid)

@pytest.mark.unit
def test_change_symmetry(self):
"""Test correct sym changes when only sym is passed to change_resolution."""
c = FourierRZCurve(sym=False)
c.change_resolution(sym=True)
assert c.sym
assert c.R_basis.sym == "cos"
assert c.Z_basis.sym == "sin"

c.change_resolution(sym=False)
assert c.sym is False
assert c.R_basis.sym is False
assert c.Z_basis.sym is False


class TestFourierXYZCurve:
"""Tests for FourierXYZCurve class."""
Expand Down
58 changes: 51 additions & 7 deletions tests/test_equilibrium.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Tests for Equilibrium class."""

import os
import pickle
import warnings

import numpy as np
import pytest
from qic import Qic

from desc.__main__ import main
from desc.backend import sign
Expand Down Expand Up @@ -263,6 +263,9 @@ def test_eq_change_symmetry():
assert eq.surface.sym
assert eq.surface.R_basis.sym == "cos"
assert eq.surface.Z_basis.sym == "sin"
assert eq.axis.sym
assert eq.axis.R_basis.sym == "cos"
assert eq.axis.Z_basis.sym == "sin"

# undo symmetry
eq.change_resolution(sym=False)
Expand All @@ -276,6 +279,9 @@ def test_eq_change_symmetry():
assert not eq.surface.sym
assert not eq.surface.R_basis.sym
assert not eq.surface.Z_basis.sym
assert eq.axis.sym is False
assert eq.axis.R_basis.sym is False
assert eq.axis.Z_basis.sym is False


@pytest.mark.unit
Expand All @@ -292,20 +298,58 @@ def test_resolution():
@pytest.mark.unit
def test_equilibrium_from_near_axis():
"""Test loading a solution from pyQSC/pyQIC."""
qsc_path = "./tests/inputs/qsc_r2section5.5.pkl"
file = open(qsc_path, "rb")
na = pickle.load(file)
file.close()
na = Qic.from_paper("r2 section 5.5", rs=[0, 1e-5], zc=[0, 1e-5])

r = 1e-2
eq = Equilibrium.from_near_axis(na, r=r, M=8, N=8)
grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, sym=eq.sym)
data = eq.compute("|B|", grid=grid)

# get the sin/cos modes
eq_rc = eq.Ra_n[
np.where(
np.logical_and(
eq.axis.R_basis.modes[:, 2] >= 0,
eq.axis.R_basis.modes[:, 2] < na.nfourier,
)
)
]
eq_zc = eq.Za_n[
np.where(
np.logical_and(
eq.axis.Z_basis.modes[:, 2] >= 0,
eq.axis.Z_basis.modes[:, 2] < na.nfourier,
)
)
]
eq_rs = np.flipud(
eq.Ra_n[
np.where(
np.logical_and(
eq.axis.R_basis.modes[:, 2] < 0,
eq.axis.R_basis.modes[:, 2] > -na.nfourier,
)
)
]
)
eq_zs = np.flipud(
eq.Za_n[
np.where(
np.logical_and(
eq.axis.Z_basis.modes[:, 2] < 0,
eq.axis.Z_basis.modes[:, 2] > -na.nfourier,
)
)
]
)

assert eq.is_nested()
assert eq.NFP == na.nfp
np.testing.assert_allclose(eq.Ra_n[:2], na.rc, atol=1e-10)
np.testing.assert_allclose(eq.Za_n[-2:], na.zs, atol=1e-10)
np.testing.assert_allclose(eq_rc, na.rc, atol=1e-10)
# na.zs[0] is always 0, which DESC doesn't include
np.testing.assert_allclose(eq_zs, na.zs[1:], atol=1e-10)
np.testing.assert_allclose(eq_rs, na.rs[1:], atol=1e-10)
np.testing.assert_allclose(eq_zc, na.zc, atol=1e-10)
np.testing.assert_allclose(data["|B|"][0], na.B_mag(r, 0, 0), rtol=2e-2)


Expand Down
30 changes: 30 additions & 0 deletions tests/test_surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,33 @@ def test_surface_orientation():
assert surf._compute_orientation() == -1
eq = Equilibrium(M=surf.M, N=surf.N, surface=surf, check_orientation=False)
assert np.sign(eq.compute("sqrt(g)")["sqrt(g)"].mean()) == -1


@pytest.mark.unit
def test_surface_change_only_symmetry():
"""Test that sym correctly changes when only sym is passed to change_resolution."""
surf = FourierRZToroidalSurface(sym=False)
surf.change_resolution(sym=True)
assert surf.sym
assert surf.R_basis.sym == "cos"
assert surf.Z_basis.sym == "sin"

surf.change_resolution(sym=False)
assert surf.sym is False
assert surf.R_basis.sym is False
assert surf.Z_basis.sym is False


@pytest.mark.unit
def test_section_change_only_symmetry():
"""Test that sym correctly changes when only sym is passed to change_resolution."""
surf = ZernikeRZToroidalSection(sym=False)
surf.change_resolution(sym=True)
assert surf.sym
assert surf.R_basis.sym == "cos"
assert surf.Z_basis.sym == "sin"

surf.change_resolution(sym=False)
assert surf.sym is False
assert surf.R_basis.sym is False
assert surf.Z_basis.sym is False
Loading