Skip to content

Commit

Permalink
Merge branch 'master' into yge/less-factorize
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis authored Feb 17, 2025
2 parents 10d84ab + 5492090 commit 86543e7
Show file tree
Hide file tree
Showing 14 changed files with 524 additions and 64 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ New Features
for compatibility with other codes which expect such files from the Booz_Xform code.
- Renames compute quantity ``sqrt(g)_B`` to ``sqrt(g)_Boozer_DESC`` to more accurately reflect what the quantiy is (the jacobian from (rho,theta_B,zeta_B) to (rho,theta,zeta)), and adds a new function to compute ``sqrt(g)_Boozer`` which is the jacobian from (rho,theta_B,zeta_B) to (R,phi,Z).
- Allows specification of Nyquist spectrum maximum modenumbers when using ``VMECIO.save`` to save a DESC .h5 file as a VMEC-format wout file
- Adds a new objective ``desc.objectives.ExternalObjective`` for wrapping external codes with finite differences.
- DESC/JAX version and device info is no longer printed by default, but can be accessed with the function `desc.backend.print_backend_info()`.

Speed Improvements

Expand Down
2 changes: 2 additions & 0 deletions desc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def main(cl_args=sys.argv[1:]):

import matplotlib.pyplot as plt

from desc.backend import print_backend_info
from desc.equilibrium import EquilibriaFamily, Equilibrium
from desc.plotting import plot_section, plot_surfaces

if ir.args.verbose:
print_backend_info()
print("Reading input from {}".format(ir.input_path))
print("Outputs will be written to {}".format(ir.output_path))

Expand Down
68 changes: 46 additions & 22 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings

import numpy as np
from packaging.version import Version
from termcolor import colored

import desc
Expand All @@ -15,11 +16,6 @@
jnp = np
use_jax = False
set_device(kind="cpu")
print(
"DESC version {}, using numpy backend, version={}, dtype={}".format(
desc.__version__, np.__version__, np.linspace(0, 1).dtype
)
)
else:
if desc_config.get("device") is None:
set_device("cpu")
Expand All @@ -41,29 +37,31 @@
x = jnp.linspace(0, 5)
y = jnp.exp(x)
use_jax = True
print(
f"DESC version {desc.__version__},"
+ f"using JAX backend, jax version={jax.__version__}, "
+ f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}"
)
del x, y
except ModuleNotFoundError:
jnp = np
x = jnp.linspace(0, 5)
y = jnp.exp(x)
use_jax = False
set_device(kind="cpu")
warnings.warn(colored("Failed to load JAX", "red"))


def print_backend_info():
"""Prints DESC version, backend type & version, device type & memory."""
print(f"DESC version={desc.__version__}.")
if use_jax:
print(
"DESC version {}, using NumPy backend, version={}, dtype={}".format(
desc.__version__, np.__version__, y.dtype
)
f"Using JAX backend: jax version={jax.__version__}, "
+ f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}."
)
else:
print(f"Using NumPy backend: version={np.__version__}, dtype={y.dtype}.")
print(
"Using device: {}, with {:.2f} GB available memory.".format(
desc_config.get("device"), desc_config.get("avail_mem")
)
print(
"Using device: {}, with {:.2f} GB available memory".format(
desc_config.get("device"), desc_config.get("avail_mem")
)
)


if use_jax: # noqa: C901
from jax import custom_jvp, jit, vmap
Expand All @@ -85,13 +83,35 @@
treedef_is_leaf,
)

# TODO: update this when JAX min version >= 0.4.26
if hasattr(jnp, "trapezoid"):
trapezoid = jnp.trapezoid # for JAX 0.4.26 and later
elif hasattr(jax.scipy, "integrate"):
trapezoid = jax.scipy.integrate.trapezoid
else:
trapezoid = jnp.trapz # for older versions of JAX, deprecated by jax 0.4.16

# TODO: update this when JAX min version >= 0.4.35
if Version(jax.__version__) >= Version("0.4.35"):

def pure_callback(func, result_shape_dtype, *args, vectorized=False, **kwargs):
"""Wrapper for jax.pure_callback for versions >=0.4.35."""
return jax.pure_callback(
func,
result_shape_dtype,
*args,
vmap_method="expand_dims" if vectorized else "sequential",
**kwargs,
)

else:

def pure_callback(func, result_shape_dtype, *args, vectorized=False, **kwargs):
"""Wrapper for jax.pure_callback for versions <0.4.35."""
return jax.pure_callback(
func, result_shape_dtype, *args, vectorized=vectorized, **kwargs
)

def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.
Expand Down Expand Up @@ -481,6 +501,10 @@ def vmap(fun, in_axes=0, out_axes=0):
"""
return lambda xs: _map(fun, xs, in_axes=in_axes, out_axes=out_axes)

def pure_callback(*args, **kwargs):
"""IO callback for numpy backend."""
raise NotImplementedError

def tree_stack(*args, **kwargs):
"""Stack pytree for numpy backend."""
raise NotImplementedError
Expand Down Expand Up @@ -586,7 +610,7 @@ def fori_loop(lower, upper, body_fun, init_val):
val = body_fun(i, val)
return val

def cond(pred, true_fun, false_fun, *operand):
def cond(pred, true_fun, false_fun, *operands):
"""Conditionally apply true_fun or false_fun.
This version is for the numpy backend, for jax backend see jax.lax.cond
Expand All @@ -599,7 +623,7 @@ def cond(pred, true_fun, false_fun, *operand):
Function (A -> B), to be applied if pred is True.
false_fun: callable
Function (A -> B), to be applied if pred is False.
operand: any
operands: any
input to either branch depending on pred. The type can be a scalar, array,
or any pytree (nested Python tuple/list/dict) thereof.
Expand All @@ -612,9 +636,9 @@ def cond(pred, true_fun, false_fun, *operand):
"""
if pred:
return true_fun(*operand)
return true_fun(*operands)
else:
return false_fun(*operand)
return false_fun(*operands)

def switch(index, branches, operand):
"""Apply exactly one of branches given by index.
Expand Down
2 changes: 1 addition & 1 deletion desc/integrals/surface_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def surface_integrals_map(grid, surface_label="rho", expand_out=True, tol=1e-14)
has_endpoint_dupe,
lambda _: put(mask, jnp.array([0, -1]), mask[0] | mask[-1]),
lambda _: mask,
operand=None,
None,
)
else:
# If we don't have the idx attributes, we are forced to expand out.
Expand Down
7 changes: 6 additions & 1 deletion desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
)
from ._fast_ion import GammaC
from ._free_boundary import BoundaryError, VacuumBoundaryError
from ._generic import GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser
from ._generic import (
ExternalObjective,
GenericObjective,
LinearObjectiveFromUser,
ObjectiveFromUser,
)
from ._geometry import (
AspectRatio,
BScaleLength,
Expand Down
Loading

0 comments on commit 86543e7

Please sign in to comment.