Skip to content

Commit

Permalink
Condense strategies into a single class to simplify the source-code
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer committed Oct 25, 2024
1 parent 7063aed commit 7891a46
Showing 1 changed file with 76 additions and 112 deletions.
188 changes: 76 additions & 112 deletions probdiffeq/ivpsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,123 +545,56 @@ class _Strategy:
"""Estimation strategy."""

name: str
extrapolation: _ExtraImpl
correction: _Correction
ssm: Any

is_suitable_for_save_at: int
is_suitable_for_offgrid_marginals: int
is_suitable_for_save_every_step: int

prior: _MarkovProcess

@property
def num_derivatives(self):
return self.prior.num_derivatives

initial_condition: Callable
"""Construct an initial condition from a set of Taylor coefficients."""

init: Callable
"""Initialise a state from a posterior."""

begin: Callable
"""Predict the error of an upcoming step."""

complete: Callable
"""Complete the step after the error has been predicted."""

extract: Callable
"""Extract the solution from a state."""

case_interpolate_at_t1: Callable
"""Process the solution in case t=t_n."""

case_interpolate: Callable

offgrid_marginals: Callable
"""Compute offgrid_marginals."""


def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy:
"""Construct a smoother."""
extrapolation_impl = _ExtraImplSmoother(prior, ssm=ssm)
return _strategy(
extrapolation_impl,
correction,
ssm=ssm,
is_suitable_for_save_at=False,
is_suitable_for_save_every_step=True,
is_suitable_for_offgrid_marginals=True,
name=f"<Smoother with {extrapolation_impl}, {correction}>",
)


def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy:
"""Construct a fixedpoint-smoother."""
extrapolation_impl = _ExtraImplFixedPoint(prior, ssm=ssm)
return _strategy(
extrapolation_impl,
correction,
ssm=ssm,
is_suitable_for_save_at=True,
is_suitable_for_save_every_step=False,
is_suitable_for_offgrid_marginals=False,
name=f"<Fixedpoint smoother with {extrapolation_impl}, {correction}>",
)


def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy:
"""Construct a filter."""
extrapolation_impl = _ExtraImplFilter(prior, ssm=ssm)
return _strategy(
extrapolation_impl,
correction,
name=f"<Filter with {extrapolation_impl}, {correction}>",
is_suitable_for_save_at=True,
is_suitable_for_offgrid_marginals=True,
is_suitable_for_save_every_step=True,
ssm=ssm,
)


def _strategy(
extrapolation: _ExtraImpl,
correction: _Correction,
*,
name,
is_suitable_for_save_at,
is_suitable_for_save_every_step,
is_suitable_for_offgrid_marginals,
ssm,
):
def init(t, posterior, /) -> _StrategyState:
rv, extra = extrapolation.init(posterior)
rv, corr = correction.init(rv)
def init(self, t, posterior, /) -> _StrategyState:
"""Initialise a state from a posterior."""
rv, extra = self.extrapolation.init(posterior)
rv, corr = self.correction.init(rv)
return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr)

def initial_condition():
return extrapolation.initial_condition()
def initial_condition(self):
"""Construct an initial condition from a set of Taylor coefficients."""
return self.extrapolation.initial_condition()

def begin(state: _StrategyState, /, *, dt, vector_field):
hidden, extra = extrapolation.begin(state.hidden, state.aux_extra, dt=dt)
def begin(self, state: _StrategyState, /, *, dt, vector_field):
"""Predict the error of an upcoming step."""
hidden, extra = self.extrapolation.begin(state.hidden, state.aux_extra, dt=dt)
t = state.t + dt
error, observed, corr = correction.estimate_error(
error, observed, corr = self.correction.estimate_error(
hidden, vector_field=vector_field, t=t
)
state = _StrategyState(t=t, hidden=hidden, aux_extra=extra, aux_corr=corr)
return error, observed, state

def complete(state, /, *, output_scale):
hidden, extra = extrapolation.complete(
def complete(self, state, /, *, output_scale):
"""Complete the step after the error has been predicted."""
hidden, extra = self.extrapolation.complete(
state.hidden, state.aux_extra, output_scale=output_scale
)
hidden, corr = correction.complete(hidden, state.aux_corr)
hidden, corr = self.correction.complete(hidden, state.aux_corr)
return _StrategyState(t=state.t, hidden=hidden, aux_extra=extra, aux_corr=corr)

def extract(state: _StrategyState, /):
hidden = correction.extract(state.hidden)
sol = extrapolation.extract(hidden, state.aux_extra)
def extract(self, state: _StrategyState, /):
"""Extract the solution from a state."""
hidden = self.correction.extract(state.hidden)
sol = self.extrapolation.extract(hidden, state.aux_extra)
return state.t, sol

def case_interpolate_at_t1(state_t1: _StrategyState) -> _InterpRes:
_tmp = extrapolation.interpolate_at_t1(state_t1.hidden, state_t1.aux_extra)
def case_interpolate_at_t1(self, state_t1: _StrategyState) -> _InterpRes:
"""Process the solution in case t=t_n."""
_tmp = self.extrapolation.interpolate_at_t1(state_t1.hidden, state_t1.aux_extra)
step_from, solution, interp_from = (
_tmp.step_from,
_tmp.interpolated,
Expand All @@ -679,11 +612,11 @@ def _state(x):
return _InterpRes(step_from, solution, interp_from)

def case_interpolate(
t, *, s0: _StrategyState, s1: _StrategyState, output_scale
self, t, *, s0: _StrategyState, s1: _StrategyState, output_scale
) -> _InterpRes:
"""Process the solution in case t>t_n."""
# Interpolate
interp = extrapolation.interpolate(
interp = self.extrapolation.interpolate(
state_t0=(s0.hidden, s0.aux_extra),
marginal_t1=s1.hidden,
dt0=t - s0.t,
Expand All @@ -704,15 +637,16 @@ def _state(t_, x):
step_from=step_from, interpolated=interpolated, interp_from=interp_from
)

def offgrid_marginals(*, t, marginals_t1, posterior_t0, t0, t1, output_scale):
if not is_suitable_for_offgrid_marginals:
def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale):
"""Compute offgrid_marginals."""
if not self.is_suitable_for_offgrid_marginals:
raise NotImplementedError

dt0 = t - t0
dt1 = t1 - t
state_t0 = init(t0, posterior_t0)
state_t0 = self.init(t0, posterior_t0)

interp = extrapolation.interpolate(
interp = self.extrapolation.interpolate(
state_t0=(state_t0.hidden, state_t0.aux_extra),
marginal_t1=marginals_t1,
dt0=dt0,
Expand All @@ -721,22 +655,52 @@ def offgrid_marginals(*, t, marginals_t1, posterior_t0, t0, t1, output_scale):
)

(marginals, _aux) = interp.interpolated
u = ssm.stats.qoi(marginals)
u = self.ssm.stats.qoi(marginals)
return u, marginals


def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy:
"""Construct a smoother."""
extrapolation = _ExtraImplSmoother(prior, ssm=ssm)
return _Strategy(
name=name,
init=init,
initial_condition=initial_condition,
begin=begin,
complete=complete,
extract=extract,
case_interpolate_at_t1=case_interpolate_at_t1,
case_interpolate=case_interpolate,
offgrid_marginals=offgrid_marginals,
is_suitable_for_save_at=is_suitable_for_save_at,
is_suitable_for_save_every_step=is_suitable_for_save_every_step,
prior=extrapolation.prior,
extrapolation=extrapolation,
correction=correction,
prior=prior,
ssm=ssm,
is_suitable_for_save_at=False,
is_suitable_for_save_every_step=True,
is_suitable_for_offgrid_marginals=True,
name="Smoother",
)


def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy:
"""Construct a fixedpoint-smoother."""
extrapolation = _ExtraImplFixedPoint(prior, ssm=ssm)
return _Strategy(
extrapolation=extrapolation,
correction=correction,
ssm=ssm,
prior=prior,
is_suitable_for_save_at=True,
is_suitable_for_save_every_step=False,
is_suitable_for_offgrid_marginals=False,
name="Fixed-point smoother",
)


def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy:
"""Construct a filter."""
extrapolation = _ExtraImplFilter(prior, ssm=ssm)
return _Strategy(
name="Filter",
prior=prior,
extrapolation=extrapolation,
correction=correction,
is_suitable_for_save_at=True,
is_suitable_for_offgrid_marginals=True,
is_suitable_for_save_every_step=True,
ssm=ssm,
)


Expand Down

0 comments on commit 7891a46

Please sign in to comment.