Skip to content

Commit

Permalink
Internal: Move error-scaling (by the error contraction rate) outside …
Browse files Browse the repository at this point in the history
…of the controller implementations (#795)

* Write a test for the controllers to ensure I don't break them

* Apply the controller repeatedly to tighten the test

* Clean up the source of proportional-integral control

* Make the control-inputs keywords to ensure that changing the signature raises erorrs

* Assign scaled error norm to a variable to prepare changing the signature

* Isolate error**(1/rate)

* Move error-power computation out of the controllers to reduce no. function arguments

* Merge error-normalisation with existing error normalisation

* Rename the error-scale function

* Add mypy extensions to the benchmark dependencies
  • Loading branch information
pnkraemer authored Oct 25, 2024
1 parent d17c76d commit a64ff3b
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 27 deletions.
5 changes: 4 additions & 1 deletion probdiffeq/backend/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from typing import Any, Callable, Generic, Optional, TypeVar # noqa: F401

import jax
from typing_extensions import TypeAlias # typing.TypeAlias requires 3.10+
from mypy_extensions import NamedArg # noqa: F401

# typing.TypeAlias requires 3.10+
from typing_extensions import TypeAlias

# Array
Array: TypeAlias = jax.Array
50 changes: 24 additions & 26 deletions probdiffeq/ivpsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
warnings,
)
from probdiffeq.backend import numpy as np
from probdiffeq.backend.typing import Any, Callable
from probdiffeq.backend.typing import Any, Callable, NamedArg


@containers.dataclass
Expand All @@ -24,7 +24,7 @@ class _Controller:
clip: Callable[[Any, float, float], Any]
"""(Optionally) clip the current step to not exceed t1."""

apply: Callable[[Any, float, float], Any]
apply: Callable[[Any, NamedArg(float, "error_power")], Any]
r"""Propose a time-step $\Delta t$."""

extract: Callable[[Any], float]
Expand All @@ -44,26 +44,27 @@ def control_proportional_integral(

class PIState(containers.NamedTuple):
dt: float
error_norm_previously_accepted: float
error_power_previously_accepted: float

def init(dt: float, /) -> PIState:
return PIState(dt, 1.0)

def apply(state: PIState, /, error_norm, error_contraction_rate) -> PIState:
dt_proposed, error_norm_prev = state
n1 = power_integral_unscaled / error_contraction_rate
n2 = power_proportional_unscaled / error_contraction_rate
def apply(state: PIState, /, *, error_power) -> PIState:
# error_power = error_norm ** (-1.0 / error_contraction_rate)
dt_proposed, error_power_prev = state

a1 = (1.0 / error_norm) ** n1
a2 = (error_norm_prev / error_norm) ** n2
a1 = error_power**power_integral_unscaled
a2 = (error_power / error_power_prev) ** power_proportional_unscaled
scale_factor_unclipped = safety * a1 * a2

scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max)
scale_factor = np.maximum(factor_min, scale_factor_clipped_min)
error_norm_prev = np.where(error_norm <= 1.0, error_norm, error_norm_prev)

# >= 1.0 because error_power is 1/scaled_error_norm
error_power_prev = np.where(error_power >= 1.0, error_power, error_power_prev)

dt_proposed = scale_factor * dt_proposed
return PIState(dt_proposed, error_norm_prev)
return PIState(dt_proposed, error_power_prev)

def extract(state: PIState, /) -> float:
dt_proposed, _error_norm_previously_accepted = state
Expand All @@ -90,8 +91,8 @@ def control_integral(
def init(dt, /):
return dt

def apply(dt, /, error_norm, error_contraction_rate):
error_power = error_norm ** (-1.0 / error_contraction_rate)
def apply(dt, /, *, error_power):
# error_power = error_norm ** (-1.0 / error_contraction_rate)
scale_factor_unclipped = safety * error_power

scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max)
Expand Down Expand Up @@ -175,16 +176,17 @@ def init(s0: _AdaState) -> _RejectionState:
def _inf_like(tree):
return tree_util.tree_map(lambda x: np.inf() * np.ones_like(x), tree)

larger_than_1 = 1.1
smaller_than_1 = 1.0 / 1.1 # the cond() must return True
return _RejectionState(
error_norm_proposed=larger_than_1,
error_norm_proposed=smaller_than_1,
control=s0.control,
proposed=_inf_like(s0.step_from),
step_from=s0.step_from,
)

def cond_fn(state: _RejectionState) -> bool:
return state.error_norm_proposed > 1.0
# error_norm_proposed is EEst ** (-1/rate), thus "<"
return state.error_norm_proposed < 1.0

def body_fn(state: _RejectionState) -> _RejectionState:
"""Attempt a step.
Expand All @@ -197,7 +199,6 @@ def body_fn(state: _RejectionState) -> _RejectionState:
state_control = self.control.clip(state.control, t=state.step_from.t, t1=t1)

# Perform the actual step.
# todo: error estimate should be a tuple (abs, rel)
error_estimate, state_proposed = self.solver.step(
state=state.step_from,
vector_field=vector_field,
Expand All @@ -207,26 +208,23 @@ def body_fn(state: _RejectionState) -> _RejectionState:
u_proposed = self.ssm.stats.qoi(state_proposed.strategy.hidden)[0]
u_step_from = self.ssm.stats.qoi(state_proposed.strategy.hidden)[0]
u = np.maximum(np.abs(u_proposed), np.abs(u_step_from))
error_norm = _normalise_error(error_estimate, u=u)
error_power = _error_scale_and_normalize(error_estimate, u=u)

# Propose a new step
state_control = self.control.apply(
state_control,
error_norm=error_norm,
error_contraction_rate=self.solver.error_contraction_rate,
)
state_control = self.control.apply(state_control, error_power=error_power)
return _RejectionState(
error_norm_proposed=error_norm, # new
error_norm_proposed=error_power, # new
proposed=state_proposed, # new
control=state_control, # new
step_from=state.step_from,
)

def _normalise_error(error_estimate, *, u):
def _error_scale_and_normalize(error_estimate, *, u):
error_relative = error_estimate / (self.atol + self.rtol * np.abs(u))
dim = np.atleast_1d(u).size
error_norm = linalg.vector_norm(error_relative, order=self.norm_ord)
return error_norm / np.sqrt(dim)
error_norm_rel = error_norm / np.sqrt(dim)
return error_norm_rel ** (-1.0 / self.solver.error_contraction_rate)

def extract(s: _RejectionState) -> _AdaState:
num_steps = state0.stats + 1
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ test =[
"diffeqzoo",
"diffrax",
"equinox",
"mypy_extensions",
]
format-and-lint =[
"pre-commit",
Expand All @@ -57,6 +58,7 @@ doc = [
"mkdocstrings-python",
"mkdocstrings",
"mkdocs-jupyter",
"mypy_extensions",
]


Expand Down
22 changes: 22 additions & 0 deletions tests/test_ivpsolve/test_controllers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Test the controllers."""

from probdiffeq import ivpsolve
from probdiffeq.backend import numpy as np


def test_equivalence_pi_vs_i(dt=0.1428, error_power=3.142, num_applies=4):
ctrl_pi = ivpsolve.control_proportional_integral(
power_integral_unscaled=1.0, power_proportional_unscaled=0.0
)
ctrl_i = ivpsolve.control_integral()

x_pi = ctrl_pi.init(dt)
for _ in range(num_applies):
x_pi = ctrl_pi.apply(x_pi, error_power=error_power)
x_pi = ctrl_pi.extract(x_pi)

x_i = ctrl_i.init(dt)
for _ in range(num_applies):
x_i = ctrl_i.apply(x_i, error_power=error_power)
x_i = ctrl_i.extract(x_i)
assert np.allclose(x_i, x_pi)

0 comments on commit a64ff3b

Please sign in to comment.