Skip to content

Commit

Permalink
Make clip/no-clip an argument to the controllers to reduce code dupli…
Browse files Browse the repository at this point in the history
…cation (#781)

* Make clip/no-clip an argument to the controllers to reduce code duplication

* Run pre-commit autoupdate to use the newest versions of linters
  • Loading branch information
pnkraemer authored Oct 14, 2024
1 parent 7ffbb8a commit a2e9cde
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 119 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@ default_language_version:
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: end-of-file-fixer
- id: check-merge-conflict
- repo: https://github.com/lyz-code/yamlfix/
rev: 1.16.0
rev: 1.17.0
hooks:
- id: yamlfix
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.7
rev: v0.6.9
hooks:
# Run the linter.
- id: ruff
args: [--fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.1
rev: v1.11.2
hooks:
- id: mypy
args: [--ignore-missing-imports]
2 changes: 1 addition & 1 deletion docs/benchmarks/hires/run_hires.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def param_to_solution(tol):
ts1 = ivpsolvers.correction_ts1()
strategy = ivpsolvers.strategy_filter(ibm, ts1)
solver = ivpsolvers.solver_dynamic(strategy)
control = ivpsolve.control_proportional_integral_clipped()
control = ivpsolve.control_proportional_integral(clip=True)
adaptive_solver = ivpsolve.adaptive(
solver, atol=1e-2 * tol, rtol=tol, control=control
)
Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/vanderpol/run_vanderpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def param_to_solution(tol):
ts0_or_ts1 = ivpsolvers.correction_ts1(ode_order=2)
strategy = ivpsolvers.strategy_filter(ibm, ts0_or_ts1)
solver = ivpsolvers.solver_dynamic(strategy)
control = ivpsolve.control_proportional_integral_clipped()
control = ivpsolve.control_proportional_integral(clip=True)
adaptive_solver = ivpsolve.adaptive(
solver, atol=1e-3 * tol, rtol=tol, control=control
)
Expand Down
158 changes: 48 additions & 110 deletions probdiffeq/ivpsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,144 +34,82 @@ class _Controller:

def control_proportional_integral(
*,
safety=0.95,
factor_min=0.2,
factor_max=10.0,
power_integral_unscaled=0.3,
power_proportional_unscaled=0.4,
) -> _Controller:
"""Construct a proportional-integral-controller."""
init = _proportional_integral_init
apply = functools.partial(
_proportional_integral_apply,
safety=safety,
factor_min=factor_min,
factor_max=factor_max,
power_integral_unscaled=power_integral_unscaled,
power_proportional_unscaled=power_proportional_unscaled,
)
extract = _proportional_integral_extract
return _Controller(init=init, apply=apply, extract=extract, clip=_no_clip)


def control_proportional_integral_clipped(
*,
clip: bool = False,
safety=0.95,
factor_min=0.2,
factor_max=10.0,
power_integral_unscaled=0.3,
power_proportional_unscaled=0.4,
) -> _Controller:
"""Construct a proportional-integral-controller with time-clipping."""
init = _proportional_integral_init
apply = functools.partial(
_proportional_integral_apply,
safety=safety,
factor_min=factor_min,
factor_max=factor_max,
power_integral_unscaled=power_integral_unscaled,
power_proportional_unscaled=power_proportional_unscaled,
)
extract = _proportional_integral_extract
clip = _proportional_integral_clip
return _Controller(init=init, apply=apply, extract=extract, clip=clip)

def init(dt, /):
return dt, 1.0

def _proportional_integral_apply(
state: tuple[float, float],
/,
error_normalised,
*,
error_contraction_rate,
safety,
factor_min,
factor_max,
power_integral_unscaled,
power_proportional_unscaled,
) -> tuple[float, float]:
dt_proposed, error_norm_previously_accepted = state
n1 = power_integral_unscaled / error_contraction_rate
n2 = power_proportional_unscaled / error_contraction_rate

a1 = (1.0 / error_normalised) ** n1
a2 = (error_norm_previously_accepted / error_normalised) ** n2
scale_factor_unclipped = safety * a1 * a2

scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max)
scale_factor = np.maximum(factor_min, scale_factor_clipped_min)
error_norm_previously_accepted = np.where(
error_normalised <= 1.0, error_normalised, error_norm_previously_accepted
)

dt_proposed = scale_factor * dt_proposed
return dt_proposed, error_norm_previously_accepted
def apply(
state: tuple[float, float], /, error_normalised, error_contraction_rate
) -> tuple[float, float]:
dt_proposed, error_norm_previously_accepted = state
n1 = power_integral_unscaled / error_contraction_rate
n2 = power_proportional_unscaled / error_contraction_rate

a1 = (1.0 / error_normalised) ** n1
a2 = (error_norm_previously_accepted / error_normalised) ** n2
scale_factor_unclipped = safety * a1 * a2

def _proportional_integral_init(dt0, /):
return dt0, 1.0
scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max)
scale_factor = np.maximum(factor_min, scale_factor_clipped_min)
error_norm_previously_accepted = np.where(
error_normalised <= 1.0, error_normalised, error_norm_previously_accepted
)

dt_proposed = scale_factor * dt_proposed
return dt_proposed, error_norm_previously_accepted

def _proportional_integral_clip(
state: tuple[float, float], /, t, t1
) -> tuple[float, float]:
dt_proposed, error_norm_previously_accepted = state
dt = dt_proposed
dt_clipped = np.minimum(dt, t1 - t)
return dt_clipped, error_norm_previously_accepted
def extract(state: tuple[float, float], /):
dt_proposed, _error_norm_previously_accepted = state
return dt_proposed

if clip:

def _proportional_integral_extract(state: tuple[float, float], /):
dt_proposed, _error_norm_previously_accepted = state
return dt_proposed
def clip_fun(state: tuple[float, float], /, t, t1) -> tuple[float, float]:
dt_proposed, error_norm_previously_accepted = state
dt = dt_proposed
dt_clipped = np.minimum(dt, t1 - t)
return dt_clipped, error_norm_previously_accepted

return _Controller(init=init, apply=apply, extract=extract, clip=clip_fun)

def control_integral(*, safety=0.95, factor_min=0.2, factor_max=10.0) -> _Controller:
"""Construct an integral-controller."""
init = _integral_init
apply = functools.partial(
_integral_apply, safety=safety, factor_min=factor_min, factor_max=factor_max
)
extract = _integral_extract
return _Controller(init=init, apply=apply, extract=extract, clip=_no_clip)
return _Controller(init=init, apply=apply, extract=extract, clip=lambda v, **_kw: v)


def control_integral_clipped(
*, safety=0.95, factor_min=0.2, factor_max=10.0
def control_integral(
*, clip=False, safety=0.95, factor_min=0.2, factor_max=10.0
) -> _Controller:
"""Construct an integral-controller with time-clipping."""
init = functools.partial(_integral_init)
apply = functools.partial(
_integral_apply, safety=safety, factor_min=factor_min, factor_max=factor_max
)
extract = functools.partial(_integral_extract)
return _Controller(init=init, apply=apply, extract=extract, clip=_integral_clip)


def _integral_init(dt0, /):
return dt0

"""Construct an integral-controller."""

def _integral_clip(dt, /, t, t1):
return np.minimum(dt, t1 - t)
def init(dt, /):
return dt

def apply(dt, /, error_normalised, error_contraction_rate):
error_power = error_normalised ** (-1.0 / error_contraction_rate)
scale_factor_unclipped = safety * error_power

def _no_clip(dt, /, *_args, **_kwargs):
return dt
scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max)
scale_factor = np.maximum(factor_min, scale_factor_clipped_min)
return scale_factor * dt

def extract(dt, /):
return dt

def _integral_apply(
dt, /, error_normalised, *, error_contraction_rate, safety, factor_min, factor_max
):
error_power = error_normalised ** (-1.0 / error_contraction_rate)
scale_factor_unclipped = safety * error_power
if clip:

scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max)
scale_factor = np.maximum(factor_min, scale_factor_clipped_min)
return scale_factor * dt
def clip_fun(dt, /, t, t1):
return np.minimum(dt, t1 - t)

return _Controller(init=init, apply=apply, extract=extract, clip=clip_fun)

def _integral_extract(dt, /):
return dt
return _Controller(init=init, apply=apply, extract=extract, clip=lambda v, **_kw: v)


def adaptive(solver, atol=1e-4, rtol=1e-2, control=None, norm_ord=None):
Expand Down
2 changes: 1 addition & 1 deletion probdiffeq/ivpsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def _calibration_none() -> _Calibration:
def init(prior):
return prior

def update(_state, /, observed): # noqa: ARG001
def update(_state, /, observed):
raise NotImplementedError

def extract(state, /):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_fixed_grid_result_matches_adaptive_grid_result(ssm):
ts0 = ivpsolvers.correction_ts0()
strategy = ivpsolvers.strategy_filter(ibm, ts0)
solver = ivpsolvers.solver_mle(strategy)
control = ivpsolve.control_integral_clipped() # Any clipped controller will do.
control = ivpsolve.control_integral(clip=True) # Any clipped controller will do.
adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, control=control)

tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_ivpsolve/test_save_every_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def fixture_python_loop_solution(ssm):
ts0 = ivpsolvers.correction_ts0()
strategy = ivpsolvers.strategy_filter(ibm, ts0)
solver = ivpsolvers.solver_mle(strategy)
adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2)
control = ivpsolve.control_proportional_integral(clip=True)
adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, control=control)

dt0 = ivpsolve.dt0_adaptive(
vf, u0, t0=t0, atol=1e-2, rtol=1e-2, error_contraction_rate=5
Expand Down

0 comments on commit a2e9cde

Please sign in to comment.