Skip to content

Commit

Permalink
Add more difficulty-based steppers (#2)
Browse files Browse the repository at this point in the history
* Add polynomial difficulty stepper

* Implement difficulty-based general nonlinear stepper

* Add simple streamlit notebook to understand difficulties
  • Loading branch information
Ceyron authored May 6, 2024
1 parent 4bcd5b2 commit 9ab18cc
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 3 deletions.
154 changes: 154 additions & 0 deletions examples/understanding_normalized_and_difficulty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
This is a streamlit app.
"""
import jax
import matplotlib.pyplot as plt
import streamlit as st

import exponax as ex

jax.config.update("jax_platform_name", "cpu")

with st.sidebar:
num_points = st.slider("Number of points", 16, 256, 48)
num_steps = st.slider("Number of steps", 1, 300, 50)
num_modes_init = st.slider("Number of modes in the initial condition", 1, 40, 5)
num_substeps = st.slider("Number of substeps", 1, 100, 1)

use_difficulty = st.toggle("Use difficulty", value=True)

overall_scale = st.slider("Overall scale", 0.1, 50.0, 1.0)

a_0_cols = st.columns(3)
with a_0_cols[0]:
a_0_mantissa = st.slider("a_0 mantissa", 0.0, 10.0, 0.0)
with a_0_cols[1]:
a_0_exponent = st.slider("a_0 exponent", -5, 5, 0)
with a_0_cols[2]:
a_0_sign = st.select_slider("a_0 sign", options=["-", "+"])
a_0 = float(f"{a_0_sign}{a_0_mantissa}e{a_0_exponent}")

a_1_cols = st.columns(3)
with a_1_cols[0]:
a_1_mantissa = st.slider("a_1 mantissa", 0.0, 10.0, 0.1)
with a_1_cols[1]:
a_1_exponent = st.slider("a_1 exponent", -5, 5, 0)
with a_1_cols[2]:
a_1_sign = st.select_slider("a_1 sign", options=["-", "+"])
a_1 = float(f"{a_1_sign}{a_1_mantissa}e{a_1_exponent}")

a_2_cols = st.columns(3)
with a_2_cols[0]:
a_2_mantissa = st.slider("a_2 mantissa", 0.0, 10.0, 0.0)
with a_2_cols[1]:
a_2_exponent = st.slider("a_2 exponent", -5, 5, 0)
with a_2_cols[2]:
a_2_sign = st.select_slider("a_2 sign", options=["-", "+"])
a_2 = float(f"{a_2_sign}{a_2_mantissa}e{a_2_exponent}")

a_3_cols = st.columns(3)
with a_3_cols[0]:
a_3_mantissa = st.slider("a_3 mantissa", 0.0, 10.0, 0.0)
with a_3_cols[1]:
a_3_exponent = st.slider("a_3 exponent", -5, 5, 0)
with a_3_cols[2]:
a_3_sign = st.select_slider("a_3 sign", options=["-", "+"])
a_3 = float(f"{a_3_sign}{a_3_mantissa}e{a_3_exponent}")

a_4_cols = st.columns(3)
with a_4_cols[0]:
a_4_mantissa = st.slider("a_4 mantissa", 0.0, 10.0, 0.0)
with a_4_cols[1]:
a_4_exponent = st.slider("a_4 exponent", -5, 5, 0)
with a_4_cols[2]:
a_4_sign = st.select_slider("a_4 sign", options=["-", "+"])
a_4 = float(f"{a_4_sign}{a_4_mantissa}e{a_4_exponent}")

b_0_cols = st.columns(3)
with b_0_cols[0]:
b_0_mantissa = st.slider("b_0 mantissa", 0.0, 10.0, 0.0)
with b_0_cols[1]:
b_0_exponent = st.slider("b_0 exponent", -5, 5, 0)
with b_0_cols[2]:
b_0_sign = st.select_slider("b_0 sign", options=["-", "+"])
b_0 = float(f"{b_0_sign}{b_0_mantissa}e{b_0_exponent}")

b_1_cols = st.columns(3)
with b_1_cols[0]:
b_1_mantissa = st.slider("b_1 mantissa", 0.0, 10.0, 0.0)
with b_1_cols[1]:
b_1_exponent = st.slider("b_1 exponent", -5, 5, 0)
with b_1_cols[2]:
b_1_sign = st.select_slider("b_1 sign", options=["-", "+"])
b_1 = float(f"{b_1_sign}{b_1_mantissa}e{b_1_exponent}")

b_2_cols = st.columns(3)
with b_2_cols[0]:
b_2_mantissa = st.slider("b_2 mantissa", 0.0, 10.0, 0.0)
with b_2_cols[1]:
b_2_exponent = st.slider("b_2 exponent", -5, 5, 0)
with b_2_cols[2]:
b_2_sign = st.select_slider("b_2 sign", options=["-", "+"])
b_2 = float(f"{b_2_sign}{b_2_mantissa}e{b_2_exponent}")

# a_0 = st.slider("a_0", -10.0, 10.0, 0.0)
# a_1 = st.slider("a_1", -10.0, 10.0, 0.1)
# a_2 = st.slider("a_2", -10.0, 10.0, 0.0)
# a_3 = st.slider("a_3", -10.0, 10.0, 0.0)
# a_4 = st.slider("a_4", -10.0, 10.0, 0.0)
# b_0 = st.slider("b_0", -10.0, 10.0, 0.0)
# b_1 = st.slider("b_1", -10.0, 10.0, 0.0)
# b_2 = st.slider("b_2", -10.0, 10.0, 0.0)

linear_tuple = (a_0, a_1, a_2, a_3, a_4)
nonlinear_tuple = (b_0, b_1, b_2)

linear_tuple = tuple([overall_scale * x for x in linear_tuple])
nonlinear_tuple = tuple([overall_scale * x for x in nonlinear_tuple])

if use_difficulty:
stepper = ex.RepeatedStepper(
ex.normalized.DifficultyGeneralNonlinearStepper(
1,
num_points,
linear_difficulties=tuple(x / num_substeps for x in linear_tuple),
nonlinear_difficulties=tuple(x / num_substeps for x in nonlinear_tuple),
),
num_substeps,
)
else:
stepper = ex.RepeatedStepper(
ex.normalized.NormlizedGeneralNonlinearStepper(
1,
num_points,
normalized_coefficients_linear=tuple(
x / num_substeps for x in linear_tuple
),
normalized_coefficients_nonlinear=tuple(
x / num_substeps for x in nonlinear_tuple
),
),
num_substeps,
)

ic_gen = ex.ic.RandomSineWaves1d(1, cutoff=num_modes_init, max_one=True)
u_0 = ic_gen(num_points, key=jax.random.PRNGKey(0))

trj = ex.rollout(stepper, num_steps, include_init=True)(u_0)

v_range = st.slider("Colorbar range", 0.1, 10.0, 1.0)

fig, ax = plt.subplots()
ax.imshow(
trj[:, 0, :].T,
aspect="auto",
vmin=-v_range,
vmax=v_range,
cmap="RdBu_r",
origin="lower",
)

st.write(f"Linear: {linear_tuple}")
st.write(f"Nonlinear: {nonlinear_tuple}")

st.pyplot(fig)
9 changes: 7 additions & 2 deletions exponax/normalized/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
slightly differently.
"""
from ._convection import DifficultyConvectionStepper, NormalizedConvectionStepper
from ._general_nonlinear import NormlizedGeneralNonlinearStepper
from ._general_nonlinear import (
DifficultyGeneralNonlinearStepper,
NormlizedGeneralNonlinearStepper,
)
from ._gradient_norm import DifficultyGradientNormStepper, NormalizedGradientNormStepper
from ._linear import (
DifficultyLinearStepper,
DiffultyLinearStepperSimple,
NormalizedLinearStepper,
)
from ._polynomial import NormalizedPolynomialStepper
from ._polynomial import DifficultyPolynomialStepper, NormalizedPolynomialStepper
from ._utils import (
denormalize_coefficients,
denormalize_convection_scale,
Expand All @@ -39,6 +42,8 @@
"DiffultyLinearStepperSimple",
"DifficultyConvectionStepper",
"DifficultyGradientNormStepper",
"DifficultyPolynomialStepper",
"DifficultyGeneralNonlinearStepper",
"NormalizedConvectionStepper",
"NormlizedGeneralNonlinearStepper",
"NormalizedGradientNormStepper",
Expand Down
65 changes: 64 additions & 1 deletion exponax/normalized/_general_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

from .._base_stepper import BaseStepper
from ..nonlin_fun import GeneralNonlinearFun
from ._utils import (
extract_normalized_coefficients_from_difficulty,
extract_normalized_nonlinear_scales_from_difficulty,
)


class NormlizedGeneralNonlinearStepper(BaseStepper):
Expand All @@ -15,7 +19,7 @@ def __init__(
num_spatial_dims: int,
num_points: int,
*,
normalized_coefficients_linear: tuple[float, ...] = (0.0, 0.0, 0.01 * 0.1),
normalized_coefficients_linear: tuple[float, ...] = (0.0, 0.0, 0.1 * 0.1),
normalized_coefficients_nonlinear: tuple[float, ...] = (0.0, -1.0 * 0.1, 0.0),
order=2,
dealiasing_fraction: float = 2 / 3,
Expand Down Expand Up @@ -68,3 +72,62 @@ def _build_nonlinear_fun(
scale_list=self.normalized_coefficients_nonlinear,
zero_mode_fix=True, # ToDo: check this
)


class DifficultyGeneralNonlinearStepper(NormlizedGeneralNonlinearStepper):
linear_difficulties: tuple[float, ...]
nonlinear_difficulties: tuple[float, ...]

def __init__(
self,
num_spatial_dims: int = 1,
num_points: int = 48,
*,
linear_difficulties: tuple[float, ...] = (
0.0,
0.0,
0.1 * 0.1 / 1.0 * 48**2 * 2,
),
nonlinear_difficulties: tuple[float, ...] = (
0.0,
-1.0 * 0.1 / 1.0 * 48,
0.0,
),
maximum_absolute: float = 1.0,
order: int = 2,
dealiasing_fraction: float = 2 / 3,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
"""
By default Burgers.
"""
self.linear_difficulties = linear_difficulties
self.nonlinear_difficulties = nonlinear_difficulties

normalized_coefficients_linear = (
extract_normalized_coefficients_from_difficulty(
linear_difficulties,
num_spatial_dims=num_spatial_dims,
num_points=num_points,
)
)
normalized_coefficients_nonlinear = (
extract_normalized_nonlinear_scales_from_difficulty(
nonlinear_difficulties,
num_spatial_dims=num_spatial_dims,
num_points=num_points,
maximum_absolute=maximum_absolute,
)
)

super().__init__(
num_spatial_dims=num_spatial_dims,
num_points=num_points,
normalized_coefficients_linear=normalized_coefficients_linear,
normalized_coefficients_nonlinear=normalized_coefficients_nonlinear,
order=order,
dealiasing_fraction=dealiasing_fraction,
num_circle_points=num_circle_points,
circle_radius=circle_radius,
)
51 changes: 51 additions & 0 deletions exponax/normalized/_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .._base_stepper import BaseStepper
from ..nonlin_fun import PolynomialNonlinearFun
from ._utils import extract_normalized_coefficients_from_difficulty


class NormalizedPolynomialStepper(BaseStepper):
Expand Down Expand Up @@ -72,3 +73,53 @@ def _build_nonlinear_fun(
coefficients=self.normalized_polynomial_scales,
dealiasing_fraction=self.dealiasing_fraction,
)


class DifficultyPolynomialStepper(NormalizedPolynomialStepper):
linear_difficulties: tuple[float, ...]
polynomial_difficulties: tuple[float, ...]

def __init__(
self,
num_spatial_dims: int = 1,
num_points: int = 48,
*,
linear_difficulties: tuple[float, ...] = (
10.0 * 0.001 / (10.0**0) * 48**0,
0.0,
1.0 * 0.001 / (10.0**2) * 48**2 * 2**1,
),
polynomial_difficulties: tuple[float, ...] = (
0.0,
0.0,
-10.0 * 0.001,
),
order: int = 2,
dealiasing_fraction: float = 2 / 3,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
"""
By default: Fisher-KPP
"""
self.linear_difficulties = linear_difficulties
self.polynomial_difficulties = polynomial_difficulties

normalized_coefficients = extract_normalized_coefficients_from_difficulty(
linear_difficulties,
num_spatial_dims=num_spatial_dims,
num_points=num_points,
)
# For polynomial nonlinearities, we have difficulties == normalized scales
normalized_polynomial_scales = polynomial_difficulties

super().__init__(
num_spatial_dims=num_spatial_dims,
num_points=num_points,
normalized_coefficients=normalized_coefficients,
normalized_polynomial_scales=normalized_polynomial_scales,
order=order,
dealiasing_fraction=dealiasing_fraction,
num_circle_points=num_circle_points,
circle_radius=circle_radius,
)
50 changes: 50 additions & 0 deletions exponax/normalized/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,53 @@ def extract_normalized_gradient_norm_scale_from_difficulty(
maximum_absolute * jnp.square(num_points) * num_spatial_dims
)
return normalized_gradient_norm_scale


def reduce_normalized_nonlinear_scales_to_difficulty(
normalized_nonlinear_scales: tuple[float],
*,
num_spatial_dims: int,
num_points: int,
maximum_absolute: float,
):
nonlinear_difficulties = (
normalized_nonlinear_scales[0], # Polynomial: normalized == difficulty
reduce_normalized_convection_scale_to_difficulty(
normalized_nonlinear_scales[1],
num_spatial_dims=num_spatial_dims,
num_points=num_points,
maximum_absolute=maximum_absolute,
),
reduce_normalized_gradient_norm_scale_to_difficulty(
normalized_nonlinear_scales[2],
num_spatial_dims=num_spatial_dims,
num_points=num_points,
maximum_absolute=maximum_absolute,
),
)
return nonlinear_difficulties


def extract_normalized_nonlinear_scales_from_difficulty(
nonlinear_difficulties: tuple[float],
*,
num_spatial_dims: int,
num_points: int,
maximum_absolute: float,
):
normalized_nonlinear_scales = (
nonlinear_difficulties[0], # Polynomial: normalized == difficulty
extract_normalized_convection_scale_from_difficulty(
nonlinear_difficulties[1],
num_spatial_dims=num_spatial_dims,
num_points=num_points,
maximum_absolute=maximum_absolute,
),
extract_normalized_gradient_norm_scale_from_difficulty(
nonlinear_difficulties[2],
num_spatial_dims=num_spatial_dims,
num_points=num_points,
maximum_absolute=maximum_absolute,
),
)
return normalized_nonlinear_scales

0 comments on commit 9ab18cc

Please sign in to comment.