Skip to content

Commit

Permalink
Internal: Reduce unnecessary complexity in correction code (#792)
Browse files Browse the repository at this point in the history
* Replace needlessly-functional code with a few classes to improve hackability

* Remove unused arguments from the correction API

* Rename CorrectionTaylor to remove linebreaks

* Use dataclasses instead of modules
  • Loading branch information
pnkraemer authored Oct 25, 2024
1 parent 8b8982c commit 7063aed
Showing 1 changed file with 69 additions and 101 deletions.
170 changes: 69 additions & 101 deletions probdiffeq/ivpsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,132 +427,100 @@ class _Correction:

name: str
ode_order: int
ssm: Any
linearize: Callable

init: Callable
"""Initialise the state from the solution."""

estimate_error: Callable
"""Perform all elements of the correction until the error estimate."""

complete: Callable
"""Complete what has been left out by `estimate_error`."""

extract: Callable
"""Extract the solution from the state."""

def init(self, x, /):
"""Initialise the state from the solution."""
raise NotImplementedError

def correction_ts0(*, ssm, ode_order=1) -> _Correction:
"""Zeroth-order Taylor linearisation."""
return _correction_constraint_ode_taylor(
ssm=ssm,
ode_order=ode_order,
linearise_fun=ssm.linearise.ode_taylor_0th(ode_order=ode_order),
name=f"<TS0 with ode_order={ode_order}>",
)
def estimate_error(self, x, /, vector_field, t):
"""Perform all elements of the correction until the error estimate."""
raise NotImplementedError

def complete(self, x, cache, /):
"""Complete what has been left out by `estimate_error`."""
raise NotImplementedError

def correction_ts1(*, ssm, ode_order=1) -> _Correction:
"""First-order Taylor linearisation."""
return _correction_constraint_ode_taylor(
ssm=ssm,
ode_order=ode_order,
linearise_fun=ssm.linearise.ode_taylor_1st(ode_order=ode_order),
name=f"<TS1 with ode_order={ode_order}>",
)
def extract(self, x, /):
"""Extract the solution from the state."""
raise NotImplementedError


def _correction_constraint_ode_taylor(
ode_order, linearise_fun, name, *, ssm
) -> _Correction:
def init(ssv, /):
obs_like = ssm.prototypes.observed()
return ssv, obs_like
@containers.dataclass
class _CorrectionTS(_Correction):
def init(self, x, /):
y = self.ssm.prototypes.observed()
return x, y

def estimate_error(hidden_state, _corr, /, vector_field, t):
def estimate_error(self, x, /, vector_field, t):
def f_wrapped(s):
return vector_field(*s, t=t)

A, b = linearise_fun(f_wrapped, hidden_state.mean)
observed = ssm.transform.marginalise(hidden_state, (A, b))
A, b = self.linearize(f_wrapped, x.mean)
observed = self.ssm.transform.marginalise(x, (A, b))

error_estimate = _estimate_error(observed, ssm=ssm)
return error_estimate, observed, (A, b)
error_estimate = _estimate_error(observed, ssm=self.ssm)
return error_estimate, observed, {"linearization": (A, b)}

def complete(hidden_state, corr, /):
A, b = corr
observed, (_gain, corrected) = ssm.transform.revert(hidden_state, (A, b))
def complete(self, x, cache, /):
A, b = cache["linearization"]
observed, (_gain, corrected) = self.ssm.transform.revert(x, (A, b))
return corrected, observed

def extract(ssv, _corr, /):
return ssv
def extract(self, x, /):
return x

return _Correction(
ode_order=ode_order,
name=name,
init=init,
estimate_error=estimate_error,
complete=complete,
extract=extract,
)

@containers.dataclass
class _CorrectionSLR(_Correction):
def init(self, x, /):
y = self.ssm.prototypes.observed()
return x, y

def correction_slr0(*, ssm, cubature_fun=cubature_third_order_spherical) -> _Correction:
"""Zeroth-order statistical linear regression."""
linearise_fun = ssm.linearise.ode_statistical_1st(cubature_fun)
return _correction_constraint_ode_statistical(
ssm=ssm,
ode_order=1,
linearise_fun=linearise_fun,
name=f"<SLR1 with ode_order={1}>",
)


def correction_slr1(*, ssm, cubature_fun=cubature_third_order_spherical) -> _Correction:
"""First-order statistical linear regression."""
linearise_fun = ssm.linearise.ode_statistical_0th(cubature_fun)
return _correction_constraint_ode_statistical(
ssm=ssm,
ode_order=1,
linearise_fun=linearise_fun,
name=f"<SLR0 with ode_order={1}>",
)


def _correction_constraint_ode_statistical(
ode_order, linearise_fun, name, *, ssm
) -> _Correction:
def init(ssv, /):
obs_like = ssm.prototypes.observed()
return ssv, obs_like

def estimate_error(hidden_state, _corr, /, vector_field, t):
def estimate_error(self, x, /, vector_field, t):
f_wrapped = functools.partial(vector_field, t=t)
A, b = linearise_fun(f_wrapped, hidden_state)
observed = ssm.conditional.marginalise(hidden_state, (A, b))
A, b = self.linearize(f_wrapped, x)
observed = self.ssm.conditional.marginalise(x, (A, b))

error_estimate = _estimate_error(observed, ssm=ssm)
error_estimate = _estimate_error(observed, ssm=self.ssm)
return error_estimate, observed, (A, b, f_wrapped)

def complete(hidden_state, corr, /):
def complete(self, x, cache, /):
# Re-linearise (because the linearisation point changed)
*_, f_wrapped = corr
A, b = linearise_fun(f_wrapped, hidden_state)
*_, f_wrapped = cache
A, b = self.linearize(f_wrapped, x)

# Condition
observed, (_gain, corrected) = ssm.conditional.revert(hidden_state, (A, b))
observed, (_gain, corrected) = self.ssm.conditional.revert(x, (A, b))
return corrected, observed

def extract(hidden_state, _corr, /):
return hidden_state
def extract(self, x, /):
return x

return _Correction(
ode_order=ode_order,
name=name,
init=init,
estimate_error=estimate_error,
complete=complete,
extract=extract,
)

def correction_ts0(*, ssm, ode_order=1) -> _Correction:
"""Zeroth-order Taylor linearisation."""
linearize = ssm.linearise.ode_taylor_0th(ode_order=ode_order)
return _CorrectionTS(name="TS0", ode_order=ode_order, ssm=ssm, linearize=linearize)


def correction_ts1(*, ssm, ode_order=1) -> _Correction:
"""First-order Taylor linearisation."""
linearize = ssm.linearise.ode_taylor_1st(ode_order=ode_order)
return _CorrectionTS(name="TS1", ode_order=ode_order, ssm=ssm, linearize=linearize)


def correction_slr0(*, ssm, cubature_fun=cubature_third_order_spherical) -> _Correction:
"""Zeroth-order statistical linear regression."""
linearize = ssm.linearise.ode_statistical_1st(cubature_fun)
return _CorrectionSLR(ssm=ssm, ode_order=1, linearize=linearize, name="SLR0")


def correction_slr1(*, ssm, cubature_fun=cubature_third_order_spherical) -> _Correction:
"""First-order statistical linear regression."""
linearize = ssm.linearise.ode_statistical_0th(cubature_fun)
return _CorrectionSLR(ssm=ssm, ode_order=1, linearize=linearize, name="SLR1")


def _estimate_error(observed, /, *, ssm):
Expand Down Expand Up @@ -675,7 +643,7 @@ def begin(state: _StrategyState, /, *, dt, vector_field):
hidden, extra = extrapolation.begin(state.hidden, state.aux_extra, dt=dt)
t = state.t + dt
error, observed, corr = correction.estimate_error(
hidden, state.aux_corr, vector_field=vector_field, t=t
hidden, vector_field=vector_field, t=t
)
state = _StrategyState(t=t, hidden=hidden, aux_extra=extra, aux_corr=corr)
return error, observed, state
Expand All @@ -688,7 +656,7 @@ def complete(state, /, *, output_scale):
return _StrategyState(t=state.t, hidden=hidden, aux_extra=extra, aux_corr=corr)

def extract(state: _StrategyState, /):
hidden = correction.extract(state.hidden, state.aux_corr)
hidden = correction.extract(state.hidden)
sol = extrapolation.extract(hidden, state.aux_extra)
return state.t, sol

Expand Down

0 comments on commit 7063aed

Please sign in to comment.